bug fix: updating duration func
Browse files
app.py
CHANGED
|
@@ -62,7 +62,7 @@ def remote_text_encoder(prompts):
|
|
| 62 |
except Exception as e:
|
| 63 |
raise Exception(f"Failed to encode prompt: {str(e)}")
|
| 64 |
|
| 65 |
-
def get_duration(
|
| 66 |
"""Calculate dynamic GPU duration based on inference steps and input image."""
|
| 67 |
num_images = 0 if input_image is None else 1
|
| 68 |
step_duration = 1 + 0.7 * num_images
|
|
@@ -87,6 +87,10 @@ def generate_image(
|
|
| 87 |
guidance_scale: How closely to follow the prompt (higher = more strict)
|
| 88 |
seed: Random seed for reproducibility (-1 for random)
|
| 89 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
if not prompt or prompt.strip() == "":
|
| 91 |
raise gr.Error("Please enter a prompt!")
|
| 92 |
|
|
@@ -94,21 +98,32 @@ def generate_image(
|
|
| 94 |
|
| 95 |
try:
|
| 96 |
# Load pipeline (lazy loading)
|
|
|
|
| 97 |
pipeline = load_pipeline()
|
|
|
|
| 98 |
|
| 99 |
progress(0.1, desc="Encoding prompt...")
|
|
|
|
| 100 |
|
| 101 |
# Get prompt embeddings from remote encoder
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
progress(0.3, desc="Generating image...")
|
| 105 |
|
| 106 |
# Set up generator
|
| 107 |
generator_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
| 108 |
if seed == -1:
|
| 109 |
import random
|
| 110 |
seed = random.randint(0, 2**32 - 1)
|
| 111 |
|
|
|
|
| 112 |
generator = torch.Generator(device=generator_device).manual_seed(int(seed))
|
| 113 |
|
| 114 |
# Prepare pipeline arguments
|
|
@@ -123,25 +138,43 @@ def generate_image(
|
|
| 123 |
if input_image is not None:
|
| 124 |
pipe_kwargs["image"] = input_image
|
| 125 |
progress(0.4, desc="Processing input image...")
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# Generate image
|
| 128 |
with torch.inference_mode():
|
| 129 |
-
|
|
|
|
| 130 |
|
|
|
|
| 131 |
progress(1.0, desc="Done!")
|
| 132 |
|
| 133 |
return image
|
| 134 |
|
|
|
|
|
|
|
|
|
|
| 135 |
except Exception as e:
|
| 136 |
import traceback
|
| 137 |
error_msg = f"Error generating image: {str(e)}\n{traceback.format_exc()}"
|
| 138 |
print(error_msg)
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
# Create Gradio interface
|
| 143 |
with gr.Blocks(
|
| 144 |
title="Flux2 Image Generator",
|
|
|
|
| 145 |
) as demo:
|
| 146 |
gr.Markdown(
|
| 147 |
"""
|
|
|
|
| 62 |
except Exception as e:
|
| 63 |
raise Exception(f"Failed to encode prompt: {str(e)}")
|
| 64 |
|
| 65 |
+
def get_duration(prompt: str, input_image: Image.Image = None, num_inference_steps: int = 28, guidance_scale: float = 4.0, seed: int = 42, progress=None):
|
| 66 |
"""Calculate dynamic GPU duration based on inference steps and input image."""
|
| 67 |
num_images = 0 if input_image is None else 1
|
| 68 |
step_duration = 1 + 0.7 * num_images
|
|
|
|
| 87 |
guidance_scale: How closely to follow the prompt (higher = more strict)
|
| 88 |
seed: Random seed for reproducibility (-1 for random)
|
| 89 |
"""
|
| 90 |
+
print(f"=== Starting generation ===")
|
| 91 |
+
print(f"Prompt: {prompt[:100]}...")
|
| 92 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 93 |
+
|
| 94 |
if not prompt or prompt.strip() == "":
|
| 95 |
raise gr.Error("Please enter a prompt!")
|
| 96 |
|
|
|
|
| 98 |
|
| 99 |
try:
|
| 100 |
# Load pipeline (lazy loading)
|
| 101 |
+
print("Loading pipeline...")
|
| 102 |
pipeline = load_pipeline()
|
| 103 |
+
print("Pipeline loaded successfully")
|
| 104 |
|
| 105 |
progress(0.1, desc="Encoding prompt...")
|
| 106 |
+
print("Encoding prompt...")
|
| 107 |
|
| 108 |
# Get prompt embeddings from remote encoder
|
| 109 |
+
try:
|
| 110 |
+
prompt_embeds = remote_text_encoder(prompt)
|
| 111 |
+
print(f"Prompt embeds shape: {prompt_embeds.shape}")
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"Error encoding prompt: {str(e)}")
|
| 114 |
+
raise gr.Error(f"Failed to encode prompt. Please check your HuggingFace token. Error: {str(e)}")
|
| 115 |
|
| 116 |
progress(0.3, desc="Generating image...")
|
| 117 |
|
| 118 |
# Set up generator
|
| 119 |
generator_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 120 |
+
print(f"Generator device: {generator_device}")
|
| 121 |
+
|
| 122 |
if seed == -1:
|
| 123 |
import random
|
| 124 |
seed = random.randint(0, 2**32 - 1)
|
| 125 |
|
| 126 |
+
print(f"Using seed: {seed}")
|
| 127 |
generator = torch.Generator(device=generator_device).manual_seed(int(seed))
|
| 128 |
|
| 129 |
# Prepare pipeline arguments
|
|
|
|
| 138 |
if input_image is not None:
|
| 139 |
pipe_kwargs["image"] = input_image
|
| 140 |
progress(0.4, desc="Processing input image...")
|
| 141 |
+
print("Processing with input image")
|
| 142 |
+
|
| 143 |
+
print(f"Starting generation with {num_inference_steps} steps...")
|
| 144 |
|
| 145 |
# Generate image
|
| 146 |
with torch.inference_mode():
|
| 147 |
+
result = pipeline(**pipe_kwargs)
|
| 148 |
+
image = result.images[0]
|
| 149 |
|
| 150 |
+
print("Generation complete!")
|
| 151 |
progress(1.0, desc="Done!")
|
| 152 |
|
| 153 |
return image
|
| 154 |
|
| 155 |
+
except gr.Error:
|
| 156 |
+
# Re-raise Gradio errors as-is
|
| 157 |
+
raise
|
| 158 |
except Exception as e:
|
| 159 |
import traceback
|
| 160 |
error_msg = f"Error generating image: {str(e)}\n{traceback.format_exc()}"
|
| 161 |
print(error_msg)
|
| 162 |
+
|
| 163 |
+
# Provide more helpful error messages
|
| 164 |
+
if "CUDA" in str(e):
|
| 165 |
+
raise gr.Error(f"GPU Error: {str(e)}. The model requires GPU to run.")
|
| 166 |
+
elif "token" in str(e).lower() or "401" in str(e):
|
| 167 |
+
raise gr.Error("Authentication failed. Please ensure your HuggingFace token is set correctly.")
|
| 168 |
+
elif "timeout" in str(e).lower():
|
| 169 |
+
raise gr.Error("Request timed out. Please try again.")
|
| 170 |
+
else:
|
| 171 |
+
raise gr.Error(f"Error: {str(e)}")
|
| 172 |
|
| 173 |
|
| 174 |
# Create Gradio interface
|
| 175 |
with gr.Blocks(
|
| 176 |
title="Flux2 Image Generator",
|
| 177 |
+
theme=gr.themes.Soft(),
|
| 178 |
) as demo:
|
| 179 |
gr.Markdown(
|
| 180 |
"""
|