Chyd19 commited on
Commit
8bf3645
·
verified ·
1 Parent(s): 9fbf22f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -367
app.py CHANGED
@@ -39,6 +39,14 @@
39
  # ==============================
40
  # Install
41
 
 
 
 
 
 
 
 
 
42
 
43
  # Libraries
44
  import torch
@@ -50,6 +58,8 @@ import lpips
50
  import clip
51
  from bert_score import score
52
  import torchvision.transforms as T
 
 
53
 
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
 
@@ -73,18 +83,30 @@ dreamshaper_pipe = DiffusionPipeline.from_pretrained(
73
  captioner = pipeline(
74
  "image-to-text",
75
  model="Salesforce/blip-image-captioning-large",
76
- device=0 if device=="cuda" else -1,)
77
- #generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7})
78
-
79
- sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
80
- device=0 if device=="cuda" else -1)
81
- ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
82
- aggregation_strategy="simple", device=0 if device=="cuda" else -1)
83
- topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
84
- device=0 if device=="cuda" else -1)
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
87
- vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
88
 
89
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
90
  lpips_model = lpips.LPIPS(net='alex').to(device)
@@ -103,68 +125,36 @@ style_map = {
103
  "Cyberpunk": "neon cyberpunk futuristic",
104
  }
105
 
106
- # **Section Two**
107
 
 
108
  # ==============================
109
- # SECTION 2 — FUNCTIONS
110
  # ==============================
111
- def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images):
112
  images = images or []
113
  base_caption = base_caption or ""
114
  enhancer = enhancer or ""
115
-
116
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
117
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
118
-
119
  try:
120
  seed = int(seed)
121
  except:
122
  seed = 42
123
-
124
- generator = torch.Generator(device="cpu").manual_seed(seed)
125
-
126
  try:
127
  with torch.no_grad():
128
- out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
129
  img = out.images[0]
130
  except Exception as e:
131
- print("SD Turbo failed:", e)
132
  img = None
133
-
134
  if img:
135
  images.append(img)
136
-
137
  free_gpu_cache()
138
  return img, images
139
 
140
- def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images):
141
- images = images or []
142
- base_caption = base_caption or ""
143
- enhancer = enhancer or ""
144
-
145
- final_prompt = f"{base_caption}, {enhancer}".strip(", ")
146
- final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
147
-
148
- try:
149
- seed = int(seed)
150
- except:
151
- seed = 42
152
-
153
- generator = torch.Generator(device="cpu").manual_seed(seed)
154
-
155
- try:
156
- with torch.no_grad():
157
- out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
158
- img = out.images[0]
159
- except Exception as e:
160
- print("DreamShaper failed:", e)
161
- img = None
162
-
163
- if img:
164
- images.append(img)
165
-
166
- free_gpu_cache()
167
- return img, images
168
 
169
  def caption_for_image(img):
170
  try:
@@ -178,7 +168,7 @@ def answer_vqa(question, image):
178
  return "Provide image + question."
179
  try:
180
  inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
181
- inputs = {k:v.to("cpu") for k,v in inputs_raw.items()}
182
  with torch.no_grad():
183
  out = vqa_model(**inputs)
184
  ans_id = out.logits.argmax(-1)
@@ -187,26 +177,21 @@ def answer_vqa(question, image):
187
  return "VQA failed."
188
 
189
  def compute_metrics(images, captions, i1, i2):
190
- img1 = images[i1]
191
- img2 = images[i2]
192
- cap1 = captions[i1]
193
- cap2 = captions[i2]
194
-
195
- # CLIP
196
- t1 = clip_preprocess(img1).unsqueeze(0).to("cpu")
197
- t2 = clip_preprocess(img2).unsqueeze(0).to("cpu")
198
  with torch.no_grad():
199
  f1 = clip_model.encode_image(t1)
200
  f2 = clip_model.encode_image(t2)
201
  clip_sim = float(torch.cosine_similarity(f1, f2))
202
 
203
- # LPIPS
204
- L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1)
205
- L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1)
206
  with torch.no_grad():
207
  lp = float(lpips_model(L1, L2))
208
 
209
- # BERTScore
210
  if cap1 and cap2:
211
  _, _, F = score([cap1],[cap2], lang="en", verbose=False)
212
  bert_f1 = float(F.mean())
@@ -215,14 +200,31 @@ def compute_metrics(images, captions, i1, i2):
215
 
216
  return clip_sim, lp, bert_f1
217
 
218
- # **Section Three**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # ==============================
221
- # Section Three
222
- # ==============================
223
 
224
- # 1
225
- # ---------------- Build Gradio UI with Custom Look ----------------
226
  def build_ui_with_custom_ui():
227
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
228
 
@@ -230,339 +232,180 @@ def build_ui_with_custom_ui():
230
  gr.HTML("""
231
  <style>
232
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
233
- .orange-btn button {
234
- background-color: #ff5500 !important;
235
- color: white !important;
236
- border-radius: 6px !important;
237
- height: 36px !important;
238
- font-weight: bold;
239
- }
240
- .teal-btn button {
241
- background-color: #008080 !important;
242
- color: white !important;
243
- border-radius: 6px !important;
244
- height: 40px !important;
245
- font-weight: bold;
246
- }
247
-
248
- /* Horizontal thin spinner */
249
- .loading-line {
250
- height: 4px;
251
- background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%);
252
- background-size: 200% 100%;
253
- animation: loading 1s linear infinite;
254
- }
255
- @keyframes loading {
256
- 0% { background-position: 200% 0; }
257
- 100% { background-position: -200% 0; }
258
- }
259
-
260
- /* Match enhancer box to upload button */
261
- .enhancer-box textarea {
262
- width: 100% !important;
263
- height: 36px !important;
264
- box-sizing: border-box;
265
- font-size: 14px;
266
- }
267
-
268
- /* Equal-height styling for Step-1 columns */
269
- .equal-height-row {
270
- display: flex;
271
- align-items: stretch;
272
- }
273
- .equal-height-row > .gr-column {
274
- display: flex;
275
- flex-direction: column;
276
- }
277
-
278
- /* Target Gradio image container */
279
- .stretch-img .gr-image-container {
280
- flex-grow: 1;
281
- display: flex;
282
- }
283
-
284
- .stretch-img .gr-image-container img {
285
- width: 100% !important;
286
- height: 100% !important;
287
- object-fit: contain; /* or cover */
288
- }
289
-
290
-
291
-
292
  </style>
293
  """)
294
 
295
  # ---------------- Heading ----------------
296
- gr.Markdown(
297
- "## Multimodal AI Image Studio: An Integrated Comparative Perspective",
298
- elem_classes="heading-orange"
299
- )
300
 
301
- # ---------------- States ----------------
302
  images_state = gr.State([])
303
  captions_state = gr.State([])
304
 
305
- # ---------------- Step 1: Upload Reference Image ----------------
306
  gr.Markdown("### Upload Reference Image", elem_classes="heading-orange")
307
 
308
- with gr.Row(elem_classes="equal-height-row"):
309
- with gr.Column(scale=1):
310
- upload_input = gr.Image(label="Drag & Drop Image", type="pil")
311
- upload_btn = gr.Button(
312
- "Upload Image & Generate Caption",
313
- elem_classes="orange-btn"
314
- )
315
-
316
- with gr.Column(scale=1):
317
- upload_preview = gr.Image(
318
- label="Uploaded Image",
319
- interactive=False, elem_classes="stretch-img"
320
- )
321
-
322
- enhancer_box = gr.Textbox(
323
- label="Add Prompt Enhancer (Optional)",
324
- placeholder="Example: 'at night with neon lights', 'wearing a red jacket', etc.",
325
- elem_classes="enhancer-box"
326
- )
327
-
328
- caption_out = gr.Markdown(label="Generated Caption")
329
-
330
- # ---------------- Robust Captioning ----------------
331
- def upload_and_generate_caption_ui(img, images_state, captions_state):
332
- if img is None:
333
- return None, "No image uploaded.", [], []
334
-
335
- images = [img]
336
- try:
337
- output = captioner(img)
338
- caption = (
339
- output[0]["generated_text"]
340
- if len(output) > 0 and "generated_text" in output[0]
341
- else "Caption failed."
342
- )
343
- except Exception as e:
344
- print("Captioning error:", e)
345
- caption = "Caption failed."
346
-
347
- captions = [caption]
348
- return img, caption, images, captions
349
-
350
- upload_btn.click(
351
- upload_and_generate_caption_ui,
352
- inputs=[upload_input, images_state, captions_state],
353
- outputs=[upload_preview, caption_out, images_state, captions_state]
354
- )
355
-
356
- # ---------------- Step 2: Generate SD-Turbo & DreamShaper ----------------
357
  gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange")
358
-
359
  with gr.Row():
360
- with gr.Column(scale=1, min_width=300):
361
- sd_btn = gr.Button(
362
- "Generate SD-Turbo Image",
363
- elem_classes="orange-btn"
364
- )
365
- sd_preview = gr.Image(
366
- label="SD-Turbo Image",
367
- interactive=False
368
- )
369
-
370
- with gr.Column(scale=1, min_width=300):
371
- ds_btn = gr.Button(
372
- "Generate DreamShaper Image",
373
- elem_classes="orange-btn"
374
- )
375
- ds_preview = gr.Image(
376
- label="DreamShaper Image",
377
- interactive=False
378
- )
379
-
380
- def generate_sd_from_caption_ui(caption, enhancer, images_state, captions_state):
381
- final_prompt = f"{caption}, {enhancer}".strip(", ")
382
- img, images = generate_image_with_enhancer(
383
- final_prompt,
384
- enhancer="",
385
- negative="",
386
- seed=42,
387
- style="Photorealistic",
388
- images=images_state
389
- )
390
- try:
391
- generated_caption = captioner(img)[0]["generated_text"]
392
- except:
393
- generated_caption = "Caption failed."
394
-
395
- captions_state[1:2] = [generated_caption]
396
- return img, images, captions_state
397
-
398
- def generate_ds_from_caption_ui(caption, enhancer, images_state, captions_state):
399
- final_prompt = f"{caption}, {enhancer}".strip(", ")
400
- img, images = generate_dreamshaper_with_enhancer(
401
- final_prompt,
402
- enhancer="",
403
- negative="",
404
- seed=123,
405
- style="Photorealistic",
406
- images=images_state
407
- )
408
- try:
409
- generated_caption = captioner(img)[0]["generated_text"]
410
- except:
411
- generated_caption = "Caption failed."
412
-
413
- captions_state[2:3] = [generated_caption]
414
- return img, images, captions_state
415
-
416
- sd_btn.click(
417
- generate_sd_from_caption_ui,
418
- inputs=[caption_out, enhancer_box, images_state, captions_state],
419
- outputs=[sd_preview, images_state, captions_state]
420
- )
421
-
422
- ds_btn.click(
423
- generate_ds_from_caption_ui,
424
- inputs=[caption_out, enhancer_box, images_state, captions_state],
425
- outputs=[ds_preview, images_state, captions_state]
426
- )
427
-
428
- # ---------------- Step 3: Compute Pairwise Metrics ----------------
429
  gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
430
-
431
- metrics_btn = gr.Button(
432
- "Compute Metrics for All Pairs",
433
- elem_classes="teal-btn"
434
- )
435
-
436
- with gr.Row():
437
  metrics_A = gr.Markdown()
438
  metrics_B = gr.Markdown()
439
  metrics_C = gr.Markdown()
440
 
441
  def compute_metrics_all_pairs_ui(images, captions):
442
- yield (
443
- "<div class='loading-line'></div>",
444
- "<div class='loading-line'></div>",
445
- "<div class='loading-line'></div>"
446
- )
447
-
448
- if len(images) < 3:
449
- msg = "All three images and captions are required to compute metrics."
450
  yield msg, msg, msg
451
- else:
452
- A = compute_metrics(images, captions, 0, 1)
453
- B = compute_metrics(images, captions, 0, 2)
454
- C = compute_metrics(images, captions, 1, 2)
455
- yield (
456
- f"**Reference ↔ SD-Turbo**\n{A}",
457
- f"**Reference ↔ DreamShaper**\n{B}",
458
- f"**SD-Turbo ↔ DreamShaper**\n{C}"
459
- )
460
-
461
- metrics_btn.click(
462
- compute_metrics_all_pairs_ui,
463
- inputs=[images_state, captions_state],
464
- outputs=[metrics_A, metrics_B, metrics_C]
465
- )
466
-
467
- # ---------------- Step 4: NLP Analysis ----------------
468
  gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange")
469
-
470
- nlp_btn = gr.Button(
471
- "Analyze Captions",
472
- elem_classes="teal-btn"
473
- )
474
-
475
- nlp_out = gr.HTML()
476
 
477
  def analyze_caption_pipeline_ui(captions):
478
- yield "<div class='loading-line'></div>"
479
-
480
  if len(captions) < 3:
481
- yield "<b>All three captions are required for NLP analysis.</b>"
482
- else:
483
- labels = ["Reference Image", "SD-Turbo", "DreamShaper"]
484
- blocks = []
485
-
486
- for label, caption in zip(labels, captions):
487
- sentiment = "<br>".join(
488
- [f"{s['label']}: {s['score']:.2f}"
489
- for s in sentiment_model(caption)]
490
- )
491
-
492
- ents = (
493
- "<br>".join(
494
- [f"{e['entity_group']}: {e['word']}"
495
- for e in ner_model(caption)]
496
- ) or "None"
497
- )
498
-
499
- topics_data = topic_model(
500
- caption,
501
- candidate_labels=[
502
- "people", "animals", "objects", "food", "nature"
503
- ]
504
- )
505
-
506
- topics = "<br>".join(
507
- [f"{l}: {sc:.2f}"
508
- for l, sc in zip(
509
- topics_data["labels"],
510
- topics_data["scores"]
511
- )]
512
- )
513
-
514
- block = f"""
515
- <div style='flex:1;padding:10px;min-width:250px;'>
516
- <h3><u>{label}</u></h3>
517
- <b>Sentiment</b><br>{sentiment}<br><br>
518
- <b>Entities</b><br>{ents}<br><br>
519
- <b>Topics</b><br>{topics}
520
- </div>
521
- """
522
- blocks.append(block)
523
-
524
- yield (
525
- "<div style='display:flex; gap:20px; justify-content:space-between;'>"
526
- + "".join(blocks) +
527
- "</div>"
528
- )
529
-
530
- nlp_btn.click(
531
- analyze_caption_pipeline_ui,
532
- inputs=[captions_state],
533
- outputs=[nlp_out]
534
- )
535
-
536
- # ---------------- Step 5: Visual Question Answering ----------------
537
  gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
538
-
539
  with gr.Row():
 
540
  with gr.Column(scale=1):
541
- vqa_input = gr.Textbox(
542
- label="Enter a question about the reference image"
543
- )
544
- vqa_btn = gr.Button(
545
- "Get Answer",
546
- elem_classes="teal-btn"
547
- )
548
-
549
  with gr.Column(scale=1):
550
  vqa_out = gr.Markdown(label="VQA Output")
551
 
552
  def answer_vqa_ui(question, image):
553
  yield "<div class='loading-line'></div>"
554
- ans = answer_vqa(question, image)
555
- yield ans
 
 
 
 
 
 
 
 
 
 
 
556
 
557
- vqa_btn.click(
558
- answer_vqa_ui,
559
- inputs=[vqa_input, upload_preview],
560
- outputs=[vqa_out]
561
- )
562
 
563
  return demo
564
 
565
-
566
  # ---------------- Launch ----------------
567
  demo = build_ui_with_custom_ui()
568
  demo.launch()
 
39
  # ==============================
40
  # Install
41
 
42
+ # Section One
43
+ # Section One
44
+ # ---------------- Install Libraries ----------------
45
+ !pip install -qq git+https://github.com/openai/CLIP.git
46
+ !pip install -qq lpips
47
+ !pip install -qq bert-score
48
+ !pip install -qq transformers accelerate
49
+ !pip install -qq diffusers gradio
50
 
51
  # Libraries
52
  import torch
 
58
  import clip
59
  from bert_score import score
60
  import torchvision.transforms as T
61
+ import requests
62
+ from io import BytesIO
63
 
64
  device = "cuda" if torch.cuda.is_available() else "cpu"
65
 
 
83
  captioner = pipeline(
84
  "image-to-text",
85
  model="Salesforce/blip-image-captioning-large",
86
+ device=0 if device=="cuda" else -1
87
+ )
88
+
89
+ sentiment_model = pipeline(
90
+ "sentiment-analysis",
91
+ model="distilbert-base-uncased-finetuned-sst-2-english",
92
+ device=-1
93
+ )
94
+
95
+ ner_model = pipeline(
96
+ "ner",
97
+ model="dbmdz/bert-large-cased-finetuned-conll03-english",
98
+ aggregation_strategy="simple",
99
+ device=-1
100
+ )
101
+
102
+ topic_model = pipeline(
103
+ "zero-shot-classification",
104
+ model="facebook/bart-large-mnli",
105
+ device=-1
106
+ )
107
 
108
  vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
109
+ vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
110
 
111
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
112
  lpips_model = lpips.LPIPS(net='alex').to(device)
 
125
  "Cyberpunk": "neon cyberpunk futuristic",
126
  }
127
 
 
128
 
129
+ # SEction Two
130
  # ==============================
131
+ # FUNCTIONS
132
  # ==============================
133
+ def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=gen_pipe):
134
  images = images or []
135
  base_caption = base_caption or ""
136
  enhancer = enhancer or ""
 
137
  final_prompt = f"{base_caption}, {enhancer}".strip(", ")
138
  final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
 
139
  try:
140
  seed = int(seed)
141
  except:
142
  seed = 42
143
+ generator = torch.Generator(device=device).manual_seed(seed)
 
 
144
  try:
145
  with torch.no_grad():
146
+ out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
147
  img = out.images[0]
148
  except Exception as e:
149
+ print(f"{pipe} failed:", e)
150
  img = None
 
151
  if img:
152
  images.append(img)
 
153
  free_gpu_cache()
154
  return img, images
155
 
156
+ generate_dreamshaper_with_enhancer = lambda base_caption, enhancer, negative, seed, style, images: \
157
+ generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=dreamshaper_pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def caption_for_image(img):
160
  try:
 
168
  return "Provide image + question."
169
  try:
170
  inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
171
+ inputs = {k:v.to(device) for k,v in inputs_raw.items()}
172
  with torch.no_grad():
173
  out = vqa_model(**inputs)
174
  ans_id = out.logits.argmax(-1)
 
177
  return "VQA failed."
178
 
179
  def compute_metrics(images, captions, i1, i2):
180
+ img1, img2 = images[i1], images[i2]
181
+ cap1, cap2 = captions[i1], captions[i2]
182
+
183
+ t1 = clip_preprocess(img1).unsqueeze(0).to(device)
184
+ t2 = clip_preprocess(img2).unsqueeze(0).to(device)
 
 
 
185
  with torch.no_grad():
186
  f1 = clip_model.encode_image(t1)
187
  f2 = clip_model.encode_image(t2)
188
  clip_sim = float(torch.cosine_similarity(f1, f2))
189
 
190
+ L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device)
191
+ L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device)
 
192
  with torch.no_grad():
193
  lp = float(lpips_model(L1, L2))
194
 
 
195
  if cap1 and cap2:
196
  _, _, F = score([cap1],[cap2], lang="en", verbose=False)
197
  bert_f1 = float(F.mean())
 
200
 
201
  return clip_sim, lp, bert_f1
202
 
203
+ def caption_and_store(img, images, captions):
204
+ if img is None:
205
+ return None, "", images, captions
206
+ try:
207
+ caption = captioner(img)[0]["generated_text"]
208
+ except Exception as e:
209
+ print("Captioning failed:", e)
210
+ caption = "Caption failed."
211
+ images = images + [img]
212
+ captions = captions + [caption]
213
+ return img, caption, images, captions
214
+
215
+ def fetch_and_caption(url, images, captions):
216
+ if not url:
217
+ return None, "", images, captions
218
+ try:
219
+ response = requests.get(url)
220
+ img = Image.open(BytesIO(response.content)).convert("RGB")
221
+ except Exception as e:
222
+ print("Failed to fetch image from URL:", e)
223
+ return None, "Failed to fetch image", images, captions
224
+ return caption_and_store(img, images, captions)
225
 
 
 
 
226
 
227
+ # ---------------- Section Three: UI ----------------
 
228
  def build_ui_with_custom_ui():
229
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
230
 
 
232
  gr.HTML("""
233
  <style>
234
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
235
+ .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
236
+ .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; }
237
+ .loading-line { height: 4px; background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; margin-bottom:4px; }
238
+ @keyframes loading { 0% { background-position: 200% 0; } 100% { background-position: -200% 0; } }
239
+ .enhancer-box textarea { width: 100% !important; height: 36px !important; font-size: 14px; }
240
+ .equal-height-row { display: flex; align-items: stretch; }
241
+ .equal-height-row > .gr-column { display: flex; flex-direction: column; }
242
+ .stretch-img .gr-image-container { flex-grow: 1; display: flex; }
243
+ .stretch-img img { width: 100% !important; height: 100% !important; object-fit: contain; }
244
+ .metrics-row { display: flex; gap: 20px; }
245
+ .metrics-row > div { flex: 1; }
246
+ .gradio-tabs button.selected { background-color: #ff5500 !important; color: white !important; font-weight: bold; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  </style>
248
  """)
249
 
250
  # ---------------- Heading ----------------
251
+ gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective",
252
+ elem_classes="heading-orange")
 
 
253
 
 
254
  images_state = gr.State([])
255
  captions_state = gr.State([])
256
 
257
+ # ---------------- Step 1: Upload Image ----------------
258
  gr.Markdown("### Upload Reference Image", elem_classes="heading-orange")
259
 
260
+ with gr.Tabs():
261
+ with gr.Tab("📁 Upload Image"):
262
+ with gr.Row(elem_classes="equal-height-row"):
263
+ with gr.Column(scale=1):
264
+ upload_input = gr.Image(label="Drag & Drop Image", type="pil")
265
+ upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
266
+ with gr.Column(scale=1):
267
+ upload_preview = gr.Image(label="Uploaded Image", interactive=False, elem_classes="stretch-img")
268
+ enhancer_box = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box")
269
+ caption_out = gr.Markdown(label="Generated Caption")
270
+ with gr.Tab("📷 Webcam"):
271
+ with gr.Row(elem_classes="equal-height-row"):
272
+ with gr.Column(scale=1):
273
+ webcam_input = gr.Image(label="Webcam Live", type="pil", sources=["webcam"], elem_classes="stretch-img")
274
+ webcam_btn = gr.Button("Capture & Generate Caption", elem_classes="orange-btn")
275
+ with gr.Column(scale=1):
276
+ webcam_preview = gr.Image(label="Captured Image", interactive=False, elem_classes="stretch-img")
277
+ enhancer_box_webcam = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box")
278
+ caption_out_webcam = gr.Markdown(label="Generated Caption")
279
+ with gr.Tab("🔗 From URL"):
280
+ url_input = gr.Textbox(label="Paste Image URL")
281
+ url_btn = gr.Button("Fetch & Generate Caption", elem_classes="orange-btn")
282
+
283
+ # ---------------- Caption Buttons ----------------
284
+ upload_btn.click(caption_and_store, [upload_input, images_state, captions_state],
285
+ [upload_preview, caption_out, images_state, captions_state])
286
+ webcam_btn.click(caption_and_store, [webcam_input, images_state, captions_state],
287
+ [webcam_preview, caption_out_webcam, images_state, captions_state])
288
+ url_btn.click(fetch_and_caption, [url_input, images_state, captions_state],
289
+ [upload_preview, caption_out, images_state, captions_state])
290
+
291
+ # ---------------- Step 2: Generate Images ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange")
 
293
  with gr.Row():
294
+ with gr.Column():
295
+ sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn")
296
+ sd_preview = gr.Image(label="SD-Turbo Image")
297
+ with gr.Column():
298
+ ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
299
+ ds_preview = gr.Image(label="DreamShaper Image")
300
+
301
+ # ---------------- Image Generation Functions ----------------
302
+ def generate_sd(_, enhancer, images, captions):
303
+ if not captions:
304
+ return None, images, captions
305
+ base_caption = captions[-1]
306
+ img, images = generate_image_with_enhancer(base_caption, enhancer or "", negative="", seed=42, style="Photorealistic", images=images)
307
+ if img:
308
+ new_caption = captioner(img)[0]["generated_text"]
309
+ captions = captions + [new_caption]
310
+ return img, images, captions
311
+
312
+ def generate_ds(_, enhancer, images, captions):
313
+ if not captions:
314
+ return None, images, captions
315
+ base_caption = captions[-1]
316
+ img, images = generate_dreamshaper_with_enhancer(base_caption, enhancer or "", negative="", seed=123, style="Photorealistic", images=images)
317
+ if img:
318
+ new_caption = captioner(img)[0]["generated_text"]
319
+ captions = captions + [new_caption]
320
+ return img, images, captions
321
+
322
+ # ---------------- Attach Clicks ----------------
323
+ sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state],
324
+ [sd_preview, images_state, captions_state])
325
+ ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state],
326
+ [ds_preview, images_state, captions_state])
327
+
328
+ # ---------------- Step 3: Metrics ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
330
+ metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
331
+ with gr.Row(elem_classes="metrics-row"):
 
 
 
 
 
332
  metrics_A = gr.Markdown()
333
  metrics_B = gr.Markdown()
334
  metrics_C = gr.Markdown()
335
 
336
  def compute_metrics_all_pairs_ui(images, captions):
337
+ yield ("<div class='loading-line'></div>",) * 3
338
+ if len(images) < 3 or len(captions) < 3:
339
+ msg = "⚠️ All three images and captions required."
 
 
 
 
 
340
  yield msg, msg, msg
341
+ return
342
+ pairs = [(0,1,"Reference ↔ SD-Turbo"), (0,2,"Reference ↔ DreamShaper"), (1,2,"SD-Turbo ↔ DreamShaper")]
343
+ results = []
344
+ for i1, i2, label in pairs:
345
+ clip_sim, lp, bert_f1 = compute_metrics(images, captions, i1, i2)
346
+ results.append(f"**{label}**<br>CLIP similarity: {clip_sim:.3f}<br>LPIPS: {lp:.3f}<br>BERT F1: {bert_f1:.3f}")
347
+ yield tuple(results)
348
+
349
+ metrics_btn.click(compute_metrics_all_pairs_ui, [images_state, captions_state],
350
+ [metrics_A, metrics_B, metrics_C])
351
+
352
+ # ---------------- Step 4: NLP ----------------
 
 
 
 
 
353
  gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange")
354
+ nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
355
+ with gr.Row(elem_classes="metrics-row"):
356
+ nlp_out_A = gr.HTML()
357
+ nlp_out_B = gr.HTML()
358
+ nlp_out_C = gr.HTML()
 
 
359
 
360
  def analyze_caption_pipeline_ui(captions):
361
+ yield ("<div class='loading-line'></div>",) * 3
 
362
  if len(captions) < 3:
363
+ yield "<b>All three captions required.</b>", "<b>All three captions required.</b>", "<b>All three captions required.</b>"
364
+ return
365
+ labels = ["Reference Image","SD-Turbo","DreamShaper"]
366
+ results = []
367
+ for label, caption in zip(labels, captions):
368
+ sentiment = "<br>".join(f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption))
369
+ ents = "<br>".join(f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)) or "None"
370
+ topics_data = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
371
+ topics = "<br>".join(f"{l}: {sc:.2f}" for l, sc in zip(topics_data["labels"], topics_data["scores"]))
372
+ results.append(f"<b>{label}</b><br><b>Sentiment</b><br>{sentiment}<br><b>Entities</b><br>{ents}<br><b>Topics</b><br>{topics}")
373
+ yield tuple(results)
374
+
375
+ nlp_btn.click(analyze_caption_pipeline_ui, captions_state,
376
+ [nlp_out_A, nlp_out_B, nlp_out_C])
377
+
378
+ # ---------------- Step 5: VQA ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
 
380
  with gr.Row():
381
+ # Left column: question input and button
382
  with gr.Column(scale=1):
383
+ vqa_input = gr.Textbox(label="Enter a question about the reference image")
384
+ vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
385
+ # Right column: VQA output
 
 
 
 
 
386
  with gr.Column(scale=1):
387
  vqa_out = gr.Markdown(label="VQA Output")
388
 
389
  def answer_vqa_ui(question, image):
390
  yield "<div class='loading-line'></div>"
391
+ if image is None or not question.strip():
392
+ yield "⚠️ Provide image + question."
393
+ return
394
+ try:
395
+ inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
396
+ inputs = {k:v.to(device) for k,v in inputs_raw.items()}
397
+ with torch.no_grad():
398
+ out = vqa_model(**inputs)
399
+ ans_id = out.logits.argmax(-1)
400
+ answer = vqa_processor.decode(ans_id[0], skip_special_tokens=True)
401
+ yield answer
402
+ except Exception as e:
403
+ yield f"⚠️ VQA failed: {str(e)}"
404
 
405
+ vqa_btn.click(answer_vqa_ui, [vqa_input, upload_preview], vqa_out)
 
 
 
 
406
 
407
  return demo
408
 
 
409
  # ---------------- Launch ----------------
410
  demo = build_ui_with_custom_ui()
411
  demo.launch()