rokmr commited on
Commit
34a75a5
·
verified ·
1 Parent(s): 998921e

Optimizing for GPU time

Browse files
Files changed (1) hide show
  1. app.py +60 -42
app.py CHANGED
@@ -14,29 +14,21 @@ torch_dtype = torch.bfloat16
14
 
15
  print("Starting Flux2 Image Generator...")
16
 
17
- # Global variable to hold the pipeline
 
18
  pipe = None
19
 
20
- def load_pipeline():
21
- """Lazy load the pipeline when needed."""
22
- global pipe
23
- if pipe is None:
24
- print("Loading Flux2 pipeline...")
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- print(f"Using device: {device}")
27
-
28
- try:
29
- pipe = Flux2Pipeline.from_pretrained(
30
- repo_id,
31
- text_encoder=None,
32
- torch_dtype=torch_dtype,
33
- device_map="cuda"
34
- )
35
- print("Pipeline loaded successfully!")
36
- except Exception as e:
37
- print(f"Error loading pipeline: {e}")
38
- raise
39
- return pipe
40
 
41
  def remote_text_encoder(prompts):
42
  """Encode prompts using remote text encoder API."""
@@ -46,25 +38,39 @@ def remote_text_encoder(prompts):
46
 
47
  # Method 1: From huggingface_hub
48
  try:
49
- token = get_token()
 
50
  except:
51
  pass
52
 
53
- # Method 2: From environment variable (Spaces sets this automatically)
 
 
 
 
 
 
 
54
  if not token:
55
  token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
56
 
57
- # Method 3: From Spaces secrets
58
  if not token:
59
- token = os.environ.get("SPACE_TOKEN")
60
 
61
  if not token:
62
  raise ValueError(
63
- "HuggingFace token not found. "
64
- "If running on Spaces, make sure your Space has access to gated models. "
65
- "If running locally, please login using 'huggingface-cli login'"
 
 
 
 
66
  )
67
 
 
 
68
  response = requests.post(
69
  "https://remote-text-encoder-flux-2.huggingface.co/predict",
70
  json={"prompt": prompts},
@@ -82,23 +88,31 @@ def remote_text_encoder(prompts):
82
  except requests.HTTPError as e:
83
  if e.response.status_code == 401:
84
  raise Exception(
85
- "Authentication failed (401). Your HuggingFace token may not have access to this model. "
 
86
  "Please ensure your token has permission to access FLUX.2 models."
87
  )
88
  elif e.response.status_code == 403:
89
  raise Exception(
90
- "Access forbidden (403). You may need to accept the model's license agreement on HuggingFace."
 
 
91
  )
92
  else:
93
  raise Exception(f"HTTP error {e.response.status_code}: {str(e)}")
94
  except Exception as e:
 
 
95
  raise Exception(f"Failed to encode prompt: {str(e)}")
96
 
97
  def get_duration(prompt: str, input_image: Image.Image = None, num_inference_steps: int = 28, guidance_scale: float = 4.0, seed: int = 42, progress=None):
98
  """Calculate dynamic GPU duration based on inference steps and input image."""
99
  num_images = 0 if input_image is None else 1
100
- step_duration = 1 + 0.7 * num_images
101
- return max(65, num_inference_steps * step_duration + 10)
 
 
 
102
 
103
  @spaces.GPU(duration=get_duration) # Dynamic GPU allocation
104
  def generate_image(
@@ -119,6 +133,8 @@ def generate_image(
119
  guidance_scale: How closely to follow the prompt (higher = more strict)
120
  seed: Random seed for reproducibility (-1 for random)
121
  """
 
 
122
  print(f"=== Starting generation ===")
123
  print(f"Prompt: {prompt[:100]}...")
124
  print(f"CUDA available: {torch.cuda.is_available()}")
@@ -126,13 +142,15 @@ def generate_image(
126
  if not prompt or prompt.strip() == "":
127
  raise gr.Error("Please enter a prompt!")
128
 
129
- progress(0, desc="Loading model...")
130
 
131
  try:
132
- # Load pipeline (lazy loading)
133
- print("Loading pipeline...")
134
- pipeline = load_pipeline()
135
- print("Pipeline loaded successfully")
 
 
136
 
137
  progress(0.1, desc="Encoding prompt...")
138
  print("Encoding prompt...")
@@ -145,7 +163,7 @@ def generate_image(
145
  print(f"Error encoding prompt: {str(e)}")
146
  raise gr.Error(f"Failed to encode prompt. Please check your HuggingFace token. Error: {str(e)}")
147
 
148
- progress(0.3, desc="Generating image...")
149
 
150
  # Set up generator
151
  generator_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -169,14 +187,14 @@ def generate_image(
169
  # Add input image if provided
170
  if input_image is not None:
171
  pipe_kwargs["image"] = input_image
172
- progress(0.4, desc="Processing input image...")
173
  print("Processing with input image")
174
 
175
  print(f"Starting generation with {num_inference_steps} steps...")
176
 
177
  # Generate image
178
  with torch.inference_mode():
179
- result = pipeline(**pipe_kwargs)
180
  image = result.images[0]
181
 
182
  print("Generation complete!")
@@ -193,8 +211,8 @@ def generate_image(
193
  print(error_msg)
194
 
195
  # Provide more helpful error messages
196
- if "CUDA" in str(e):
197
- raise gr.Error(f"GPU Error: {str(e)}. The model requires GPU to run.")
198
  elif "token" in str(e).lower() or "401" in str(e):
199
  raise gr.Error("Authentication failed. Please ensure your HuggingFace token is set correctly.")
200
  elif "timeout" in str(e).lower():
 
14
 
15
  print("Starting Flux2 Image Generator...")
16
 
17
+ # Load the pipeline at startup (NOT inside GPU decorator)
18
+ print("Loading Flux2 pipeline...")
19
  pipe = None
20
 
21
+ try:
22
+ pipe = Flux2Pipeline.from_pretrained(
23
+ repo_id,
24
+ text_encoder=None,
25
+ torch_dtype=torch_dtype,
26
+ device_map="balanced" # Use balanced for CPU during startup
27
+ )
28
+ print("Pipeline loaded successfully!")
29
+ except Exception as e:
30
+ print(f"Error loading pipeline: {e}")
31
+ # Don't raise - will try to load later if needed
 
 
 
 
 
 
 
 
 
32
 
33
  def remote_text_encoder(prompts):
34
  """Encode prompts using remote text encoder API."""
 
38
 
39
  # Method 1: From huggingface_hub
40
  try:
41
+ from huggingface_hub import HfFolder
42
+ token = HfFolder.get_token()
43
  except:
44
  pass
45
 
46
+ # Method 2: get_token from huggingface_hub
47
+ if not token:
48
+ try:
49
+ token = get_token()
50
+ except:
51
+ pass
52
+
53
+ # Method 3: From environment variable (Spaces sets this automatically)
54
  if not token:
55
  token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
56
 
57
+ # Method 4: From Spaces secrets
58
  if not token:
59
+ token = os.environ.get("SPACE_TOKEN") or os.environ.get("SPACES_TOKEN")
60
 
61
  if not token:
62
  raise ValueError(
63
+ "HuggingFace token not found!\n\n"
64
+ "📝 To fix this:\n"
65
+ "1. Go to https://huggingface.co/settings/tokens\n"
66
+ "2. Create a token with 'read' access\n"
67
+ "3. In your Space settings, add a secret named 'HF_TOKEN' with your token value\n"
68
+ "4. Restart your Space\n\n"
69
+ "If running locally, use: huggingface-cli login"
70
  )
71
 
72
+ print(f"Token found: {token[:10]}... (length: {len(token)})")
73
+
74
  response = requests.post(
75
  "https://remote-text-encoder-flux-2.huggingface.co/predict",
76
  json={"prompt": prompts},
 
88
  except requests.HTTPError as e:
89
  if e.response.status_code == 401:
90
  raise Exception(
91
+ "Authentication failed (401).\n\n"
92
+ "Your HuggingFace token may not have access to this model.\n"
93
  "Please ensure your token has permission to access FLUX.2 models."
94
  )
95
  elif e.response.status_code == 403:
96
  raise Exception(
97
+ "Access forbidden (403).\n\n"
98
+ "You may need to accept the model's license agreement on HuggingFace:\n"
99
+ "Visit: https://huggingface.co/black-forest-labs/FLUX.1-dev"
100
  )
101
  else:
102
  raise Exception(f"HTTP error {e.response.status_code}: {str(e)}")
103
  except Exception as e:
104
+ if "token" in str(e).lower():
105
+ raise # Re-raise token errors as-is
106
  raise Exception(f"Failed to encode prompt: {str(e)}")
107
 
108
  def get_duration(prompt: str, input_image: Image.Image = None, num_inference_steps: int = 28, guidance_scale: float = 4.0, seed: int = 42, progress=None):
109
  """Calculate dynamic GPU duration based on inference steps and input image."""
110
  num_images = 0 if input_image is None else 1
111
+ step_duration = 1.3 + 0.7 * num_images # Increased from 1 to 1.3
112
+ # Add extra time for model transfer to GPU + generation
113
+ base_time = 30 # Time for moving model to GPU
114
+ generation_time = num_inference_steps * step_duration
115
+ return int(base_time + generation_time + 15) # Extra 15s buffer
116
 
117
  @spaces.GPU(duration=get_duration) # Dynamic GPU allocation
118
  def generate_image(
 
133
  guidance_scale: How closely to follow the prompt (higher = more strict)
134
  seed: Random seed for reproducibility (-1 for random)
135
  """
136
+ global pipe
137
+
138
  print(f"=== Starting generation ===")
139
  print(f"Prompt: {prompt[:100]}...")
140
  print(f"CUDA available: {torch.cuda.is_available()}")
 
142
  if not prompt or prompt.strip() == "":
143
  raise gr.Error("Please enter a prompt!")
144
 
145
+ progress(0, desc="Moving model to GPU...")
146
 
147
  try:
148
+ # Move pipeline to GPU
149
+ if pipe is None:
150
+ raise gr.Error("Pipeline not loaded. Please refresh the page.")
151
+
152
+ print("Moving pipeline to CUDA...")
153
+ pipe = pipe.to("cuda")
154
 
155
  progress(0.1, desc="Encoding prompt...")
156
  print("Encoding prompt...")
 
163
  print(f"Error encoding prompt: {str(e)}")
164
  raise gr.Error(f"Failed to encode prompt. Please check your HuggingFace token. Error: {str(e)}")
165
 
166
+ progress(0.2, desc="Generating image...")
167
 
168
  # Set up generator
169
  generator_device = "cuda" if torch.cuda.is_available() else "cpu"
 
187
  # Add input image if provided
188
  if input_image is not None:
189
  pipe_kwargs["image"] = input_image
190
+ progress(0.25, desc="Processing input image...")
191
  print("Processing with input image")
192
 
193
  print(f"Starting generation with {num_inference_steps} steps...")
194
 
195
  # Generate image
196
  with torch.inference_mode():
197
+ result = pipe(**pipe_kwargs)
198
  image = result.images[0]
199
 
200
  print("Generation complete!")
 
211
  print(error_msg)
212
 
213
  # Provide more helpful error messages
214
+ if "CUDA" in str(e) or "out of memory" in str(e).lower():
215
+ raise gr.Error(f"GPU Error: {str(e)}. Try reducing inference steps.")
216
  elif "token" in str(e).lower() or "401" in str(e):
217
  raise gr.Error("Authentication failed. Please ensure your HuggingFace token is set correctly.")
218
  elif "timeout" in str(e).lower():