jiuface commited on
Commit
078123a
·
verified ·
1 Parent(s): b93bd46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -70
app.py CHANGED
@@ -38,13 +38,13 @@ device = "cuda:0"
38
 
39
  base_model = "black-forest-labs/FLUX.1-Krea-dev"
40
 
41
- pipeline_quant_config = PipelineQuantizationConfig(
42
- quant_backend="bitsandbytes_4bit",
43
- quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
44
- components_to_quantize=["transformer", "text_encoder_2"],
45
- )
46
 
47
- txt2img_pipe = FluxKontextPipeline.from_pretrained(base_model, quantization_config=pipeline_quant_config, torch_dtype=dtype)
48
  txt2img_pipe = txt2img_pipe.to(device)
49
 
50
  MAX_SEED = 2**32 - 1
@@ -69,13 +69,16 @@ class calculateDuration:
69
  else:
70
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
71
 
72
-
 
 
 
 
 
73
 
74
  def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
75
  with calculateDuration("Upload images"):
76
- print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
77
  connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
78
-
79
  s3 = boto3.client(
80
  's3',
81
  endpoint_url=connectionUrl,
@@ -113,97 +116,116 @@ def generate_random_4_digit_string():
113
  return ''.join(random.choices(string.digits, k=4))
114
 
115
  @spaces.GPU(duration=120)
116
- def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
118
  gr.Info("Starting process")
119
- img2img_model = False
120
- orginal_image = None
121
- device = txt2img_pipe.device
122
-
123
  print(device)
124
-
125
- # Set random seed for reproducibility
126
  if randomize_seed:
127
  with calculateDuration("Set random seed"):
128
  seed = random.randint(0, MAX_SEED)
 
129
 
130
- # Load LoRA weights
131
  gr.Info("Start to load LoRA ...")
132
  with calculateDuration("Unloading LoRA"):
133
- txt2img_pipe.unload_lora_weights()
134
-
135
- lora_configs = None
 
 
 
136
  adapter_names = []
137
- lora_names = []
138
  if lora_strings_json:
139
  try:
140
  lora_configs = json.loads(lora_strings_json)
141
- except:
 
142
  gr.Warning("Parse lora config json failed")
143
  print("parse lora config json failed")
144
-
145
  if lora_configs:
146
-
147
  with calculateDuration("Loading LoRA weights"):
148
- adapter_weights = []
149
-
150
- for idx, lora_info in enumerate(lora_configs):
151
- lora_repo = lora_info.get("repo")
152
  weights = lora_info.get("weights")
153
- adapter_name = lora_info.get("adapter_name")
154
-
155
- lora_name = generate_random_4_digit_string()
156
- lora_names.append(lora_name)
157
- adapter_weight = lora_info.get("adapter_weight")
158
- adapter_names.append(adapter_name)
159
- adapter_weights.append(adapter_weight)
160
- if lora_repo and weights and adapter_name:
161
- try:
162
- txt2img_pipe.load_lora_weights(hf_hub_download(lora_repo, weights), adapter_name=lora_name)
163
- except:
164
- print("load lora error")
165
-
166
- # set lora weights
167
- if len(lora_names) > 0:
168
- txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
169
- txt2img_pipe.fuse_lora(adapter_names=lora_names)
170
- txt2img_pipe.enable_vae_slicing()
171
-
172
- # Generate image
 
 
 
 
 
173
  error_message = ""
174
  try:
175
  gr.Info("Start to generate images ...")
176
- print(device)
177
- # Generate image
178
- pipe = txt2img_pipe.to(device)
179
- generator = torch.Generator("cuda").manual_seed(seed)
180
  joint_attention_kwargs = {"scale": 1}
181
- final_image = pipe(
182
  prompt=prompt,
183
- num_inference_steps=steps,
184
- guidance_scale=cfg_scale,
185
- width=width,
186
- height=height,
187
  max_sequence_length=512,
188
  generator=generator,
189
  joint_attention_kwargs=joint_attention_kwargs
190
  ).images[0]
191
  except Exception as e:
192
- error_message = str(e)
193
  gr.Error(error_message)
194
  print("fatal error", e)
195
- final_image = None
196
-
197
- if final_image:
198
- if upload_to_r2:
199
- url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket)
200
- result = {"status": "success", "message": "upload image success", "url": url}
201
- else:
202
- result = {"status": "success", "message": "Image generated but not uploaded"}
203
- else:
204
- result = {"status": "failed", "message": error_message}
205
- final_image = None
206
-
 
207
  gr.Info("Completed!")
208
  progress(100, "Completed!")
209
  return json.dumps(result)
 
38
 
39
  base_model = "black-forest-labs/FLUX.1-Krea-dev"
40
 
41
+ # pipeline_quant_config = PipelineQuantizationConfig(
42
+ # quant_backend="bitsandbytes_4bit",
43
+ # quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
44
+ # components_to_quantize=["transformer", "text_encoder_2"],
45
+ # )
46
 
47
+ txt2img_pipe = FluxKontextPipeline.from_pretrained(base_model, torch_dtype=dtype)
48
  txt2img_pipe = txt2img_pipe.to(device)
49
 
50
  MAX_SEED = 2**32 - 1
 
69
  else:
70
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
71
 
72
+ def safe_trim_for_clip(text: str, max_words: int = 77) -> str:
73
+ # 简单按词裁,不破坏主 prompt。你也可以做更智能的关键词抽取。
74
+ tokens = re.split(r"\s+", text.strip())
75
+ if len(tokens) <= max_words:
76
+ return text
77
+ return " ".join(tokens[:max_words])
78
 
79
  def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
80
  with calculateDuration("Upload images"):
 
81
  connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
 
82
  s3 = boto3.client(
83
  's3',
84
  endpoint_url=connectionUrl,
 
116
  return ''.join(random.choices(string.digits, k=4))
117
 
118
  @spaces.GPU(duration=120)
119
+ def run_lora(
120
+ prompt,
121
+ image_url,
122
+ lora_strings_json,
123
+ image_strength,
124
+ cfg_scale,
125
+ steps,
126
+ randomize_seed,
127
+ seed,
128
+ width,
129
+ height,
130
+ upload_to_r2,
131
+ account_id,
132
+ access_key,
133
+ secret_key,
134
+ bucket,
135
+ progress=gr.Progress(track_tqdm=True)
136
+ ):
137
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
138
  gr.Info("Starting process")
139
+ pipe = txt2img_pipe
140
+ device = pipe.device
 
 
141
  print(device)
142
+
143
+ # ========== Seed ==========
144
  if randomize_seed:
145
  with calculateDuration("Set random seed"):
146
  seed = random.randint(0, MAX_SEED)
147
+ generator = torch.Generator(device=device).manual_seed(seed)
148
 
149
+ # ========== LoRA ==========
150
  gr.Info("Start to load LoRA ...")
151
  with calculateDuration("Unloading LoRA"):
152
+ try:
153
+ pipe.unload_lora_weights()
154
+ except Exception as _:
155
+ # 某些版本上未加载时调用可能抛异常,忽略
156
+ pass
157
+
158
  adapter_names = []
159
+ adapter_weights = []
160
  if lora_strings_json:
161
  try:
162
  lora_configs = json.loads(lora_strings_json)
163
+ except Exception as _:
164
+ lora_configs = None
165
  gr.Warning("Parse lora config json failed")
166
  print("parse lora config json failed")
167
+
168
  if lora_configs:
 
169
  with calculateDuration("Loading LoRA weights"):
170
+ for lora_info in lora_configs:
171
+ repo = lora_info.get("repo")
 
 
172
  weights = lora_info.get("weights")
173
+ # 优先使用用户提供的 adapter_name;没有则随机
174
+ adapter_name = lora_info.get("adapter_name") or f"adp_{generate_random_4_digit_string()}"
175
+ weight = float(lora_info.get("adapter_weight", 1.0))
176
+ if not (repo and weights):
177
+ print(f"skip invalid lora entry: {lora_info}")
178
+ continue
179
+ try:
180
+ weight_path = hf_hub_download(repo_id=repo, filename=weights)
181
+ # 关键修复:prefix=None,避免仅在 text_encoder 查找
182
+ pipe.load_lora_weights(weight_path, adapter_name=adapter_name, prefix=None)
183
+ adapter_names.append(adapter_name)
184
+ adapter_weights.append(weight)
185
+ except Exception as e:
186
+ print(f"load lora error for {repo}/{weights}: {e}")
187
+
188
+ if adapter_names:
189
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
190
+ # 可选:融合后推理更快,但无法动态调整权重
191
+ pipe.fuse_lora(adapter_names=adapter_names)
192
+
193
+ pipe.enable_vae_slicing()
194
+
195
+
196
+ clip_side_prompt = safe_trim_for_clip(prompt, max_words=77)
197
+ init_image = None
198
  error_message = ""
199
  try:
200
  gr.Info("Start to generate images ...")
 
 
 
 
201
  joint_attention_kwargs = {"scale": 1}
202
+ image = pipe(
203
  prompt=prompt,
204
+ num_inference_steps=int(steps),
205
+ guidance_scale=float(cfg_scale),
206
+ width=int(width),
207
+ height=int(height),
208
  max_sequence_length=512,
209
  generator=generator,
210
  joint_attention_kwargs=joint_attention_kwargs
211
  ).images[0]
212
  except Exception as e:
213
+ error_message = str(e)
214
  gr.Error(error_message)
215
  print("fatal error", e)
216
+ image = None
217
+
218
+ result = {"status": "failed", "message": error_message} if image is None else {"status": "success", "message": "Image generated but not uploaded"}
219
+ if image is not None and upload_to_r2:
220
+ try:
221
+ url = upload_image_to_r2(image, account_id, access_key, secret_key, bucket)
222
+ result = {"status": "success", "message": "upload image success", "url": url}
223
+ except Exception as e:
224
+ err = f"Upload failed: {e}"
225
+ gr.Warning(err)
226
+ print(err)
227
+ result = {"status": "success", "message": "generated but upload failed"}
228
+
229
  gr.Info("Completed!")
230
  progress(100, "Completed!")
231
  return json.dumps(result)