rokmr commited on
Commit
97aaf94
·
verified ·
1 Parent(s): 8247cad

bug fix: updating duration func

Browse files
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -62,7 +62,7 @@ def remote_text_encoder(prompts):
62
  except Exception as e:
63
  raise Exception(f"Failed to encode prompt: {str(e)}")
64
 
65
- def get_duration(num_inference_steps: int, input_image: Image.Image = None):
66
  """Calculate dynamic GPU duration based on inference steps and input image."""
67
  num_images = 0 if input_image is None else 1
68
  step_duration = 1 + 0.7 * num_images
@@ -87,6 +87,10 @@ def generate_image(
87
  guidance_scale: How closely to follow the prompt (higher = more strict)
88
  seed: Random seed for reproducibility (-1 for random)
89
  """
 
 
 
 
90
  if not prompt or prompt.strip() == "":
91
  raise gr.Error("Please enter a prompt!")
92
 
@@ -94,21 +98,32 @@ def generate_image(
94
 
95
  try:
96
  # Load pipeline (lazy loading)
 
97
  pipeline = load_pipeline()
 
98
 
99
  progress(0.1, desc="Encoding prompt...")
 
100
 
101
  # Get prompt embeddings from remote encoder
102
- prompt_embeds = remote_text_encoder(prompt)
 
 
 
 
 
103
 
104
  progress(0.3, desc="Generating image...")
105
 
106
  # Set up generator
107
  generator_device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
108
  if seed == -1:
109
  import random
110
  seed = random.randint(0, 2**32 - 1)
111
 
 
112
  generator = torch.Generator(device=generator_device).manual_seed(int(seed))
113
 
114
  # Prepare pipeline arguments
@@ -123,25 +138,43 @@ def generate_image(
123
  if input_image is not None:
124
  pipe_kwargs["image"] = input_image
125
  progress(0.4, desc="Processing input image...")
 
 
 
126
 
127
  # Generate image
128
  with torch.inference_mode():
129
- image = pipeline(**pipe_kwargs).images[0]
 
130
 
 
131
  progress(1.0, desc="Done!")
132
 
133
  return image
134
 
 
 
 
135
  except Exception as e:
136
  import traceback
137
  error_msg = f"Error generating image: {str(e)}\n{traceback.format_exc()}"
138
  print(error_msg)
139
- raise gr.Error(f"Error generating image: {str(e)}")
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  # Create Gradio interface
143
  with gr.Blocks(
144
  title="Flux2 Image Generator",
 
145
  ) as demo:
146
  gr.Markdown(
147
  """
 
62
  except Exception as e:
63
  raise Exception(f"Failed to encode prompt: {str(e)}")
64
 
65
+ def get_duration(prompt: str, input_image: Image.Image = None, num_inference_steps: int = 28, guidance_scale: float = 4.0, seed: int = 42, progress=None):
66
  """Calculate dynamic GPU duration based on inference steps and input image."""
67
  num_images = 0 if input_image is None else 1
68
  step_duration = 1 + 0.7 * num_images
 
87
  guidance_scale: How closely to follow the prompt (higher = more strict)
88
  seed: Random seed for reproducibility (-1 for random)
89
  """
90
+ print(f"=== Starting generation ===")
91
+ print(f"Prompt: {prompt[:100]}...")
92
+ print(f"CUDA available: {torch.cuda.is_available()}")
93
+
94
  if not prompt or prompt.strip() == "":
95
  raise gr.Error("Please enter a prompt!")
96
 
 
98
 
99
  try:
100
  # Load pipeline (lazy loading)
101
+ print("Loading pipeline...")
102
  pipeline = load_pipeline()
103
+ print("Pipeline loaded successfully")
104
 
105
  progress(0.1, desc="Encoding prompt...")
106
+ print("Encoding prompt...")
107
 
108
  # Get prompt embeddings from remote encoder
109
+ try:
110
+ prompt_embeds = remote_text_encoder(prompt)
111
+ print(f"Prompt embeds shape: {prompt_embeds.shape}")
112
+ except Exception as e:
113
+ print(f"Error encoding prompt: {str(e)}")
114
+ raise gr.Error(f"Failed to encode prompt. Please check your HuggingFace token. Error: {str(e)}")
115
 
116
  progress(0.3, desc="Generating image...")
117
 
118
  # Set up generator
119
  generator_device = "cuda" if torch.cuda.is_available() else "cpu"
120
+ print(f"Generator device: {generator_device}")
121
+
122
  if seed == -1:
123
  import random
124
  seed = random.randint(0, 2**32 - 1)
125
 
126
+ print(f"Using seed: {seed}")
127
  generator = torch.Generator(device=generator_device).manual_seed(int(seed))
128
 
129
  # Prepare pipeline arguments
 
138
  if input_image is not None:
139
  pipe_kwargs["image"] = input_image
140
  progress(0.4, desc="Processing input image...")
141
+ print("Processing with input image")
142
+
143
+ print(f"Starting generation with {num_inference_steps} steps...")
144
 
145
  # Generate image
146
  with torch.inference_mode():
147
+ result = pipeline(**pipe_kwargs)
148
+ image = result.images[0]
149
 
150
+ print("Generation complete!")
151
  progress(1.0, desc="Done!")
152
 
153
  return image
154
 
155
+ except gr.Error:
156
+ # Re-raise Gradio errors as-is
157
+ raise
158
  except Exception as e:
159
  import traceback
160
  error_msg = f"Error generating image: {str(e)}\n{traceback.format_exc()}"
161
  print(error_msg)
162
+
163
+ # Provide more helpful error messages
164
+ if "CUDA" in str(e):
165
+ raise gr.Error(f"GPU Error: {str(e)}. The model requires GPU to run.")
166
+ elif "token" in str(e).lower() or "401" in str(e):
167
+ raise gr.Error("Authentication failed. Please ensure your HuggingFace token is set correctly.")
168
+ elif "timeout" in str(e).lower():
169
+ raise gr.Error("Request timed out. Please try again.")
170
+ else:
171
+ raise gr.Error(f"Error: {str(e)}")
172
 
173
 
174
  # Create Gradio interface
175
  with gr.Blocks(
176
  title="Flux2 Image Generator",
177
+ theme=gr.themes.Soft(),
178
  ) as demo:
179
  gr.Markdown(
180
  """