Chyd19 commited on
Commit
446e4eb
·
verified ·
1 Parent(s): c7bda74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -539
app.py CHANGED
@@ -20,7 +20,6 @@ def free_gpu_cache():
20
  # =========================
21
  # MODELS
22
  # =========================
23
- # Image generation
24
  gen_pipe = DiffusionPipeline.from_pretrained(
25
  "stabilityai/sdxl-turbo",
26
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
@@ -31,7 +30,6 @@ dreamshaper_pipe = DiffusionPipeline.from_pretrained(
31
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
32
  ).to(device)
33
 
34
- # Captioning
35
  captioner = pipeline(
36
  "image-to-text",
37
  model="Salesforce/blip-image-captioning-large",
@@ -39,7 +37,6 @@ captioner = pipeline(
39
  generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
40
  )
41
 
42
- # NLP
43
  sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
44
  device=0 if device=="cuda" else -1)
45
  ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
@@ -47,16 +44,13 @@ ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-engl
47
  topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
48
  device=0 if device=="cuda" else -1)
49
 
50
- # VQA
51
  vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
52
  vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
53
 
54
- # Metrics
55
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
56
  lpips_model = lpips.LPIPS(net='alex').to(device)
57
  lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
58
 
59
- # Styles
60
  style_map = {
61
  "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
62
  "Real Life": "natural lighting, true-to-life colors, DSLR",
@@ -74,9 +68,7 @@ style_map = {
74
  # IMAGE GENERATION FUNCTIONS
75
  # =========================
76
  def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images):
77
- images = images or []
78
- base_caption = base_caption or ""
79
- enhancer = enhancer or ""
80
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
81
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
82
  try:
@@ -92,14 +84,12 @@ def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style,
92
  print("SD Turbo failed:", e)
93
  img = None
94
  if img:
95
- images.append(img)
96
  free_gpu_cache()
97
  return img, images
98
 
99
  def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images):
100
- images = images or []
101
- base_caption = base_caption or ""
102
- enhancer = enhancer or ""
103
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
104
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
105
  try:
@@ -115,7 +105,7 @@ def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, s
115
  print("DreamShaper failed:", e)
116
  img = None
117
  if img:
118
- images.append(img)
119
  free_gpu_cache()
120
  return img, images
121
 
@@ -133,7 +123,7 @@ def caption_for_image(img):
133
  # VQA
134
  # =========================
135
  def answer_vqa(question, image):
136
- if not image or not question.strip():
137
  return "Provide image + question."
138
  try:
139
  inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
@@ -175,14 +165,13 @@ def compute_metrics(images, captions, i1, i2):
175
  else:
176
  bert_f1 = 0.0
177
 
178
- return clip_sim, lp, bert_f1
179
 
180
  # =========================
181
  # GRADIO UI BUILD
182
  # =========================
183
  def build_full_ui():
184
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
185
- # --- CSS Styling ---
186
  gr.HTML("""
187
  <style>
188
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
@@ -196,13 +185,10 @@ def build_full_ui():
196
  </style>
197
  """)
198
 
199
- # --- States ---
200
  images_state = gr.State([None, None, None])
201
  captions_state = gr.State(["", "", ""])
202
 
203
- # =========================
204
- # Section 1: Upload Reference Image
205
- # =========================
206
  gr.Markdown("## 1️⃣ Upload Reference Image", elem_classes="heading-orange")
207
  with gr.Row(elem_classes="equal-height-row"):
208
  with gr.Column(scale=1):
@@ -213,24 +199,17 @@ def build_full_ui():
213
  upload_preview = gr.Image(label="Uploaded Image", interactive=False)
214
  caption_out = gr.Markdown(label="Generated Caption")
215
 
216
- # Upload & caption function
217
  def upload_and_caption(img, images_state, captions_state):
218
  if img is None:
219
  return None, "No image uploaded.", images_state, captions_state
220
  images_state[0] = img
221
- try:
222
- cap = caption_for_image(img)
223
- except:
224
- cap = "Caption failed."
225
- captions_state[0] = cap
226
- return img, cap, images_state, captions_state
227
 
228
  upload_btn.click(upload_and_caption, inputs=[upload_input, images_state, captions_state],
229
  outputs=[upload_preview, caption_out, images_state, captions_state])
230
 
231
- # =========================
232
- # Section 2: Generate SD-Turbo & DreamShaper
233
- # =========================
234
  gr.Markdown("## 2️⃣ Generate Images from Caption", elem_classes="heading-orange")
235
  with gr.Row():
236
  with gr.Column(scale=1):
@@ -240,16 +219,14 @@ def build_full_ui():
240
  ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
241
  ds_preview = gr.Image(label="DreamShaper Image", interactive=False)
242
 
243
- # Generate SD-Turbo
244
  def generate_sd(caption, enhancer, images_state, captions_state):
245
- img, images_state = generate_image_with_enhancer(caption, enhancer, negative="", seed=42, style="Photorealistic", images=images_state)
246
  if img:
247
  captions_state[1] = caption_for_image(img)
248
  return img, images_state, captions_state
249
 
250
- # Generate DreamShaper
251
  def generate_ds(caption, enhancer, images_state, captions_state):
252
- img, images_state = generate_dreamshaper_with_enhancer(caption, enhancer, negative="", seed=123, style="Photorealistic", images=images_state)
253
  if img:
254
  captions_state[2] = caption_for_image(img)
255
  return img, images_state, captions_state
@@ -259,79 +236,29 @@ def build_full_ui():
259
  ds_btn.click(generate_ds, inputs=[caption_out, enhancer_box, images_state, captions_state],
260
  outputs=[ds_preview, images_state, captions_state])
261
 
262
- # =========================
263
- # Section 3: Compute Pairwise Metrics (Side-by-Side)
264
- # =========================
265
  gr.Markdown("## 3️⃣ Compute Pairwise Metrics", elem_classes="heading-orange")
266
  metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
267
  metrics_spinner = gr.HTML("<div style='height:4px;'></div>")
268
- metrics_out = gr.HTML()
 
 
269
 
270
  def compute_metrics_ui(images, captions):
271
- yield "<div class='loading-line'></div>", ""
272
  if any(i is None for i in images):
273
- yield "All three images and captions are required."
 
274
  else:
275
- try:
276
- A = compute_metrics(images, captions, 0, 1)
277
- B = compute_metrics(images, captions, 0, 2)
278
- C = compute_metrics(images, captions, 1, 2)
279
- def fmt(m):
280
- return f"CLIP: {m[0]:.3f}<br>LPIPS: {m[1]:.3f}<br>BERTScore F1: {m[2]:.3f}"
281
- html = f"""
282
- <div style='display:flex; gap:40px; justify-content:space-around;'>
283
- <div style='text-align:center;'><b>Metrics A</b><br>{fmt(A)}</div>
284
- <div style='text-align:center;'><b>Metrics B</b><br>{fmt(B)}</div>
285
- <div style='text-align:center;'><b>Metrics C</b><br>{fmt(C)}</div>
286
- </div>
287
- """
288
- yield html
289
- except Exception as e:
290
- print("Metrics error:", e)
291
- yield "Failed to compute metrics."
292
 
293
  metrics_btn.click(compute_metrics_ui, inputs=[images_state, captions_state],
294
- outputs=[metrics_out])
295
-
296
- # =========================
297
- # Section 4: NLP Analysis
298
- # =========================
299
- gr.Markdown("## 4️⃣ NLP Analysis of Captions", elem_classes="heading-orange")
300
- nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
301
- nlp_spinner = gr.HTML("<div style='height:4px;'></div>")
302
- nlp_out = gr.HTML()
303
-
304
- def analyze_captions_ui(captions):
305
- yield "<div class='loading-line'></div>", ""
306
- if any(c=="" for c in captions):
307
- yield "<b>All three captions are required for NLP analysis.</b>"
308
- else:
309
- labels = ["Reference", "SD-Turbo", "DreamShaper"]
310
- blocks = []
311
- for label, caption in zip(labels, captions):
312
- try:
313
- sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)])
314
- except:
315
- sentiment = "Sentiment failed."
316
- try:
317
- ents_list = ner_model(caption)
318
- ents = "<br>".join([f"{e.get('entity_group','')}: {e.get('word','')}" for e in ents_list]) or "None"
319
- except:
320
- ents = "NER failed."
321
- try:
322
- topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature'])
323
- topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data.get('labels',[]), topics_data.get('scores',[]))])
324
- except:
325
- topics = "Topics failed."
326
- block = f"<div style='flex:1;padding:10px;min-width:250px;'><h3><u>{label}</u></h3><b>Sentiment</b><br>{sentiment}<br><br><b>Entities</b><br>{ents}<br><br><b>Topics</b><br>{topics}</div>"
327
- blocks.append(block)
328
- yield f"<div style='display:flex; gap:20px; justify-content:space-between;'>{''.join(blocks)}</div>"
329
-
330
- nlp_btn.click(analyze_captions_ui, inputs=[captions_state], outputs=[nlp_out])
331
 
332
- # =========================
333
- # Section 5: Visual Question Answering
334
- # =========================
335
  gr.Markdown("## 5️⃣ Visual Question Answering (VQA)", elem_classes="heading-orange")
336
  with gr.Row():
337
  with gr.Column(scale=1):
@@ -341,19 +268,12 @@ def build_full_ui():
341
  vqa_spinner = gr.HTML("<div style='height:4px;'></div>")
342
  vqa_out = gr.Markdown(label="VQA Output")
343
 
344
- def vqa_ui(question, image):
345
  yield "<div class='loading-line'></div>", ""
346
- if not question.strip() or image is None:
347
- yield "Provide image + question."
348
- else:
349
- try:
350
- ans = answer_vqa(question, image)
351
- yield f"<b>Answer:</b> {ans}"
352
- except Exception as e:
353
- print("VQA error:", e)
354
- yield "Could not determine the answer."
355
 
356
- vqa_btn.click(vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out])
357
 
358
  return demo
359
 
@@ -361,269 +281,11 @@ def build_full_ui():
361
  demo = build_full_ui()
362
  demo.launch()
363
 
364
-
365
-
366
  """
367
- # Dump code
368
- # ==============================
369
- # Libraries
370
- # ==============================
371
- import torch
372
- import gradio as gr
373
- from PIL import Image
374
- from diffusers import DiffusionPipeline
375
- from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
376
- import lpips
377
- import clip
378
- from bert_score import score
379
- import torchvision.transforms as T
380
-
381
- device = "cuda" if torch.cuda.is_available() else "cpu"
382
-
383
- def free_gpu_cache():
384
- if device == "cuda":
385
- torch.cuda.empty_cache()
386
-
387
- # ==============================
388
- # Load Models (HF-ready, memory safe)
389
- # ==============================
390
- # SDXL-Turbo
391
- gen_pipe = DiffusionPipeline.from_pretrained(
392
- "stabilityai/sdxl-turbo",
393
- torch_dtype=torch.float16 if device=="cuda" else torch.float32
394
- ).to(device)
395
-
396
- # DreamShaper
397
- dreamshaper_pipe = DiffusionPipeline.from_pretrained(
398
- "Lykon/dreamshaper-7",
399
- torch_dtype=torch.float16 if device=="cuda" else torch.float32
400
- ).to(device)
401
-
402
- # BLIP Captioning
403
- captioner = pipeline(
404
- "image-to-text",
405
- model="Salesforce/blip-image-captioning-large",
406
- device=0 if device=="cuda" else -1,
407
- generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
408
- )
409
-
410
- # Sentiment / NER / Topic
411
- sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
412
- device=0 if device=="cuda" else -1)
413
- ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
414
- aggregation_strategy="simple", device=0 if device=="cuda" else -1)
415
- topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
416
- device=0 if device=="cuda" else -1)
417
-
418
- # BLIP VQA
419
- vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
420
- vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
421
-
422
- # CLIP / LPIPS
423
- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
424
- lpips_model = lpips.LPIPS(net='alex').to(device)
425
- lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
426
-
427
- # Style map
428
- style_map = {
429
- "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
430
- "Real Life": "natural lighting, true-to-life colors, DSLR",
431
- "Documentary": "documentary handheld muted colors",
432
- "iPhone Camera": "iPhone photo natural HDR",
433
- "Street Photography": "candid street ambient shadows",
434
- "Cinematic": "cinematic lighting dramatic depth",
435
- "Anime": "anime cel shaded vibrant",
436
- "Watercolor": "watercolor soft wash art",
437
- "Macro": "macro lens shallow DOF",
438
- "Cyberpunk": "neon cyberpunk futuristic",
439
- }
440
-
441
- # ==============================
442
- # Functions
443
- # ==============================
444
-
445
- def generate_image(pipe, caption, enhancer, negative, seed, style):
446
- final_prompt = f"{caption}, {enhancer}".strip(", ")
447
- final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
448
-
449
- try:
450
- seed = int(seed)
451
- except:
452
- seed = 42
453
-
454
- generator = torch.Generator(device="cpu").manual_seed(seed)
455
- img = None
456
-
457
- try:
458
- with torch.no_grad():
459
- out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator, height=512, width=512)
460
- img = out.images[0]
461
- except Exception as e:
462
- print(f"{pipe} generation failed:", e)
463
-
464
- free_gpu_cache()
465
- return img
466
-
467
- def caption_for_image(img):
468
- try:
469
- out = captioner(img)
470
- return out[0]["generated_text"]
471
- except:
472
- return "Caption failed."
473
-
474
- def compute_metrics(images, captions, i1, i2):
475
- img1, img2 = images[i1], images[i2]
476
- cap1, cap2 = captions[i1], captions[i2]
477
-
478
- # CLIP similarity
479
- t1, t2 = clip_preprocess(img1).unsqueeze(0).to(device), clip_preprocess(img2).unsqueeze(0).to(device)
480
- with torch.no_grad():
481
- f1, f2 = clip_model.encode_image(t1), clip_model.encode_image(t2)
482
- clip_sim = float(torch.cosine_similarity(f1, f2))
483
-
484
- # LPIPS
485
- L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device)
486
- L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device)
487
- with torch.no_grad():
488
- lp = float(lpips_model(L1, L2))
489
-
490
- # BERTScore
491
- if cap1 and cap2:
492
- _, _, F = score([cap1],[cap2], lang="en", verbose=False)
493
- bert_f1 = float(F.mean())
494
- else:
495
- bert_f1 = 0.0
496
-
497
- return clip_sim, lp, bert_f1
498
-
499
- def answer_vqa(question, image):
500
- if not image or not question.strip():
501
- return "Provide image + question."
502
- try:
503
- inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
504
- inputs = {k:v.to("cpu") for k,v in inputs_raw.items()}
505
- with torch.no_grad():
506
- out = vqa_model(**inputs)
507
- ans_id = out.logits.argmax(-1)
508
- return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
509
- except:
510
- return "I could not determine the answer."
511
-
512
- # ==============================
513
- # Gradio UI
514
- # ==============================
515
- def build_ui():
516
- with gr.Blocks(title="Multimodal AI Image Studio") as demo:
517
- images_state = gr.State([None, None, None])
518
- captions_state = gr.State(["", "", ""])
519
-
520
- gr.Markdown("## Multimodal AI Image Studio (HF-ready)")
521
-
522
- # --- Step 1: Upload Reference ---
523
- upload_input = gr.Image(label="Upload Reference Image", type="pil")
524
- upload_btn = gr.Button("Upload & Caption")
525
- upload_preview = gr.Image(interactive=False)
526
- caption_out = gr.Markdown()
527
-
528
- def upload_and_caption(img, images_state, captions_state):
529
- if img is None:
530
- return None, "No image uploaded.", images_state, captions_state
531
- caption = caption_for_image(img)
532
- images_state[0] = img
533
- captions_state[0] = caption
534
- return img, caption, images_state, captions_state
535
-
536
- upload_btn.click(upload_and_caption, inputs=[upload_input, images_state, captions_state],
537
- outputs=[upload_preview, caption_out, images_state, captions_state])
538
-
539
- # --- Step 2: Generate SDXL & DreamShaper ---
540
- sd_btn = gr.Button("Generate SD-Turbo")
541
- ds_btn = gr.Button("Generate DreamShaper")
542
- sd_preview = gr.Image(interactive=False)
543
- ds_preview = gr.Image(interactive=False)
544
-
545
- def gen_sd(caption, images_state, captions_state):
546
- img = generate_image(gen_pipe, caption, enhancer="", negative="", seed=42, style="Photorealistic")
547
- if img:
548
- images_state[1] = img
549
- captions_state[1] = caption_for_image(img)
550
- return img, images_state, captions_state
551
-
552
- def gen_ds(caption, images_state, captions_state):
553
- img = generate_image(dreamshaper_pipe, caption, enhancer="", negative="", seed=123, style="Photorealistic")
554
- if img:
555
- images_state[2] = img
556
- captions_state[2] = caption_for_image(img)
557
- return img, images_state, captions_state
558
-
559
- sd_btn.click(gen_sd, inputs=[caption_out, images_state, captions_state],
560
- outputs=[sd_preview, images_state, captions_state])
561
- ds_btn.click(gen_ds, inputs=[caption_out, images_state, captions_state],
562
- outputs=[ds_preview, images_state, captions_state])
563
-
564
- # --- Step 3: Metrics ---
565
- metrics_btn = gr.Button("Compute Metrics")
566
- metrics_out = gr.Markdown()
567
-
568
- def metrics_ui(images_state, captions_state):
569
- imgs = images_state or []
570
- caps = captions_state or []
571
- if None in imgs or "" in caps:
572
- return "All three images and captions are required."
573
- A = compute_metrics(imgs, caps, 0, 1)
574
- B = compute_metrics(imgs, caps, 0, 2)
575
- C = compute_metrics(imgs, caps, 1, 2)
576
- return f"Reference ↔ SD-Turbo: {A}\nReference ↔ DreamShaper: {B}\nSD-Turbo ↔ DreamShaper: {C}"
577
-
578
- metrics_btn.click(metrics_ui, inputs=[images_state, captions_state], outputs=[metrics_out])
579
-
580
- # --- Step 4: NLP ---
581
- nlp_btn = gr.Button("Analyze Captions")
582
- nlp_out = gr.HTML()
583
-
584
- def analyze_nlp(captions_state):
585
- caps = captions_state or []
586
- if "" in caps:
587
- return "<b>All three captions are required.</b>"
588
- labels = ["Reference", "SD-Turbo", "DreamShaper"]
589
- html_blocks = []
590
- for label, cap in zip(labels, caps):
591
- # Sentiment
592
- sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(cap)])
593
- # Entities
594
- ents_list = ner_model(cap)
595
- ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list])
596
- # Topics
597
- topics_data = topic_model(cap, candidate_labels=['people','animals','objects','food','nature'])
598
- topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])])
599
- html_blocks.append(f"<div style='padding:10px;'><h3>{label}</h3><b>Sentiment</b><br>{sentiment}<br><b>Entities</b><br>{ents}<br><b>Topics</b><br>{topics}</div>")
600
- return "<div style='display:flex;gap:20px;'>" + "".join(html_blocks) + "</div>"
601
-
602
- nlp_btn.click(analyze_nlp, inputs=[captions_state], outputs=[nlp_out])
603
-
604
- # --- Step 5: VQA ---
605
- vqa_input = gr.Textbox(label="Ask about reference image")
606
- vqa_btn = gr.Button("Get Answer")
607
- vqa_out = gr.Markdown()
608
-
609
- def vqa_ui(question, img):
610
- return answer_vqa(question, img)
611
-
612
- vqa_btn.click(vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out])
613
-
614
- return demo
615
-
616
- # Launch
617
- demo = build_ui()
618
- demo.launch()
619
-
620
-
621
- ####################################################################################
622
- # ==============================
623
- # SECTION 1
624
- # ==============================
625
-
626
- # Libraries
627
  import torch
628
  import gradio as gr
629
  from PIL import Image
@@ -640,9 +302,10 @@ def free_gpu_cache():
640
  if device == "cuda":
641
  torch.cuda.empty_cache()
642
 
643
- # ==============================
644
  # MODELS
645
- # ==============================
 
646
  gen_pipe = DiffusionPipeline.from_pretrained(
647
  "stabilityai/sdxl-turbo",
648
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
@@ -653,6 +316,7 @@ dreamshaper_pipe = DiffusionPipeline.from_pretrained(
653
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
654
  ).to(device)
655
 
 
656
  captioner = pipeline(
657
  "image-to-text",
658
  model="Salesforce/blip-image-captioning-large",
@@ -660,6 +324,7 @@ captioner = pipeline(
660
  generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
661
  )
662
 
 
663
  sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
664
  device=0 if device=="cuda" else -1)
665
  ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
@@ -667,13 +332,16 @@ ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-engl
667
  topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
668
  device=0 if device=="cuda" else -1)
669
 
 
670
  vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
671
  vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
672
 
 
673
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
674
  lpips_model = lpips.LPIPS(net='alex').to(device)
675
  lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
676
 
 
677
  style_map = {
678
  "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
679
  "Real Life": "natural lighting, true-to-life colors, DSLR",
@@ -686,25 +354,21 @@ style_map = {
686
  "Macro": "macro lens shallow DOF",
687
  "Cyberpunk": "neon cyberpunk futuristic",
688
  }
689
- # Section 2
690
- # ==============================
691
- # SECTION 2 FUNCTIONS
692
- # ==============================
693
  def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images):
694
  images = images or []
695
  base_caption = base_caption or ""
696
  enhancer = enhancer or ""
697
-
698
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
699
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
700
-
701
  try:
702
  seed = int(seed)
703
  except:
704
  seed = 42
705
-
706
  generator = torch.Generator(device="cpu").manual_seed(seed)
707
-
708
  try:
709
  with torch.no_grad():
710
  out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
@@ -712,10 +376,8 @@ def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style,
712
  except Exception as e:
713
  print("SD Turbo failed:", e)
714
  img = None
715
-
716
  if img:
717
  images.append(img)
718
-
719
  free_gpu_cache()
720
  return img, images
721
 
@@ -723,17 +385,13 @@ def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, s
723
  images = images or []
724
  base_caption = base_caption or ""
725
  enhancer = enhancer or ""
726
-
727
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
728
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
729
-
730
  try:
731
  seed = int(seed)
732
  except:
733
  seed = 42
734
-
735
  generator = torch.Generator(device="cpu").manual_seed(seed)
736
-
737
  try:
738
  with torch.no_grad():
739
  out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
@@ -741,13 +399,14 @@ def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, s
741
  except Exception as e:
742
  print("DreamShaper failed:", e)
743
  img = None
744
-
745
  if img:
746
  images.append(img)
747
-
748
  free_gpu_cache()
749
  return img, images
750
 
 
 
 
751
  def caption_for_image(img):
752
  try:
753
  out = captioner(img)
@@ -755,6 +414,9 @@ def caption_for_image(img):
755
  except:
756
  return "Caption failed."
757
 
 
 
 
758
  def answer_vqa(question, image):
759
  if not image or not question.strip():
760
  return "Provide image + question."
@@ -766,8 +428,11 @@ def answer_vqa(question, image):
766
  ans_id = out.logits.argmax(-1)
767
  return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
768
  except:
769
- return "VQA failed."
770
 
 
 
 
771
  def compute_metrics(images, captions, i1, i2):
772
  img1 = images[i1]
773
  img2 = images[i2]
@@ -797,190 +462,142 @@ def compute_metrics(images, captions, i1, i2):
797
 
798
  return clip_sim, lp, bert_f1
799
 
800
- # ---------------- Build Gradio UI with Custom Look (Fully Robust) ----------------
801
- def build_ui_with_custom_ui():
 
 
802
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
803
- # ---------------- CSS Styling ----------------
804
  gr.HTML(
805
  <style>
806
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
807
  .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
808
  .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; }
809
-
810
- /* Horizontal thin spinner */
811
- .loading-line {
812
- height: 4px;
813
- background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%);
814
- background-size: 200% 100%;
815
- animation: loading 1s linear infinite;
816
- }
817
- @keyframes loading {
818
- 0% { background-position: 200% 0; }
819
- 100% { background-position: -200% 0; }
820
- }
821
-
822
- /* Match enhancer box to upload button */
823
- .enhancer-box textarea {
824
- width: 100% !important;
825
- height: 36px !important;
826
- box-sizing: border-box;
827
- font-size: 14px;
828
- }
829
-
830
- /* Equal-height styling for Step-1 columns */
831
- .equal-height-row {
832
- display: flex;
833
- align-items: stretch;
834
- }
835
- .equal-height-row > .gr-column {
836
- display: flex;
837
- flex-direction: column;
838
- }
839
  </style>
840
  )
841
 
842
- # ---------------- Heading ----------------
843
- gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange")
844
-
845
- # ---------------- States ----------------
846
- images_state = gr.State([])
847
- captions_state = gr.State([])
848
-
849
- # ---------------- Step 1: Upload Reference Image ----------------
850
- gr.Markdown("### Upload Reference Image", elem_classes="heading-orange")
851
 
 
 
 
 
852
  with gr.Row(elem_classes="equal-height-row"):
853
  with gr.Column(scale=1):
854
  upload_input = gr.Image(label="Drag & Drop Image", type="pil")
855
  upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
 
856
  with gr.Column(scale=1):
857
  upload_preview = gr.Image(label="Uploaded Image", interactive=False)
858
- enhancer_box = gr.Textbox(
859
- label="Add Prompt Enhancer (Optional)",
860
- placeholder="Example: 'at night with neon lights', 'wearing a red jacket', etc.",
861
- elem_classes="enhancer-box"
862
- )
863
  caption_out = gr.Markdown(label="Generated Caption")
864
 
865
- # Safe caption generation
866
- def upload_and_generate_caption_ui(img, images_state, captions_state):
867
  if img is None:
868
- return None, "No image uploaded.", images_state or [], captions_state or []
869
-
870
- images = [img]
871
  try:
872
- output = captioner(img)
873
- caption = output[0].get("generated_text", "Caption failed.") if output else "Caption failed."
874
- except Exception as e:
875
- print("Captioning error:", e)
876
- caption = "Caption failed."
877
-
878
- captions = [caption]
879
- return img, caption, images, captions
880
-
881
- upload_btn.click(
882
- upload_and_generate_caption_ui,
883
- inputs=[upload_input, images_state, captions_state],
884
- outputs=[upload_preview, caption_out, images_state, captions_state]
885
- )
886
 
887
- # ---------------- Step 2: Generate SD-Turbo & DreamShaper ----------------
888
- gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange")
 
 
 
 
 
889
  with gr.Row():
890
- with gr.Column(scale=1, min_width=300):
891
  sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn")
892
  sd_preview = gr.Image(label="SD-Turbo Image", interactive=False)
893
- with gr.Column(scale=1, min_width=300):
894
  ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
895
  ds_preview = gr.Image(label="DreamShaper Image", interactive=False)
896
 
897
- def generate_sd_from_caption_ui(caption, enhancer, images_state, captions_state):
898
- images_state = images_state or []
899
- captions_state = captions_state or []
900
- img, images = generate_image_with_enhancer(caption, enhancer="", negative="", seed=42, style="Photorealistic", images=images_state)
901
  if img:
902
- try:
903
- generated_caption = captioner(img)[0].get("generated_text", "Caption failed.")
904
- except:
905
- generated_caption = "Caption failed."
906
- if len(captions_state) >= 2:
907
- captions_state[1] = generated_caption
908
- else:
909
- captions_state.append(generated_caption)
910
- return img, images, captions_state
911
-
912
- def generate_ds_from_caption_ui(caption, enhancer, images_state, captions_state):
913
- images_state = images_state or []
914
- captions_state = captions_state or []
915
- img, images = generate_dreamshaper_with_enhancer(caption, enhancer="", negative="", seed=123, style="Photorealistic", images=images_state)
916
  if img:
917
- try:
918
- generated_caption = captioner(img)[0].get("generated_text", "Caption failed.")
919
- except:
920
- generated_caption = "Caption failed."
921
- if len(captions_state) >= 3:
922
- captions_state[2] = generated_caption
923
- else:
924
- captions_state.append(generated_caption)
925
- return img, images, captions_state
926
-
927
- sd_btn.click(generate_sd_from_caption_ui, inputs=[caption_out, enhancer_box, images_state, captions_state],
928
  outputs=[sd_preview, images_state, captions_state])
929
- ds_btn.click(generate_ds_from_caption_ui, inputs=[caption_out, enhancer_box, images_state, captions_state],
930
  outputs=[ds_preview, images_state, captions_state])
931
 
932
- # ---------------- Step 3: Compute Pairwise Metrics ----------------
933
- gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
 
 
934
  metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
935
- with gr.Row():
936
- metrics_spinner = gr.HTML("<div style='height:4px;'></div>") # single spinner
937
- with gr.Row():
938
- metrics_A = gr.Markdown()
939
- metrics_B = gr.Markdown()
940
- metrics_C = gr.Markdown()
941
-
942
- def compute_metrics_all_pairs_ui(images, captions):
943
- images = images or []
944
- captions = captions or []
945
- # show spinner
946
- yield "<div class='loading-line'></div>", "", "", ""
947
- if len(images) < 3 or len(captions) < 3:
948
- msg = "All three images and captions are required to compute metrics."
949
- yield "", msg, msg, msg
950
  else:
951
  try:
952
  A = compute_metrics(images, captions, 0, 1)
953
  B = compute_metrics(images, captions, 0, 2)
954
  C = compute_metrics(images, captions, 1, 2)
955
- # remove spinner, show results
956
- yield "", f"**Reference ↔ SD-Turbo**\n{A}", f"**Reference ↔ DreamShaper**\n{B}", f"**SD-Turbo ↔ DreamShaper**\n{C}"
 
 
 
 
 
 
 
 
957
  except Exception as e:
958
- print("Metrics computation error:", e)
959
- msg = "Failed to compute metrics."
960
- yield "", msg, msg, msg
961
 
962
- metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
963
- outputs=[metrics_spinner, metrics_A, metrics_B, metrics_C])
964
 
965
- # ---------------- Step 4: NLP Analysis ----------------
966
- gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange")
 
 
967
  nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
968
- nlp_spinner = gr.HTML("<div style='height:4px;'></div>") # single spinner
969
  nlp_out = gr.HTML()
970
 
971
- def analyze_caption_pipeline_ui(captions):
972
- captions = captions or []
973
  yield "<div class='loading-line'></div>", ""
974
- if len(captions) < 3:
975
- yield "", "<b>All three captions are required for NLP analysis.</b>"
976
  else:
977
- labels = ["Reference Image", "SD-Turbo", "DreamShaper"]
978
  blocks = []
979
  for label, caption in zip(labels, captions):
980
  try:
981
  sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)])
982
  except:
983
- sentiment = "Sentiment analysis failed."
984
  try:
985
  ents_list = ner_model(caption)
986
  ents = "<br>".join([f"{e.get('entity_group','')}: {e.get('word','')}" for e in ents_list]) or "None"
@@ -990,37 +607,44 @@ def build_ui_with_custom_ui():
990
  topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature'])
991
  topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data.get('labels',[]), topics_data.get('scores',[]))])
992
  except:
993
- topics = "Topic modeling failed."
994
  block = f"<div style='flex:1;padding:10px;min-width:250px;'><h3><u>{label}</u></h3><b>Sentiment</b><br>{sentiment}<br><br><b>Entities</b><br>{ents}<br><br><b>Topics</b><br>{topics}</div>"
995
  blocks.append(block)
996
- yield "", f"<div style='display:flex; gap:20px; justify-content:space-between;'>{''.join(blocks)}</div>"
997
 
998
- nlp_btn.click(analyze_caption_pipeline_ui, inputs=[captions_state], outputs=[nlp_spinner, nlp_out])
999
 
1000
- # ---------------- Step 5: Visual Question Answering (VQA) ----------------
1001
- gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
 
 
1002
  with gr.Row():
1003
  with gr.Column(scale=1):
1004
  vqa_input = gr.Textbox(label="Enter a question about the reference image")
1005
  vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
1006
  with gr.Column(scale=1):
1007
- vqa_spinner = gr.HTML("<div style='height:4px;'></div>") # single spinner
1008
  vqa_out = gr.Markdown(label="VQA Output")
1009
 
1010
- def answer_vqa_ui(question, image):
1011
  yield "<div class='loading-line'></div>", ""
1012
- try:
1013
- ans = answer_vqa(question, image)
1014
- except Exception as e:
1015
- print("VQA error:", e)
1016
- ans = "I could not determine the answer."
1017
- yield "", ans
 
 
 
1018
 
1019
- vqa_btn.click(answer_vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_spinner, vqa_out])
1020
 
1021
  return demo
1022
 
1023
- # Launch the interface
1024
- demo = build_ui_with_custom_ui()
1025
  demo.launch()
 
 
1026
  """
 
20
  # =========================
21
  # MODELS
22
  # =========================
 
23
  gen_pipe = DiffusionPipeline.from_pretrained(
24
  "stabilityai/sdxl-turbo",
25
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
 
30
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
31
  ).to(device)
32
 
 
33
  captioner = pipeline(
34
  "image-to-text",
35
  model="Salesforce/blip-image-captioning-large",
 
37
  generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
38
  )
39
 
 
40
  sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
41
  device=0 if device=="cuda" else -1)
42
  ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
 
44
  topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
45
  device=0 if device=="cuda" else -1)
46
 
 
47
  vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
48
  vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
49
 
 
50
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
51
  lpips_model = lpips.LPIPS(net='alex').to(device)
52
  lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
53
 
 
54
  style_map = {
55
  "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
56
  "Real Life": "natural lighting, true-to-life colors, DSLR",
 
68
  # IMAGE GENERATION FUNCTIONS
69
  # =========================
70
  def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images):
71
+ images = images or [None, None, None]
 
 
72
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
73
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
74
  try:
 
84
  print("SD Turbo failed:", e)
85
  img = None
86
  if img:
87
+ images[1] = img # Always put SD-Turbo at index 1
88
  free_gpu_cache()
89
  return img, images
90
 
91
  def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images):
92
+ images = images or [None, None, None]
 
 
93
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
94
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
95
  try:
 
105
  print("DreamShaper failed:", e)
106
  img = None
107
  if img:
108
+ images[2] = img # Always put DreamShaper at index 2
109
  free_gpu_cache()
110
  return img, images
111
 
 
123
  # VQA
124
  # =========================
125
  def answer_vqa(question, image):
126
+ if image is None or not question.strip():
127
  return "Provide image + question."
128
  try:
129
  inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
 
165
  else:
166
  bert_f1 = 0.0
167
 
168
+ return f"CLIP: {clip_sim:.2f}\nLPIPS: {lp:.2f}\nBERTScore F1: {bert_f1:.2f}"
169
 
170
  # =========================
171
  # GRADIO UI BUILD
172
  # =========================
173
  def build_full_ui():
174
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
 
175
  gr.HTML("""
176
  <style>
177
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
 
185
  </style>
186
  """)
187
 
 
188
  images_state = gr.State([None, None, None])
189
  captions_state = gr.State(["", "", ""])
190
 
191
+ # --- Upload Section ---
 
 
192
  gr.Markdown("## 1️⃣ Upload Reference Image", elem_classes="heading-orange")
193
  with gr.Row(elem_classes="equal-height-row"):
194
  with gr.Column(scale=1):
 
199
  upload_preview = gr.Image(label="Uploaded Image", interactive=False)
200
  caption_out = gr.Markdown(label="Generated Caption")
201
 
 
202
  def upload_and_caption(img, images_state, captions_state):
203
  if img is None:
204
  return None, "No image uploaded.", images_state, captions_state
205
  images_state[0] = img
206
+ captions_state[0] = caption_for_image(img)
207
+ return img, captions_state[0], images_state, captions_state
 
 
 
 
208
 
209
  upload_btn.click(upload_and_caption, inputs=[upload_input, images_state, captions_state],
210
  outputs=[upload_preview, caption_out, images_state, captions_state])
211
 
212
+ # --- Generate SD-Turbo & DreamShaper ---
 
 
213
  gr.Markdown("## 2️⃣ Generate Images from Caption", elem_classes="heading-orange")
214
  with gr.Row():
215
  with gr.Column(scale=1):
 
219
  ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
220
  ds_preview = gr.Image(label="DreamShaper Image", interactive=False)
221
 
 
222
  def generate_sd(caption, enhancer, images_state, captions_state):
223
+ img, images_state = generate_image_with_enhancer(caption, enhancer, "", 42, "Photorealistic", images_state)
224
  if img:
225
  captions_state[1] = caption_for_image(img)
226
  return img, images_state, captions_state
227
 
 
228
  def generate_ds(caption, enhancer, images_state, captions_state):
229
+ img, images_state = generate_dreamshaper_with_enhancer(caption, enhancer, "", 123, "Photorealistic", images_state)
230
  if img:
231
  captions_state[2] = caption_for_image(img)
232
  return img, images_state, captions_state
 
236
  ds_btn.click(generate_ds, inputs=[caption_out, enhancer_box, images_state, captions_state],
237
  outputs=[ds_preview, images_state, captions_state])
238
 
239
+ # --- Compute Metrics ---
 
 
240
  gr.Markdown("## 3️⃣ Compute Pairwise Metrics", elem_classes="heading-orange")
241
  metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
242
  metrics_spinner = gr.HTML("<div style='height:4px;'></div>")
243
+ metrics_A = gr.Markdown()
244
+ metrics_B = gr.Markdown()
245
+ metrics_C = gr.Markdown()
246
 
247
  def compute_metrics_ui(images, captions):
248
+ yield "<div class='loading-line'></div>", "", "", ""
249
  if any(i is None for i in images):
250
+ msg = "All three images and captions are required."
251
+ yield "", msg, msg, msg
252
  else:
253
+ A = compute_metrics(images, captions, 0, 1)
254
+ B = compute_metrics(images, captions, 0, 2)
255
+ C = compute_metrics(images, captions, 1, 2)
256
+ yield "", f"**Reference ↔ SD-Turbo**\n{A}", f"**Reference ↔ DreamShaper**\n{B}", f"**SD-Turbo ↔ DreamShaper**\n{C}"
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  metrics_btn.click(compute_metrics_ui, inputs=[images_state, captions_state],
259
+ outputs=[metrics_spinner, metrics_A, metrics_B, metrics_C])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+ # --- VQA ---
 
 
262
  gr.Markdown("## 5️⃣ Visual Question Answering (VQA)", elem_classes="heading-orange")
263
  with gr.Row():
264
  with gr.Column(scale=1):
 
268
  vqa_spinner = gr.HTML("<div style='height:4px;'></div>")
269
  vqa_out = gr.Markdown(label="VQA Output")
270
 
271
+ def vqa_ui(question, images_state):
272
  yield "<div class='loading-line'></div>", ""
273
+ ans = answer_vqa(question, images_state[0])
274
+ yield "", ans
 
 
 
 
 
 
 
275
 
276
+ vqa_btn.click(vqa_ui, inputs=[vqa_input, images_state], outputs=[vqa_spinner, vqa_out])
277
 
278
  return demo
279
 
 
281
  demo = build_full_ui()
282
  demo.launch()
283
 
 
 
284
  """
285
+ #Dumped code
286
+ # =========================
287
+ # LIBRARIES & DEVICE SETUP
288
+ # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  import torch
290
  import gradio as gr
291
  from PIL import Image
 
302
  if device == "cuda":
303
  torch.cuda.empty_cache()
304
 
305
+ # =========================
306
  # MODELS
307
+ # =========================
308
+ # Image generation
309
  gen_pipe = DiffusionPipeline.from_pretrained(
310
  "stabilityai/sdxl-turbo",
311
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
 
316
  torch_dtype=torch.float16 if device=="cuda" else torch.float32
317
  ).to(device)
318
 
319
+ # Captioning
320
  captioner = pipeline(
321
  "image-to-text",
322
  model="Salesforce/blip-image-captioning-large",
 
324
  generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
325
  )
326
 
327
+ # NLP
328
  sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
329
  device=0 if device=="cuda" else -1)
330
  ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
 
332
  topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
333
  device=0 if device=="cuda" else -1)
334
 
335
+ # VQA
336
  vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
337
  vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
338
 
339
+ # Metrics
340
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
341
  lpips_model = lpips.LPIPS(net='alex').to(device)
342
  lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
343
 
344
+ # Styles
345
  style_map = {
346
  "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
347
  "Real Life": "natural lighting, true-to-life colors, DSLR",
 
354
  "Macro": "macro lens shallow DOF",
355
  "Cyberpunk": "neon cyberpunk futuristic",
356
  }
357
+
358
+ # =========================
359
+ # IMAGE GENERATION FUNCTIONS
360
+ # =========================
361
  def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images):
362
  images = images or []
363
  base_caption = base_caption or ""
364
  enhancer = enhancer or ""
 
365
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
366
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
 
367
  try:
368
  seed = int(seed)
369
  except:
370
  seed = 42
 
371
  generator = torch.Generator(device="cpu").manual_seed(seed)
 
372
  try:
373
  with torch.no_grad():
374
  out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
 
376
  except Exception as e:
377
  print("SD Turbo failed:", e)
378
  img = None
 
379
  if img:
380
  images.append(img)
 
381
  free_gpu_cache()
382
  return img, images
383
 
 
385
  images = images or []
386
  base_caption = base_caption or ""
387
  enhancer = enhancer or ""
 
388
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
389
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
 
390
  try:
391
  seed = int(seed)
392
  except:
393
  seed = 42
 
394
  generator = torch.Generator(device="cpu").manual_seed(seed)
 
395
  try:
396
  with torch.no_grad():
397
  out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
 
399
  except Exception as e:
400
  print("DreamShaper failed:", e)
401
  img = None
 
402
  if img:
403
  images.append(img)
 
404
  free_gpu_cache()
405
  return img, images
406
 
407
+ # =========================
408
+ # CAPTIONING
409
+ # =========================
410
  def caption_for_image(img):
411
  try:
412
  out = captioner(img)
 
414
  except:
415
  return "Caption failed."
416
 
417
+ # =========================
418
+ # VQA
419
+ # =========================
420
  def answer_vqa(question, image):
421
  if not image or not question.strip():
422
  return "Provide image + question."
 
428
  ans_id = out.logits.argmax(-1)
429
  return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
430
  except:
431
+ return "I could not determine the answer."
432
 
433
+ # =========================
434
+ # METRICS
435
+ # =========================
436
  def compute_metrics(images, captions, i1, i2):
437
  img1 = images[i1]
438
  img2 = images[i2]
 
462
 
463
  return clip_sim, lp, bert_f1
464
 
465
+ # =========================
466
+ # GRADIO UI BUILD
467
+ # =========================
468
+ def build_full_ui():
469
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
470
+ # --- CSS Styling ---
471
  gr.HTML(
472
  <style>
473
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
474
  .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
475
  .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; }
476
+ .loading-line { height:4px; background: linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; }
477
+ @keyframes loading { 0% { background-position:200% 0; } 100% { background-position:-200% 0; } }
478
+ .enhancer-box textarea { width:100% !important; height:36px !important; box-sizing:border-box; font-size:14px; }
479
+ .equal-height-row { display:flex; align-items:stretch; }
480
+ .equal-height-row > .gr-column { display:flex; flex-direction:column; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  </style>
482
  )
483
 
484
+ # --- States ---
485
+ images_state = gr.State([None, None, None])
486
+ captions_state = gr.State(["", "", ""])
 
 
 
 
 
 
487
 
488
+ # =========================
489
+ # Section 1: Upload Reference Image
490
+ # =========================
491
+ gr.Markdown("## 1️⃣ Upload Reference Image", elem_classes="heading-orange")
492
  with gr.Row(elem_classes="equal-height-row"):
493
  with gr.Column(scale=1):
494
  upload_input = gr.Image(label="Drag & Drop Image", type="pil")
495
  upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
496
+ enhancer_box = gr.Textbox(label="Prompt Enhancer (Optional)", placeholder="Example: 'at night with neon lights'", elem_classes="enhancer-box")
497
  with gr.Column(scale=1):
498
  upload_preview = gr.Image(label="Uploaded Image", interactive=False)
 
 
 
 
 
499
  caption_out = gr.Markdown(label="Generated Caption")
500
 
501
+ # Upload & caption function
502
+ def upload_and_caption(img, images_state, captions_state):
503
  if img is None:
504
+ return None, "No image uploaded.", images_state, captions_state
505
+ images_state[0] = img
 
506
  try:
507
+ cap = caption_for_image(img)
508
+ except:
509
+ cap = "Caption failed."
510
+ captions_state[0] = cap
511
+ return img, cap, images_state, captions_state
 
 
 
 
 
 
 
 
 
512
 
513
+ upload_btn.click(upload_and_caption, inputs=[upload_input, images_state, captions_state],
514
+ outputs=[upload_preview, caption_out, images_state, captions_state])
515
+
516
+ # =========================
517
+ # Section 2: Generate SD-Turbo & DreamShaper
518
+ # =========================
519
+ gr.Markdown("## 2️⃣ Generate Images from Caption", elem_classes="heading-orange")
520
  with gr.Row():
521
+ with gr.Column(scale=1):
522
  sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn")
523
  sd_preview = gr.Image(label="SD-Turbo Image", interactive=False)
524
+ with gr.Column(scale=1):
525
  ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
526
  ds_preview = gr.Image(label="DreamShaper Image", interactive=False)
527
 
528
+ # Generate SD-Turbo
529
+ def generate_sd(caption, enhancer, images_state, captions_state):
530
+ img, images_state = generate_image_with_enhancer(caption, enhancer, negative="", seed=42, style="Photorealistic", images=images_state)
 
531
  if img:
532
+ captions_state[1] = caption_for_image(img)
533
+ return img, images_state, captions_state
534
+
535
+ # Generate DreamShaper
536
+ def generate_ds(caption, enhancer, images_state, captions_state):
537
+ img, images_state = generate_dreamshaper_with_enhancer(caption, enhancer, negative="", seed=123, style="Photorealistic", images=images_state)
 
 
 
 
 
 
 
 
538
  if img:
539
+ captions_state[2] = caption_for_image(img)
540
+ return img, images_state, captions_state
541
+
542
+ sd_btn.click(generate_sd, inputs=[caption_out, enhancer_box, images_state, captions_state],
 
 
 
 
 
 
 
543
  outputs=[sd_preview, images_state, captions_state])
544
+ ds_btn.click(generate_ds, inputs=[caption_out, enhancer_box, images_state, captions_state],
545
  outputs=[ds_preview, images_state, captions_state])
546
 
547
+ # =========================
548
+ # Section 3: Compute Pairwise Metrics (Side-by-Side)
549
+ # =========================
550
+ gr.Markdown("## 3️⃣ Compute Pairwise Metrics", elem_classes="heading-orange")
551
  metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
552
+ metrics_spinner = gr.HTML("<div style='height:4px;'></div>")
553
+ metrics_out = gr.HTML()
554
+
555
+ def compute_metrics_ui(images, captions):
556
+ yield "<div class='loading-line'></div>", ""
557
+ if any(i is None for i in images):
558
+ yield "All three images and captions are required."
 
 
 
 
 
 
 
 
559
  else:
560
  try:
561
  A = compute_metrics(images, captions, 0, 1)
562
  B = compute_metrics(images, captions, 0, 2)
563
  C = compute_metrics(images, captions, 1, 2)
564
+ def fmt(m):
565
+ return f"CLIP: {m[0]:.3f}<br>LPIPS: {m[1]:.3f}<br>BERTScore F1: {m[2]:.3f}"
566
+ html = f"""
567
+ #<div style='display:flex; gap:40px; justify-content:space-around;'>
568
+ # <div style='text-align:center;'><b>Metrics A</b><br>{fmt(A)}</div>
569
+ # <div style='text-align:center;'><b>Metrics B</b><br>{fmt(B)}</div>
570
+ # <div style='text-align:center;'><b>Metrics C</b><br>{fmt(C)}</div>
571
+ #</div>
572
+ """
573
+ yield html
574
  except Exception as e:
575
+ print("Metrics error:", e)
576
+ yield "Failed to compute metrics."
 
577
 
578
+ metrics_btn.click(compute_metrics_ui, inputs=[images_state, captions_state],
579
+ outputs=[metrics_out])
580
 
581
+ # =========================
582
+ # Section 4: NLP Analysis
583
+ # =========================
584
+ gr.Markdown("## 4️⃣ NLP Analysis of Captions", elem_classes="heading-orange")
585
  nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
586
+ nlp_spinner = gr.HTML("<div style='height:4px;'></div>")
587
  nlp_out = gr.HTML()
588
 
589
+ def analyze_captions_ui(captions):
 
590
  yield "<div class='loading-line'></div>", ""
591
+ if any(c=="" for c in captions):
592
+ yield "<b>All three captions are required for NLP analysis.</b>"
593
  else:
594
+ labels = ["Reference", "SD-Turbo", "DreamShaper"]
595
  blocks = []
596
  for label, caption in zip(labels, captions):
597
  try:
598
  sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)])
599
  except:
600
+ sentiment = "Sentiment failed."
601
  try:
602
  ents_list = ner_model(caption)
603
  ents = "<br>".join([f"{e.get('entity_group','')}: {e.get('word','')}" for e in ents_list]) or "None"
 
607
  topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature'])
608
  topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data.get('labels',[]), topics_data.get('scores',[]))])
609
  except:
610
+ topics = "Topics failed."
611
  block = f"<div style='flex:1;padding:10px;min-width:250px;'><h3><u>{label}</u></h3><b>Sentiment</b><br>{sentiment}<br><br><b>Entities</b><br>{ents}<br><br><b>Topics</b><br>{topics}</div>"
612
  blocks.append(block)
613
+ yield f"<div style='display:flex; gap:20px; justify-content:space-between;'>{''.join(blocks)}</div>"
614
 
615
+ nlp_btn.click(analyze_captions_ui, inputs=[captions_state], outputs=[nlp_out])
616
 
617
+ # =========================
618
+ # Section 5: Visual Question Answering
619
+ # =========================
620
+ gr.Markdown("## 5️⃣ Visual Question Answering (VQA)", elem_classes="heading-orange")
621
  with gr.Row():
622
  with gr.Column(scale=1):
623
  vqa_input = gr.Textbox(label="Enter a question about the reference image")
624
  vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
625
  with gr.Column(scale=1):
626
+ vqa_spinner = gr.HTML("<div style='height:4px;'></div>")
627
  vqa_out = gr.Markdown(label="VQA Output")
628
 
629
+ def vqa_ui(question, image):
630
  yield "<div class='loading-line'></div>", ""
631
+ if not question.strip() or image is None:
632
+ yield "Provide image + question."
633
+ else:
634
+ try:
635
+ ans = answer_vqa(question, image)
636
+ yield f"<b>Answer:</b> {ans}"
637
+ except Exception as e:
638
+ print("VQA error:", e)
639
+ yield "Could not determine the answer."
640
 
641
+ vqa_btn.click(vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out])
642
 
643
  return demo
644
 
645
+ # Launch
646
+ demo = build_full_ui()
647
  demo.launch()
648
+
649
+
650
  """