anujithc commited on
Commit
50b6e72
·
1 Parent(s): 9b55e5c
Files changed (1) hide show
  1. app.py +27 -28
app.py CHANGED
@@ -9,7 +9,6 @@ import logging
9
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
10
  from huggingface_hub import login
11
  from diffusers.utils import load_image
12
- #from lora_loading_patch import load_lora_into_transformer
13
  import time
14
  from datetime import datetime
15
  from io import BytesIO
@@ -27,28 +26,28 @@ from huggingface_hub import hf_hub_download
27
  from diffusers.quantizers import PipelineQuantizationConfig
28
  from diffusers import (FluxPriorReduxPipeline, FluxInpaintPipeline, FluxFillPipeline, FluxKontextPipeline, FluxPipeline)
29
 
 
30
  # Login Hugging Face Hub
31
  HF_TOKEN = os.environ.get("HF_TOKEN")
32
  login(token=HF_TOKEN)
33
  import diffusers
34
 
 
35
  # init
36
  dtype = torch.bfloat16
37
  device = "cuda:0"
38
 
 
39
  base_model = "black-forest-labs/FLUX.1-Krea-dev"
40
 
41
- # pipeline_quant_config = PipelineQuantizationConfig(
42
- # quant_backend="bitsandbytes_4bit",
43
- # quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
44
- # components_to_quantize=["transformer", "text_encoder_2"],
45
- # )
46
 
47
  txt2img_pipe = FluxKontextPipeline.from_pretrained(base_model, torch_dtype=dtype)
48
  txt2img_pipe = txt2img_pipe.to(device)
49
 
 
50
  MAX_SEED = 2**32 - 1
51
 
 
52
  class calculateDuration:
53
  def __init__(self, activity_name=""):
54
  self.activity_name = activity_name
@@ -69,13 +68,14 @@ class calculateDuration:
69
  else:
70
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
71
 
 
72
  def safe_trim_for_clip(text: str, max_words: int = 77) -> str:
73
- # 简单按词裁,不破坏主 prompt。你也可以做更智能的关键词抽取。
74
  tokens = re.split(r"\s+", text.strip())
75
  if len(tokens) <= max_words:
76
  return text
77
  return " ".join(tokens[:max_words])
78
 
 
79
  def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
80
  with calculateDuration("Upload images"):
81
  connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
@@ -94,17 +94,16 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
94
  buffer.seek(0)
95
  s3.upload_fileobj(buffer, bucket_name, image_file)
96
  print("upload finish", image_file)
97
- # start to generate thumbnail
 
98
  thumbnail = image.copy()
99
  thumbnail_width = 256
100
  aspect_ratio = image.height / image.width
101
  thumbnail_height = int(thumbnail_width * aspect_ratio)
102
  thumbnail = thumbnail.resize((thumbnail_width, thumbnail_height), Image.LANCZOS)
103
 
104
- # Generate the thumbnail image filename
105
  thumbnail_file = image_file.replace(".png", "_thumbnail.png")
106
 
107
- # Save thumbnail to buffer and upload
108
  thumbnail_buffer = BytesIO()
109
  thumbnail.save(thumbnail_buffer, "PNG")
110
  thumbnail_buffer.seek(0)
@@ -112,9 +111,11 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
112
  print("upload thumbnail finish", thumbnail_file)
113
  return image_file
114
 
 
115
  def generate_random_4_digit_string():
116
  return ''.join(random.choices(string.digits, k=4))
117
 
 
118
  @spaces.GPU(duration=120)
119
  def run_lora(
120
  prompt,
@@ -152,7 +153,6 @@ def run_lora(
152
  try:
153
  pipe.unload_lora_weights()
154
  except Exception as _:
155
- # 某些版本上未加载时调用可能抛异常,忽略
156
  pass
157
 
158
  adapter_names = []
@@ -170,7 +170,6 @@ def run_lora(
170
  for lora_info in lora_configs:
171
  repo = lora_info.get("repo")
172
  weights = lora_info.get("weights")
173
- # 优先使用用户提供的 adapter_name;没有则随机
174
  adapter_name = lora_info.get("adapter_name") or f"adp_{generate_random_4_digit_string()}"
175
  weight = float(lora_info.get("adapter_weight", 1.0))
176
  if not (repo and weights):
@@ -178,7 +177,6 @@ def run_lora(
178
  continue
179
  try:
180
  weight_path = hf_hub_download(repo_id=repo, filename=weights)
181
- # 关键修复:prefix=None,避免仅在 text_encoder 查找
182
  pipe.load_lora_weights(weight_path, adapter_name=adapter_name, prefix=None)
183
  adapter_names.append(adapter_name)
184
  adapter_weights.append(weight)
@@ -187,8 +185,6 @@ def run_lora(
187
 
188
  if adapter_names:
189
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
190
- # 可选:融合后推理更快,但无法动态调整权重
191
- # pipe.fuse_lora(adapter_names=adapter_names)
192
  try:
193
  active = pipe.get_active_adapters() if hasattr(pipe, "get_active_adapters") else []
194
  print("Active adapters:", active)
@@ -202,15 +198,15 @@ def run_lora(
202
  lora_layer_count += 1
203
  print(f"[DEBUG] transformer LoRA layers: {lora_layer_count}")
204
 
205
- # 若层数为 0,给出直观警告
206
  if lora_layer_count == 0:
207
  gr.Warning("LoRA seems not injected (0 layers on transformer). Check whether the LoRA is trained for FLUX and `prefix=None` is set.")
208
 
209
-
210
  pipe.enable_vae_slicing()
211
  clip_side_prompt = safe_trim_for_clip(prompt, max_words=77)
212
  init_image = None
213
  error_message = ""
 
 
214
  try:
215
  gr.Info("Start to generate images ...")
216
  joint_attention_kwargs = {"scale": 1}
@@ -246,25 +242,25 @@ def run_lora(
246
  progress(100, "Completed!")
247
 
248
  # CHANGED: Return both image AND json
249
- return image, json.dumps(result) # <--- THIS IS THE KEY CHANGE
250
-
251
- # Gradio interface
252
-
253
 
254
 
 
255
  with gr.Blocks() as demo:
256
- gr.Markdown("flux-dev-multi-lora")
257
  with gr.Row():
258
 
259
  with gr.Column():
260
-
261
  prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=10)
262
- lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5)
 
 
 
 
263
  image_url = gr.Text(label="Image url", placeholder="Enter image url to enable image to image model", lines=1)
264
  run_button = gr.Button("Run", scale=0)
265
 
266
  with gr.Accordion("Advanced Settings", open=False):
267
-
268
  with gr.Row():
269
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
270
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
@@ -284,14 +280,16 @@ with gr.Blocks() as demo:
284
  secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
285
  bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
286
 
287
-
288
  with gr.Column():
 
 
289
  json_text = gr.Text(label="Result JSON")
290
 
291
  gr.Markdown("**Disclaimer:**")
292
  gr.Markdown(
293
  "This demo is only for research purpose. This space owner cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. This space owner provides the tools, but the responsibility for their use lies with the individual user."
294
  )
 
295
  inputs = [
296
  prompt,
297
  image_url,
@@ -310,7 +308,8 @@ with gr.Blocks() as demo:
310
  bucket
311
  ]
312
 
313
- outputs = [json_text]
 
314
 
315
  run_button.click(
316
  fn=run_lora,
@@ -321,4 +320,4 @@ with gr.Blocks() as demo:
321
  try:
322
  demo.queue().launch()
323
  except:
324
- print("demo exception ...")
 
9
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
10
  from huggingface_hub import login
11
  from diffusers.utils import load_image
 
12
  import time
13
  from datetime import datetime
14
  from io import BytesIO
 
26
  from diffusers.quantizers import PipelineQuantizationConfig
27
  from diffusers import (FluxPriorReduxPipeline, FluxInpaintPipeline, FluxFillPipeline, FluxKontextPipeline, FluxPipeline)
28
 
29
+
30
  # Login Hugging Face Hub
31
  HF_TOKEN = os.environ.get("HF_TOKEN")
32
  login(token=HF_TOKEN)
33
  import diffusers
34
 
35
+
36
  # init
37
  dtype = torch.bfloat16
38
  device = "cuda:0"
39
 
40
+
41
  base_model = "black-forest-labs/FLUX.1-Krea-dev"
42
 
 
 
 
 
 
43
 
44
  txt2img_pipe = FluxKontextPipeline.from_pretrained(base_model, torch_dtype=dtype)
45
  txt2img_pipe = txt2img_pipe.to(device)
46
 
47
+
48
  MAX_SEED = 2**32 - 1
49
 
50
+
51
  class calculateDuration:
52
  def __init__(self, activity_name=""):
53
  self.activity_name = activity_name
 
68
  else:
69
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
70
 
71
+
72
  def safe_trim_for_clip(text: str, max_words: int = 77) -> str:
 
73
  tokens = re.split(r"\s+", text.strip())
74
  if len(tokens) <= max_words:
75
  return text
76
  return " ".join(tokens[:max_words])
77
 
78
+
79
  def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
80
  with calculateDuration("Upload images"):
81
  connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
 
94
  buffer.seek(0)
95
  s3.upload_fileobj(buffer, bucket_name, image_file)
96
  print("upload finish", image_file)
97
+
98
+ # Generate thumbnail
99
  thumbnail = image.copy()
100
  thumbnail_width = 256
101
  aspect_ratio = image.height / image.width
102
  thumbnail_height = int(thumbnail_width * aspect_ratio)
103
  thumbnail = thumbnail.resize((thumbnail_width, thumbnail_height), Image.LANCZOS)
104
 
 
105
  thumbnail_file = image_file.replace(".png", "_thumbnail.png")
106
 
 
107
  thumbnail_buffer = BytesIO()
108
  thumbnail.save(thumbnail_buffer, "PNG")
109
  thumbnail_buffer.seek(0)
 
111
  print("upload thumbnail finish", thumbnail_file)
112
  return image_file
113
 
114
+
115
  def generate_random_4_digit_string():
116
  return ''.join(random.choices(string.digits, k=4))
117
 
118
+
119
  @spaces.GPU(duration=120)
120
  def run_lora(
121
  prompt,
 
153
  try:
154
  pipe.unload_lora_weights()
155
  except Exception as _:
 
156
  pass
157
 
158
  adapter_names = []
 
170
  for lora_info in lora_configs:
171
  repo = lora_info.get("repo")
172
  weights = lora_info.get("weights")
 
173
  adapter_name = lora_info.get("adapter_name") or f"adp_{generate_random_4_digit_string()}"
174
  weight = float(lora_info.get("adapter_weight", 1.0))
175
  if not (repo and weights):
 
177
  continue
178
  try:
179
  weight_path = hf_hub_download(repo_id=repo, filename=weights)
 
180
  pipe.load_lora_weights(weight_path, adapter_name=adapter_name, prefix=None)
181
  adapter_names.append(adapter_name)
182
  adapter_weights.append(weight)
 
185
 
186
  if adapter_names:
187
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
 
 
188
  try:
189
  active = pipe.get_active_adapters() if hasattr(pipe, "get_active_adapters") else []
190
  print("Active adapters:", active)
 
198
  lora_layer_count += 1
199
  print(f"[DEBUG] transformer LoRA layers: {lora_layer_count}")
200
 
 
201
  if lora_layer_count == 0:
202
  gr.Warning("LoRA seems not injected (0 layers on transformer). Check whether the LoRA is trained for FLUX and `prefix=None` is set.")
203
 
 
204
  pipe.enable_vae_slicing()
205
  clip_side_prompt = safe_trim_for_clip(prompt, max_words=77)
206
  init_image = None
207
  error_message = ""
208
+ image = None
209
+
210
  try:
211
  gr.Info("Start to generate images ...")
212
  joint_attention_kwargs = {"scale": 1}
 
242
  progress(100, "Completed!")
243
 
244
  # CHANGED: Return both image AND json
245
+ return image, json.dumps(result)
 
 
 
246
 
247
 
248
+ # Gradio interface
249
  with gr.Blocks() as demo:
250
+ gr.Markdown("# flux-dev-multi-lora")
251
  with gr.Row():
252
 
253
  with gr.Column():
 
254
  prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=10)
255
+ lora_strings_json = gr.Text(
256
+ label="LoRA Configs (JSON List String)",
257
+ placeholder='[{"repo": "lora_repo1", "weights": "weights1.safetensors", "adapter_name": "adapter1", "adapter_weight": 1.0}]',
258
+ lines=5
259
+ )
260
  image_url = gr.Text(label="Image url", placeholder="Enter image url to enable image to image model", lines=1)
261
  run_button = gr.Button("Run", scale=0)
262
 
263
  with gr.Accordion("Advanced Settings", open=False):
 
264
  with gr.Row():
265
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
266
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
280
  secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
281
  bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
282
 
 
283
  with gr.Column():
284
+ # CHANGED: Add image output
285
+ output_image = gr.Image(label="Generated Image", type="pil")
286
  json_text = gr.Text(label="Result JSON")
287
 
288
  gr.Markdown("**Disclaimer:**")
289
  gr.Markdown(
290
  "This demo is only for research purpose. This space owner cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. This space owner provides the tools, but the responsibility for their use lies with the individual user."
291
  )
292
+
293
  inputs = [
294
  prompt,
295
  image_url,
 
308
  bucket
309
  ]
310
 
311
+ # CHANGED: Two outputs now
312
+ outputs = [output_image, json_text]
313
 
314
  run_button.click(
315
  fn=run_lora,
 
320
  try:
321
  demo.queue().launch()
322
  except:
323
+ print("demo exception ...")