dssddsdf commited on
Commit
46added
Β·
verified Β·
1 Parent(s): 6b2dd21

feat: Quick Edit tab, LoRA Dataset Generator (50-150 captioned poses), fix LoRA reload for ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +208 -12
app.py CHANGED
@@ -157,8 +157,7 @@ def infer(
157
  if not prompt or not prompt.strip():
158
  raise gr.Error("Enter a prompt!")
159
 
160
- # LoRA activation
161
- global LOADED_ADAPTERS
162
  active_adapters = []
163
  active_weights = []
164
 
@@ -166,28 +165,31 @@ def infer(
166
  style = LORA_MAP.get(lora_name)
167
  if style and style["adapter_name"]:
168
  aname = style["adapter_name"]
169
- if aname not in LOADED_ADAPTERS:
170
- print(f"Loading LoRA: {style['title']}")
171
  pipe.load_lora_weights(style["repo"], weight_name=style["weights"], adapter_name=aname)
172
- LOADED_ADAPTERS.add(aname)
 
173
  active_adapters.append(aname)
174
  active_weights.append(lora_strength)
175
 
176
  # Custom LoRA
177
  if custom_lora_repo and custom_lora_repo.strip() and custom_lora_file and custom_lora_file.strip():
178
  cname = "custom-lora"
179
- if cname not in LOADED_ADAPTERS:
180
  print(f"Loading custom LoRA: {custom_lora_repo}/{custom_lora_file}")
181
  pipe.load_lora_weights(custom_lora_repo.strip(), weight_name=custom_lora_file.strip(), adapter_name=cname)
182
- LOADED_ADAPTERS.add(cname)
 
183
  active_adapters.append(cname)
184
  active_weights.append(custom_lora_strength)
185
 
186
  if active_adapters:
187
  pipe.set_adapters(active_adapters, adapter_weights=active_weights)
188
- print(f"Active: {list(zip(active_adapters, active_weights))}")
189
  else:
190
  pipe.disable_lora()
 
191
 
192
  if randomize_seed:
193
  seed = random.randint(0, MAX_SEED)
@@ -230,14 +232,14 @@ def generate_character_sheet(
230
  ref = pil_images[0]
231
  prefix = custom_prefix.strip() + " " if custom_prefix and custom_prefix.strip() else ""
232
 
233
- # Activate LoRA
234
- global LOADED_ADAPTERS
235
  style = LORA_MAP.get(lora_name)
236
  if style and style["adapter_name"]:
237
  aname = style["adapter_name"]
238
- if aname not in LOADED_ADAPTERS:
239
  pipe.load_lora_weights(style["repo"], weight_name=style["weights"], adapter_name=aname)
240
- LOADED_ADAPTERS.add(aname)
 
241
  pipe.set_adapters([aname], adapter_weights=[lora_strength])
242
  else:
243
  pipe.disable_lora()
@@ -325,6 +327,55 @@ with gr.Blocks(css=css) as demo:
325
  outputs=[output_image, seed_output],
326
  )
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  # ==================== CHARACTER SHEET TAB ====================
329
  with gr.TabItem("360 Character Sheet"):
330
  gr.Markdown("Generate a multi-angle character turnaround from a single reference image. Produces 7 views: front/left/right face + front/left/right/back body.")
@@ -361,5 +412,150 @@ with gr.Blocks(css=css) as demo:
361
  outputs=[cs_gallery],
362
  )
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  if __name__ == "__main__":
365
  demo.queue().launch(ssr_mode=False, show_error=True)
 
157
  if not prompt or not prompt.strip():
158
  raise gr.Error("Enter a prompt!")
159
 
160
+ # LoRA activation β€” always reload to survive ZeroGPU tensor packing
 
161
  active_adapters = []
162
  active_weights = []
163
 
 
165
  style = LORA_MAP.get(lora_name)
166
  if style and style["adapter_name"]:
167
  aname = style["adapter_name"]
168
+ try:
169
+ print(f"Loading LoRA: {style['title']} (strength={lora_strength})")
170
  pipe.load_lora_weights(style["repo"], weight_name=style["weights"], adapter_name=aname)
171
+ except ValueError:
172
+ pass # already loaded
173
  active_adapters.append(aname)
174
  active_weights.append(lora_strength)
175
 
176
  # Custom LoRA
177
  if custom_lora_repo and custom_lora_repo.strip() and custom_lora_file and custom_lora_file.strip():
178
  cname = "custom-lora"
179
+ try:
180
  print(f"Loading custom LoRA: {custom_lora_repo}/{custom_lora_file}")
181
  pipe.load_lora_weights(custom_lora_repo.strip(), weight_name=custom_lora_file.strip(), adapter_name=cname)
182
+ except ValueError:
183
+ pass # already loaded
184
  active_adapters.append(cname)
185
  active_weights.append(custom_lora_strength)
186
 
187
  if active_adapters:
188
  pipe.set_adapters(active_adapters, adapter_weights=active_weights)
189
+ print(f"Active LoRAs: {list(zip(active_adapters, active_weights))}")
190
  else:
191
  pipe.disable_lora()
192
+ print("No LoRA active")
193
 
194
  if randomize_seed:
195
  seed = random.randint(0, MAX_SEED)
 
232
  ref = pil_images[0]
233
  prefix = custom_prefix.strip() + " " if custom_prefix and custom_prefix.strip() else ""
234
 
235
+ # Activate LoRA β€” always reload for ZeroGPU
 
236
  style = LORA_MAP.get(lora_name)
237
  if style and style["adapter_name"]:
238
  aname = style["adapter_name"]
239
+ try:
240
  pipe.load_lora_weights(style["repo"], weight_name=style["weights"], adapter_name=aname)
241
+ except ValueError:
242
+ pass
243
  pipe.set_adapters([aname], adapter_weights=[lora_strength])
244
  else:
245
  pipe.disable_lora()
 
327
  outputs=[output_image, seed_output],
328
  )
329
 
330
+ # ==================== SIMPLE IMAGE EDIT TAB ====================
331
+ with gr.TabItem("Quick Image Edit"):
332
+ gr.Markdown("Upload one image and describe what to change. No LoRA β€” pure Klein 9B editing.")
333
+ with gr.Row():
334
+ with gr.Column():
335
+ qe_image = gr.Image(label="Source Image", type="pil", sources=["upload"], height=350)
336
+ qe_prompt = gr.Textbox(label="Edit instruction", lines=2,
337
+ placeholder="e.g. remove the shirt, change hair to blonde, add sunglasses...")
338
+ with gr.Row():
339
+ qe_steps = gr.Slider(1, 50, value=4, step=1, label="Steps")
340
+ qe_guidance = gr.Slider(0.0, 10.0, value=1.0, step=0.1, label="Guidance")
341
+ qe_seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
342
+ qe_rand = gr.Checkbox(value=True, label="Randomize seed")
343
+ qe_btn = gr.Button("Edit Image", variant="primary", size="lg")
344
+ with gr.Column():
345
+ qe_output = gr.Image(label="Result", interactive=False, format="png", height=500)
346
+ qe_seed_out = gr.Number(label="Seed Used", interactive=False)
347
+
348
+ @spaces.GPU
349
+ def quick_edit(image, prompt, steps, guidance, seed, randomize):
350
+ gc.collect()
351
+ torch.cuda.empty_cache()
352
+ try:
353
+ if image is None:
354
+ raise gr.Error("Upload an image!")
355
+ if not prompt or not prompt.strip():
356
+ raise gr.Error("Enter an edit instruction!")
357
+ pipe.disable_lora()
358
+ if randomize:
359
+ seed = random.randint(0, MAX_SEED)
360
+ w, h = update_dimensions(image)
361
+ img = image.resize((w, h), Image.LANCZOS).convert("RGB")
362
+ result = pipe(
363
+ image=img, prompt=prompt,
364
+ guidance_scale=guidance, width=w, height=h,
365
+ num_inference_steps=steps,
366
+ generator=torch.Generator(device=device).manual_seed(seed),
367
+ ).images[0]
368
+ return result, seed
369
+ finally:
370
+ gc.collect()
371
+ torch.cuda.empty_cache()
372
+
373
+ qe_btn.click(
374
+ fn=quick_edit,
375
+ inputs=[qe_image, qe_prompt, qe_steps, qe_guidance, qe_seed, qe_rand],
376
+ outputs=[qe_output, qe_seed_out],
377
+ )
378
+
379
  # ==================== CHARACTER SHEET TAB ====================
380
  with gr.TabItem("360 Character Sheet"):
381
  gr.Markdown("Generate a multi-angle character turnaround from a single reference image. Produces 7 views: front/left/right face + front/left/right/back body.")
 
412
  outputs=[cs_gallery],
413
  )
414
 
415
+ # ==================== DATASET GENERATOR TAB ====================
416
+ with gr.TabItem("LoRA Dataset Generator"):
417
+ gr.Markdown("Generate a captioned image dataset from a single reference for LoRA training. Each image gets a text caption file.")
418
+
419
+ POSE_LIBRARY = [
420
+ "standing facing the camera, neutral pose, arms at sides",
421
+ "standing with arms crossed, confident pose",
422
+ "standing with hands on hips",
423
+ "standing, slight lean to the left, relaxed",
424
+ "standing three-quarter view from the left",
425
+ "standing three-quarter view from the right",
426
+ "standing side profile, looking right",
427
+ "standing side profile, looking left",
428
+ "standing from behind, back view",
429
+ "standing over the shoulder look, glancing back at camera",
430
+ "sitting on a chair, legs crossed, relaxed",
431
+ "sitting on the floor, legs extended",
432
+ "sitting cross-legged on the ground",
433
+ "sitting on a stool, leaning forward slightly",
434
+ "sitting sideways on a chair, arm draped over backrest",
435
+ "kneeling on one knee",
436
+ "kneeling on both knees, upright posture",
437
+ "leaning against a wall, arms crossed",
438
+ "leaning against a wall, one foot up",
439
+ "leaning forward with hands on knees",
440
+ "walking towards the camera, mid-stride",
441
+ "walking away from camera, back view mid-stride",
442
+ "walking side view, profile mid-stride",
443
+ "running towards the camera, dynamic pose",
444
+ "looking up at the sky, chin raised",
445
+ "looking down, contemplative",
446
+ "head tilted to the left, slight smile",
447
+ "head tilted to the right, serious expression",
448
+ "laughing naturally, candid expression",
449
+ "hands behind head, stretching",
450
+ "one hand touching hair, casual pose",
451
+ "hands in pockets, casual standing",
452
+ "waving at camera, friendly gesture",
453
+ "pointing at camera, direct gesture",
454
+ "arms raised above head, celebratory",
455
+ "crouching down, low angle",
456
+ "bending forward slightly, looking at camera",
457
+ "twisting torso, looking over shoulder",
458
+ "dancing pose, one leg lifted",
459
+ "yoga tree pose, balanced on one leg",
460
+ "lying on back, looking up at camera from above",
461
+ "lying on side, propped on elbow",
462
+ "lying on stomach, chin in hands",
463
+ "close-up portrait, direct eye contact",
464
+ "close-up portrait, eyes looking away",
465
+ "close-up portrait, slight smile",
466
+ "close-up portrait, serious expression",
467
+ "medium shot from waist up, arms at sides",
468
+ "medium shot from waist up, one hand raised",
469
+ "full body shot, standing tall, power pose",
470
+ ]
471
+
472
+ with gr.Row():
473
+ with gr.Column(scale=1):
474
+ ds_ref = gr.Gallery(label="Reference Image", type="filepath", columns=1, rows=1, height=200)
475
+ ds_subject = gr.Textbox(
476
+ label="Subject description (used as caption prefix)",
477
+ placeholder="e.g. a woman with red hair, green eyes, freckles",
478
+ lines=2,
479
+ )
480
+ ds_extra = gr.Textbox(
481
+ label="Extra prompt (appended to each pose)",
482
+ placeholder="e.g. nude, studio lighting, white background",
483
+ lines=1,
484
+ )
485
+ ds_count = gr.Slider(10, 150, value=50, step=5, label="Number of images")
486
+ ds_lora = gr.Dropdown(LORA_TITLES, value="None (Base Klein 9B)", label="LoRA")
487
+ ds_lora_str = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="LoRA Strength")
488
+ with gr.Row():
489
+ ds_seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Starting Seed")
490
+ ds_guidance = gr.Slider(0.0, 10.0, value=1.0, step=0.1, label="Guidance")
491
+ ds_steps = gr.Slider(1, 50, value=4, step=1, label="Steps")
492
+ ds_btn = gr.Button("Generate Dataset", variant="primary", size="lg")
493
+
494
+ with gr.Column(scale=2):
495
+ ds_gallery = gr.Gallery(label="Generated Dataset", columns=5, rows=3, height=500, object_fit="contain")
496
+ ds_status = gr.Textbox(label="Status / Captions Preview", lines=8, interactive=False)
497
+
498
+ @spaces.GPU(duration=300)
499
+ def generate_dataset(ref_images, subject, extra, count, lora_name, lora_str,
500
+ seed, guidance, steps, progress=gr.Progress(track_tqdm=True)):
501
+ gc.collect()
502
+ torch.cuda.empty_cache()
503
+ try:
504
+ pil_images = process_gallery_images(ref_images)
505
+ if not pil_images:
506
+ raise gr.Error("Upload a reference image!")
507
+
508
+ ref = pil_images[0]
509
+ count = int(count)
510
+ poses = (POSE_LIBRARY * ((count // len(POSE_LIBRARY)) + 1))[:count]
511
+
512
+ # LoRA
513
+ style = LORA_MAP.get(lora_name)
514
+ if style and style["adapter_name"]:
515
+ try:
516
+ pipe.load_lora_weights(style["repo"], weight_name=style["weights"],
517
+ adapter_name=style["adapter_name"])
518
+ except ValueError:
519
+ pass
520
+ pipe.set_adapters([style["adapter_name"]], adapter_weights=[lora_str])
521
+ else:
522
+ pipe.disable_lora()
523
+
524
+ w, h = update_dimensions(ref)
525
+ ref_resized = ref.resize((w, h), Image.LANCZOS).convert("RGB")
526
+
527
+ results = []
528
+ captions = []
529
+ subject_text = subject.strip() if subject else "a person"
530
+ extra_text = ", " + extra.strip() if extra and extra.strip() else ""
531
+
532
+ for i, pose in enumerate(poses):
533
+ progress((i + 1) / count, desc=f"Image {i+1}/{count}")
534
+ caption = f"{subject_text}, {pose}{extra_text}"
535
+ gen = torch.Generator(device=device).manual_seed(seed + i)
536
+ img = pipe(
537
+ image=ref_resized, prompt=caption,
538
+ guidance_scale=guidance, width=w, height=h,
539
+ num_inference_steps=steps, generator=gen,
540
+ ).images[0]
541
+ results.append((img, f"{i:03d}"))
542
+ captions.append(f"{i:03d}.txt: {caption}")
543
+
544
+ status = f"Generated {count} images.\n\nCaption preview:\n" + "\n".join(captions[:10])
545
+ if count > 10:
546
+ status += f"\n... and {count - 10} more"
547
+
548
+ return results, status
549
+ finally:
550
+ gc.collect()
551
+ torch.cuda.empty_cache()
552
+
553
+ ds_btn.click(
554
+ fn=generate_dataset,
555
+ inputs=[ds_ref, ds_subject, ds_extra, ds_count, ds_lora, ds_lora_str,
556
+ ds_seed, ds_guidance, ds_steps],
557
+ outputs=[ds_gallery, ds_status],
558
+ )
559
+
560
  if __name__ == "__main__":
561
  demo.queue().launch(ssr_mode=False, show_error=True)