Chyd19 commited on
Commit
6ebebf6
Β·
verified Β·
1 Parent(s): e07bf97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +401 -39
app.py CHANGED
@@ -121,12 +121,11 @@ def compute_metrics_button(images, captions, idx1, idx2):
121
  rouge_scores = scorer.score(captions[idx1], captions[idx2])
122
 
123
  return f"""
124
- **Metrics Comparison**
125
- - CLIP Similarity: {clip_sim:.4f}
126
- - LPIPS Score: {lpips_score:.4f}
127
- - BERTScore F1: {bert_f1:.4f}
128
- - Cosine Similarity: {cosine_sim:.4f}
129
- - Jaccard Similarity: {jaccard_sim:.4f}
130
  - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
131
  - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
132
  """
@@ -194,6 +193,16 @@ def build_ui():
194
  5px 5px 15px rgba(0,0,0,0.3);
195
  border: 2px solid rgba(255,255,255,0.6);
196
  }
 
 
 
 
 
 
 
 
 
 
197
  </style>
198
  """)
199
 
@@ -253,45 +262,351 @@ def build_ui():
253
  # ---------------- Metrics ----------------
254
  gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
255
  metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
256
- metrics_A = gr.Markdown()
257
- metrics_B = gr.Markdown()
258
- metrics_C = gr.Markdown()
259
-
 
260
  def compute_metrics_all_pairs_ui(images, captions):
261
- # Show single spinner for all three
262
- yield "<div class='loading-line' style='width:100%; margin-bottom:10px;'></div>"
263
-
 
 
 
 
264
  if len(images) < 1 or len(captions) < 3:
265
  msg = "<b>Upload 1 image and generate all 3 captions.</b>"
266
- yield msg
267
  return
268
-
269
  imgs = images * 3
270
  A = compute_metrics_button(imgs, captions, 0, 1)
271
  B = compute_metrics_button(imgs, captions, 0, 2)
272
  C = compute_metrics_button(imgs, captions, 1, 2)
273
-
274
- # Side-by-side layout for the metrics
275
- html = f"""
276
- <div style='display:flex; gap:40px; text-align:left; font-family:monospace;'>
277
- <div>
278
- <h4>BLIP-large ↔ ViT-GPT2</h4>
279
- <pre>{A}</pre>
280
- </div>
281
- <div>
282
- <h4>BLIP-large ↔ BLIP2</h4>
283
- <pre>{B}</pre>
284
- </div>
285
- <div>
286
- <h4>ViT-GPT2 ↔ BLIP2</h4>
287
- <pre>{C}</pre>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  </div>
289
- </div>
290
- """
291
- yield html
292
- metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
293
- outputs=[metrics_A, metrics_B, metrics_C])
294
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def compute_metrics_all_pairs_ui(images, captions):
296
  yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
297
  if len(images) < 1 or len(captions) < 3:
@@ -307,7 +622,52 @@ def build_ui():
307
  f"**ViT-GPT2 ↔ BLIP2**<br>{C}")
308
 
309
  metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
310
- outputs=[metrics_A, metrics_B, metrics_C])"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  # ---------------- NLP ----------------
313
  gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
@@ -323,14 +683,14 @@ def build_ui():
323
  blocks = []
324
  for label, cap in zip(labels, captions):
325
  s, e, t = nlp_bundle(cap)
326
- block = f"""
327
  <div style='flex:1;padding:10px;min-width:240px;'>
328
  <h3><u>{label}</u></h3>
329
  <b>Sentiment</b><br>{s}<br><br>
330
  <b>Entities</b><br>{e}<br><br>
331
  <b>Topics</b><br>{t}
332
  </div>
333
- """
334
  blocks.append(block)
335
  yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
336
 
@@ -356,3 +716,5 @@ def build_ui():
356
  # ==============================
357
  demo = build_ui()
358
  demo.launch(share=True, debug=False)
 
 
 
121
  rouge_scores = scorer.score(captions[idx1], captions[idx2])
122
 
123
  return f"""
124
+ - CLIP: {clip_sim:.4f}
125
+ - LPIPS: {lpips_score:.4f}
126
+ - BERT-F1: {bert_f1:.4f}
127
+ - Cosine: {cosine_sim:.4f}
128
+ - Jaccard: {jaccard_sim:.4f}
 
129
  - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
130
  - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
131
  """
 
193
  5px 5px 15px rgba(0,0,0,0.3);
194
  border: 2px solid rgba(255,255,255,0.6);
195
  }
196
+
197
+ .metrics-row {
198
+ display: flex;
199
+ flex-direction: row;
200
+ gap: 20px;
201
+ }
202
+ .metrics-row > div {
203
+ flex: 1;
204
+ }
205
+
206
  </style>
207
  """)
208
 
 
262
  # ---------------- Metrics ----------------
263
  gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
264
  metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
265
+ with gr.Row(elem_classes="metrics-row"):
266
+ metrics_A = gr.Markdown()
267
+ metrics_B = gr.Markdown()
268
+ metrics_C = gr.Markdown()
269
+
270
  def compute_metrics_all_pairs_ui(images, captions):
271
+ # 3 spinners
272
+ yield (
273
+ "<div class='loading-line'></div>",
274
+ "<div class='loading-line'></div>",
275
+ "<div class='loading-line'></div>"
276
+ )
277
+
278
  if len(images) < 1 or len(captions) < 3:
279
  msg = "<b>Upload 1 image and generate all 3 captions.</b>"
280
+ yield (msg, msg, msg)
281
  return
282
+
283
  imgs = images * 3
284
  A = compute_metrics_button(imgs, captions, 0, 1)
285
  B = compute_metrics_button(imgs, captions, 0, 2)
286
  C = compute_metrics_button(imgs, captions, 1, 2)
287
+
288
+ yield (
289
+ f"### BLIP-large ↔ ViT-GPT2\n{A}",
290
+ f"### BLIP-large ↔ BLIP2\n{B}",
291
+ f"### ViT-GPT2 ↔ BLIP2\n{C}"
292
+ )
293
+
294
+ metrics_btn.click(
295
+ compute_metrics_all_pairs_ui,
296
+ inputs=[images_state, captions_state],
297
+ outputs=[metrics_A, metrics_B, metrics_C]
298
+ )
299
+
300
+ # ---------------- NLP ----------------
301
+ gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
302
+ nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
303
+ nlp_out = gr.HTML()
304
+
305
+ def do_nlp(captions):
306
+ yield "<div class='loading-line'></div>"
307
+ if len(captions) < 3:
308
+ yield "<b>All captions required.</b>"
309
+ return
310
+ labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
311
+ blocks = []
312
+ for label, cap in zip(labels, captions):
313
+ s, e, t = nlp_bundle(cap)
314
+ block = f"""
315
+ <div style='flex:1;padding:10px;min-width:240px;'>
316
+ <h3><u>{label}</u></h3>
317
+ <b>Sentiment</b><br>{s}<br><br>
318
+ <b>Entities</b><br>{e}<br><br>
319
+ <b>Topics</b><br>{t}
320
  </div>
321
+ """
322
+ blocks.append(block)
323
+ yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
324
+
325
+ nlp_btn.click(do_nlp, inputs=[captions_state], outputs=[nlp_out])
326
+
327
+ # ---------------- VQA ----------------
328
+ gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
329
+ with gr.Row():
330
+ vqa_input = gr.Textbox(label="Ask about the image")
331
+ vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
332
+ vqa_out = gr.Markdown()
333
+
334
+ def vqa_ui(question, image):
335
+ yield "<div class='loading-line'></div>"
336
+ yield answer_vqa(question, image)
337
+
338
+ vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out])
339
+
340
+ return demo
341
+
342
+ # ==============================
343
+ # LAUNCH
344
+ # ==============================
345
+ demo = build_ui()
346
+ demo.launch(share=True, debug=False)
347
+
348
+ """
349
+ # ==============================
350
+ # SECTION 1 β€” INSTALL + IMPORTS
351
+ # ==============================
352
+
353
+ import torch
354
+ import gradio as gr
355
+ from PIL import Image
356
+ from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
357
+ import lpips
358
+ import clip
359
+ from bert_score import score
360
+ import torchvision.transforms as T
361
+ from sentence_transformers import SentenceTransformer
362
+ from rouge_score import rouge_scorer
363
+ import numpy as np
364
+ from sklearn.metrics.pairwise import cosine_similarity
365
+
366
+ device = "cuda" if torch.cuda.is_available() else "cpu"
367
+
368
+ def free_gpu_cache():
369
+ if torch.cuda.is_available():
370
+ torch.cuda.empty_cache()
371
+
372
+ # ==============================
373
+ # SECTION 2 β€” LOAD LIGHTWEIGHT MODELS
374
+ # ==============================
375
+ blip_large_captioner = pipeline(
376
+ "image-to-text",
377
+ model="Salesforce/blip-image-captioning-large",
378
+ device=0 if device=="cuda" else -1
379
+ )
380
+
381
+ vit_gpt2_captioner = pipeline(
382
+ "image-to-text",
383
+ model="nlpconnect/vit-gpt2-image-captioning",
384
+ device=0 if device=="cuda" else -1
385
+ )
386
+
387
+ # --- NLP Pipelines ---
388
+ sentiment_model = pipeline("sentiment-analysis")
389
+ ner_model = pipeline("ner", aggregation_strategy="simple")
390
+ topic_model = pipeline("zero-shot-classification",
391
+ model="facebook/bart-large-mnli")
392
+
393
+ # --- Metrics ---
394
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
395
+ lpips_model = lpips.LPIPS(net='alex').to(device)
396
+ lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))])
397
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity
398
+
399
+ # ==============================
400
+ # SECTION 2b β€” LAZY LOAD HEAVY MODELS
401
+ # ==============================
402
+ blip2_captioner = None
403
+ vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
404
+ vqa_model = None
405
+
406
+ def get_blip2():
407
+ global blip2_captioner
408
+ if blip2_captioner is None:
409
+ blip2_captioner = pipeline(
410
+ "image-to-text",
411
+ model="Salesforce/blip2-opt-2.7b",
412
+ device=0 if device=="cuda" else -1
413
+ )
414
+ return blip2_captioner
415
+
416
+ def get_vqa_model():
417
+ global vqa_model
418
+ if vqa_model is None:
419
+ vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
420
+ return vqa_model
421
+
422
+ # ==============================
423
+ # SECTION 3 β€” FUNCTIONS
424
+ # ==============================
425
+ def make_captions(img):
426
+ captions = []
427
+ try: captions.append(blip_large_captioner(img)[0]["generated_text"])
428
+ except: captions.append("BLIP-large failed.")
429
+ try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"])
430
+ except: captions.append("ViT-GPT2 failed.")
431
+ try:
432
+ blip2 = get_blip2()
433
+ captions.append(blip2(img)[0]["generated_text"])
434
+ except: captions.append("BLIP2-opt failed.")
435
+ return captions
436
+
437
+ # ---------------- Metrics Computation ---------------------
438
+ def compute_metrics_button(images, captions, idx1, idx2):
439
+ # CLIP similarity
440
+ img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
441
+ img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
442
+ with torch.no_grad():
443
+ feat1 = clip_model.encode_image(img1_clip)
444
+ feat2 = clip_model.encode_image(img2_clip)
445
+ clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
446
+
447
+ # LPIPS
448
+ img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
449
+ img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
450
+ with torch.no_grad():
451
+ lpips_score = float(lpips_model(img1_lp, img2_lp).item())
452
+
453
+ # BERTScore
454
+ _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
455
+ bert_f1 = float(F1.mean().item())
456
+
457
+ # Cosine similarity of embeddings
458
+ emb1 = sentence_model.encode([captions[idx1]])
459
+ emb2 = sentence_model.encode([captions[idx2]])
460
+ cosine_sim = float(cosine_similarity(emb1, emb2)[0][0])
461
+
462
+ # Jaccard similarity
463
+ tokens1 = set(captions[idx1].lower().split())
464
+ tokens2 = set(captions[idx2].lower().split())
465
+ jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2))
466
+
467
+ # ROUGE
468
+ scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
469
+ rouge_scores = scorer.score(captions[idx1], captions[idx2])
470
+
471
+ return f""
472
+ **Metrics Comparison**
473
+ - CLIP Similarity: {clip_sim:.4f}
474
+ - LPIPS Score: {lpips_score:.4f}
475
+ - BERTScore F1: {bert_f1:.4f}
476
+ - Cosine Similarity: {cosine_sim:.4f}
477
+ - Jaccard Similarity: {jaccard_sim:.4f}
478
+ - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
479
+ - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
480
+ ""
481
+
482
+ # ---- NLP ----
483
+ def nlp_bundle(caption):
484
+ try:
485
+ sentiment = sentiment_model(caption)
486
+ sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment])
487
+ except: sentiment = "Sentiment failed."
488
+
489
+ try:
490
+ ents_list = ner_model(caption)
491
+ ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
492
+ except: ents = "NER failed."
493
+
494
+ try:
495
+ topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
496
+ topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])])
497
+ except: topics = "Topics failed."
498
+
499
+ return sentiment, ents, topics
500
+
501
+ # ---------------- VQA ----------------
502
+ def answer_vqa(question, image):
503
+ if image is None or question.strip() == "":
504
+ return "Upload an image and enter a question."
505
+ model = get_vqa_model()
506
+ inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
507
+ with torch.no_grad():
508
+ generated_ids = model.generate(**inputs)
509
+ answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
510
+ free_gpu_cache()
511
+ return answer
512
+
513
+ # Convert a PIL.Image to PNG byte stream
514
+ def to_bytes(img):
515
+ import io
516
+ buf = io.BytesIO()
517
+ img.save(buf, format="PNG")
518
+ return buf.getvalue()
519
+
520
+ # ==============================
521
+ # SECTION 4 β€” UI (GRADIO)
522
+ # ==============================
523
+ def build_ui():
524
+ with gr.Blocks(title="Multimodal AI Image Studio") as demo:
525
+
526
+ gr.HTML(
527
+ <style>
528
+ .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
529
+ .orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; }
530
+ .teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; }
531
+ .loading-line {
532
+ height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%);
533
+ background-size:200% 100%; animation: loading 1s linear infinite;
534
+ }
535
+ @keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} }
536
+ .circular-img img {
537
+ border-radius: 21%;
538
+ object-fit: cover;
539
+ width: 400px;
540
+ height: 200px;
541
+ box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3),
542
+ 5px 5px 15px rgba(0,0,0,0.3);
543
+ border: 2px solid rgba(255,255,255,0.6);
544
+ }
545
+ </style>
546
+ )
547
+
548
+ gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange")
549
+ images_state = gr.State([])
550
+ captions_state = gr.State([])
551
+
552
+ # ---------------- Image Input ----------------
553
+ gr.Markdown("### Select Image Source", elem_classes="heading-orange")
554
+ with gr.Tabs():
555
+ with gr.Tab("πŸ“ Upload Image"):
556
+ upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=900, width=960, elem_classes="circular-img")
557
+ upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn")
558
+ with gr.Tab("πŸ“· Webcam"):
559
+ webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=900, width=960, elem_classes="circular-img")
560
+ webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn")
561
+ with gr.Tab("πŸ”— From URL"):
562
+ url_input = gr.Textbox(label="Paste Image URL")
563
+ url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn")
564
+
565
+ # ---------------- Previews ----------------
566
+ with gr.Row():
567
+ with gr.Column(scale=1, min_width=200):
568
+ preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230)
569
+ blip_caption_box = gr.Markdown()
570
+ with gr.Column(scale=1, min_width=200):
571
+ preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230)
572
+ vit_caption_box = gr.Markdown()
573
+ with gr.Column(scale=1, min_width=200):
574
+ preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230)
575
+ blip2_caption_box = gr.Markdown()
576
+
577
+ # ---------------- Generate Captions ----------------
578
+ def generate_all(img, images_state, captions_state):
579
+ if img is None:
580
+ return (None, None, None, "No image.", "No image.", "No image.", [], [])
581
+ captions = make_captions(img)
582
+ return (img, img, img, captions[0], captions[1], captions[2], [img], captions)
583
+
584
+ upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state],
585
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
586
+ webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state],
587
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
588
+
589
+ def load_from_url(url, images_state, captions_state):
590
+ import requests
591
+ from io import BytesIO
592
+ try:
593
+ img = Image.open(BytesIO(requests.get(url).content))
594
+ except:
595
+ return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], [])
596
+ return generate_all(img, images_state, captions_state)
597
+
598
+ url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state],
599
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
600
+
601
+ # ---------------- Metrics ----------------
602
+
603
+ ""
604
+ gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
605
+ metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
606
+ metrics_A = gr.Markdown()
607
+ metrics_B = gr.Markdown()
608
+ metrics_C = gr.Markdown()
609
+
610
  def compute_metrics_all_pairs_ui(images, captions):
611
  yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
612
  if len(images) < 1 or len(captions) < 3:
 
622
  f"**ViT-GPT2 ↔ BLIP2**<br>{C}")
623
 
624
  metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
625
+ outputs=[metrics_A, metrics_B, metrics_C])""
626
+
627
+
628
+ # ---------------- Metrics ----------------
629
+ gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
630
+ metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
631
+
632
+ with gr.Row(elem_classes="metrics-row"):
633
+ metrics_A = gr.Markdown()
634
+ metrics_B = gr.Markdown()
635
+ metrics_C = gr.Markdown()
636
+
637
+ def compute_metrics_all_pairs_ui(images, captions):
638
+
639
+ # 3 spinners (one for each column)
640
+ yield (
641
+ "<div class='loading-line'></div>",
642
+ "<div class='loading-line'></div>",
643
+ "<div class='loading-line'></div>"
644
+ )
645
+
646
+ if len(images) < 1 or len(captions) < 3:
647
+ msg = "<b>Upload 1 image and generate all 3 captions.</b>"
648
+ yield (msg, msg, msg)
649
+ return
650
+
651
+ # duplicate image for internal function
652
+ imgs = images * 3
653
+
654
+ # compute
655
+ A = compute_metrics_button(imgs, captions, 0, 1)
656
+ B = compute_metrics_button(imgs, captions, 0, 2)
657
+ C = compute_metrics_button(imgs, captions, 1, 2)
658
+
659
+ # return 3 separate markdown blocks (side-by-side)
660
+ yield (
661
+ f"### BLIP-large ↔ ViT-GPT2\n{A}",
662
+ f"### BLIP-large ↔ BLIP2\n{B}",
663
+ f"### ViT-GPT2 ↔ BLIP2\n{C}"
664
+ )
665
+
666
+ metrics_btn.click(
667
+ compute_metrics_all_pairs_ui,
668
+ inputs=[images_state, captions_state],
669
+ outputs=[metrics_A, metrics_B, metrics_C]
670
+ )
671
 
672
  # ---------------- NLP ----------------
673
  gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
 
683
  blocks = []
684
  for label, cap in zip(labels, captions):
685
  s, e, t = nlp_bundle(cap)
686
+ block = f""
687
  <div style='flex:1;padding:10px;min-width:240px;'>
688
  <h3><u>{label}</u></h3>
689
  <b>Sentiment</b><br>{s}<br><br>
690
  <b>Entities</b><br>{e}<br><br>
691
  <b>Topics</b><br>{t}
692
  </div>
693
+ ""
694
  blocks.append(block)
695
  yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
696
 
 
716
  # ==============================
717
  demo = build_ui()
718
  demo.launch(share=True, debug=False)
719
+
720
+ """