Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -121,12 +121,11 @@ def compute_metrics_button(images, captions, idx1, idx2):
|
|
| 121 |
rouge_scores = scorer.score(captions[idx1], captions[idx2])
|
| 122 |
|
| 123 |
return f"""
|
| 124 |
-
|
| 125 |
-
-
|
| 126 |
-
-
|
| 127 |
-
-
|
| 128 |
-
-
|
| 129 |
-
- Jaccard Similarity: {jaccard_sim:.4f}
|
| 130 |
- ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
|
| 131 |
- ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
|
| 132 |
"""
|
|
@@ -194,6 +193,16 @@ def build_ui():
|
|
| 194 |
5px 5px 15px rgba(0,0,0,0.3);
|
| 195 |
border: 2px solid rgba(255,255,255,0.6);
|
| 196 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
</style>
|
| 198 |
""")
|
| 199 |
|
|
@@ -253,45 +262,351 @@ def build_ui():
|
|
| 253 |
# ---------------- Metrics ----------------
|
| 254 |
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
|
| 255 |
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
| 260 |
def compute_metrics_all_pairs_ui(images, captions):
|
| 261 |
-
#
|
| 262 |
-
yield
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
if len(images) < 1 or len(captions) < 3:
|
| 265 |
msg = "<b>Upload 1 image and generate all 3 captions.</b>"
|
| 266 |
-
yield msg
|
| 267 |
return
|
| 268 |
-
|
| 269 |
imgs = images * 3
|
| 270 |
A = compute_metrics_button(imgs, captions, 0, 1)
|
| 271 |
B = compute_metrics_button(imgs, captions, 0, 2)
|
| 272 |
C = compute_metrics_button(imgs, captions, 1, 2)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
</div>
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
yield
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
def compute_metrics_all_pairs_ui(images, captions):
|
| 296 |
yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
|
| 297 |
if len(images) < 1 or len(captions) < 3:
|
|
@@ -307,7 +622,52 @@ def build_ui():
|
|
| 307 |
f"**ViT-GPT2 β BLIP2**<br>{C}")
|
| 308 |
|
| 309 |
metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
|
| 310 |
-
outputs=[metrics_A, metrics_B, metrics_C])""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
# ---------------- NLP ----------------
|
| 313 |
gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
|
|
@@ -323,14 +683,14 @@ def build_ui():
|
|
| 323 |
blocks = []
|
| 324 |
for label, cap in zip(labels, captions):
|
| 325 |
s, e, t = nlp_bundle(cap)
|
| 326 |
-
block = f""
|
| 327 |
<div style='flex:1;padding:10px;min-width:240px;'>
|
| 328 |
<h3><u>{label}</u></h3>
|
| 329 |
<b>Sentiment</b><br>{s}<br><br>
|
| 330 |
<b>Entities</b><br>{e}<br><br>
|
| 331 |
<b>Topics</b><br>{t}
|
| 332 |
</div>
|
| 333 |
-
""
|
| 334 |
blocks.append(block)
|
| 335 |
yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
|
| 336 |
|
|
@@ -356,3 +716,5 @@ def build_ui():
|
|
| 356 |
# ==============================
|
| 357 |
demo = build_ui()
|
| 358 |
demo.launch(share=True, debug=False)
|
|
|
|
|
|
|
|
|
| 121 |
rouge_scores = scorer.score(captions[idx1], captions[idx2])
|
| 122 |
|
| 123 |
return f"""
|
| 124 |
+
- CLIP: {clip_sim:.4f}
|
| 125 |
+
- LPIPS: {lpips_score:.4f}
|
| 126 |
+
- BERT-F1: {bert_f1:.4f}
|
| 127 |
+
- Cosine: {cosine_sim:.4f}
|
| 128 |
+
- Jaccard: {jaccard_sim:.4f}
|
|
|
|
| 129 |
- ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
|
| 130 |
- ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
|
| 131 |
"""
|
|
|
|
| 193 |
5px 5px 15px rgba(0,0,0,0.3);
|
| 194 |
border: 2px solid rgba(255,255,255,0.6);
|
| 195 |
}
|
| 196 |
+
|
| 197 |
+
.metrics-row {
|
| 198 |
+
display: flex;
|
| 199 |
+
flex-direction: row;
|
| 200 |
+
gap: 20px;
|
| 201 |
+
}
|
| 202 |
+
.metrics-row > div {
|
| 203 |
+
flex: 1;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
</style>
|
| 207 |
""")
|
| 208 |
|
|
|
|
| 262 |
# ---------------- Metrics ----------------
|
| 263 |
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
|
| 264 |
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
|
| 265 |
+
with gr.Row(elem_classes="metrics-row"):
|
| 266 |
+
metrics_A = gr.Markdown()
|
| 267 |
+
metrics_B = gr.Markdown()
|
| 268 |
+
metrics_C = gr.Markdown()
|
| 269 |
+
|
| 270 |
def compute_metrics_all_pairs_ui(images, captions):
|
| 271 |
+
# 3 spinners
|
| 272 |
+
yield (
|
| 273 |
+
"<div class='loading-line'></div>",
|
| 274 |
+
"<div class='loading-line'></div>",
|
| 275 |
+
"<div class='loading-line'></div>"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
if len(images) < 1 or len(captions) < 3:
|
| 279 |
msg = "<b>Upload 1 image and generate all 3 captions.</b>"
|
| 280 |
+
yield (msg, msg, msg)
|
| 281 |
return
|
| 282 |
+
|
| 283 |
imgs = images * 3
|
| 284 |
A = compute_metrics_button(imgs, captions, 0, 1)
|
| 285 |
B = compute_metrics_button(imgs, captions, 0, 2)
|
| 286 |
C = compute_metrics_button(imgs, captions, 1, 2)
|
| 287 |
+
|
| 288 |
+
yield (
|
| 289 |
+
f"### BLIP-large β ViT-GPT2\n{A}",
|
| 290 |
+
f"### BLIP-large β BLIP2\n{B}",
|
| 291 |
+
f"### ViT-GPT2 β BLIP2\n{C}"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
metrics_btn.click(
|
| 295 |
+
compute_metrics_all_pairs_ui,
|
| 296 |
+
inputs=[images_state, captions_state],
|
| 297 |
+
outputs=[metrics_A, metrics_B, metrics_C]
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# ---------------- NLP ----------------
|
| 301 |
+
gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
|
| 302 |
+
nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
|
| 303 |
+
nlp_out = gr.HTML()
|
| 304 |
+
|
| 305 |
+
def do_nlp(captions):
|
| 306 |
+
yield "<div class='loading-line'></div>"
|
| 307 |
+
if len(captions) < 3:
|
| 308 |
+
yield "<b>All captions required.</b>"
|
| 309 |
+
return
|
| 310 |
+
labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
|
| 311 |
+
blocks = []
|
| 312 |
+
for label, cap in zip(labels, captions):
|
| 313 |
+
s, e, t = nlp_bundle(cap)
|
| 314 |
+
block = f"""
|
| 315 |
+
<div style='flex:1;padding:10px;min-width:240px;'>
|
| 316 |
+
<h3><u>{label}</u></h3>
|
| 317 |
+
<b>Sentiment</b><br>{s}<br><br>
|
| 318 |
+
<b>Entities</b><br>{e}<br><br>
|
| 319 |
+
<b>Topics</b><br>{t}
|
| 320 |
</div>
|
| 321 |
+
"""
|
| 322 |
+
blocks.append(block)
|
| 323 |
+
yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
|
| 324 |
+
|
| 325 |
+
nlp_btn.click(do_nlp, inputs=[captions_state], outputs=[nlp_out])
|
| 326 |
+
|
| 327 |
+
# ---------------- VQA ----------------
|
| 328 |
+
gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
|
| 329 |
+
with gr.Row():
|
| 330 |
+
vqa_input = gr.Textbox(label="Ask about the image")
|
| 331 |
+
vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
|
| 332 |
+
vqa_out = gr.Markdown()
|
| 333 |
+
|
| 334 |
+
def vqa_ui(question, image):
|
| 335 |
+
yield "<div class='loading-line'></div>"
|
| 336 |
+
yield answer_vqa(question, image)
|
| 337 |
+
|
| 338 |
+
vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out])
|
| 339 |
+
|
| 340 |
+
return demo
|
| 341 |
+
|
| 342 |
+
# ==============================
|
| 343 |
+
# LAUNCH
|
| 344 |
+
# ==============================
|
| 345 |
+
demo = build_ui()
|
| 346 |
+
demo.launch(share=True, debug=False)
|
| 347 |
+
|
| 348 |
+
"""
|
| 349 |
+
# ==============================
|
| 350 |
+
# SECTION 1 β INSTALL + IMPORTS
|
| 351 |
+
# ==============================
|
| 352 |
+
|
| 353 |
+
import torch
|
| 354 |
+
import gradio as gr
|
| 355 |
+
from PIL import Image
|
| 356 |
+
from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
|
| 357 |
+
import lpips
|
| 358 |
+
import clip
|
| 359 |
+
from bert_score import score
|
| 360 |
+
import torchvision.transforms as T
|
| 361 |
+
from sentence_transformers import SentenceTransformer
|
| 362 |
+
from rouge_score import rouge_scorer
|
| 363 |
+
import numpy as np
|
| 364 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 365 |
+
|
| 366 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 367 |
+
|
| 368 |
+
def free_gpu_cache():
|
| 369 |
+
if torch.cuda.is_available():
|
| 370 |
+
torch.cuda.empty_cache()
|
| 371 |
+
|
| 372 |
+
# ==============================
|
| 373 |
+
# SECTION 2 β LOAD LIGHTWEIGHT MODELS
|
| 374 |
+
# ==============================
|
| 375 |
+
blip_large_captioner = pipeline(
|
| 376 |
+
"image-to-text",
|
| 377 |
+
model="Salesforce/blip-image-captioning-large",
|
| 378 |
+
device=0 if device=="cuda" else -1
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
vit_gpt2_captioner = pipeline(
|
| 382 |
+
"image-to-text",
|
| 383 |
+
model="nlpconnect/vit-gpt2-image-captioning",
|
| 384 |
+
device=0 if device=="cuda" else -1
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# --- NLP Pipelines ---
|
| 388 |
+
sentiment_model = pipeline("sentiment-analysis")
|
| 389 |
+
ner_model = pipeline("ner", aggregation_strategy="simple")
|
| 390 |
+
topic_model = pipeline("zero-shot-classification",
|
| 391 |
+
model="facebook/bart-large-mnli")
|
| 392 |
+
|
| 393 |
+
# --- Metrics ---
|
| 394 |
+
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
| 395 |
+
lpips_model = lpips.LPIPS(net='alex').to(device)
|
| 396 |
+
lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))])
|
| 397 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity
|
| 398 |
+
|
| 399 |
+
# ==============================
|
| 400 |
+
# SECTION 2b β LAZY LOAD HEAVY MODELS
|
| 401 |
+
# ==============================
|
| 402 |
+
blip2_captioner = None
|
| 403 |
+
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
| 404 |
+
vqa_model = None
|
| 405 |
+
|
| 406 |
+
def get_blip2():
|
| 407 |
+
global blip2_captioner
|
| 408 |
+
if blip2_captioner is None:
|
| 409 |
+
blip2_captioner = pipeline(
|
| 410 |
+
"image-to-text",
|
| 411 |
+
model="Salesforce/blip2-opt-2.7b",
|
| 412 |
+
device=0 if device=="cuda" else -1
|
| 413 |
+
)
|
| 414 |
+
return blip2_captioner
|
| 415 |
+
|
| 416 |
+
def get_vqa_model():
|
| 417 |
+
global vqa_model
|
| 418 |
+
if vqa_model is None:
|
| 419 |
+
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
|
| 420 |
+
return vqa_model
|
| 421 |
+
|
| 422 |
+
# ==============================
|
| 423 |
+
# SECTION 3 β FUNCTIONS
|
| 424 |
+
# ==============================
|
| 425 |
+
def make_captions(img):
|
| 426 |
+
captions = []
|
| 427 |
+
try: captions.append(blip_large_captioner(img)[0]["generated_text"])
|
| 428 |
+
except: captions.append("BLIP-large failed.")
|
| 429 |
+
try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"])
|
| 430 |
+
except: captions.append("ViT-GPT2 failed.")
|
| 431 |
+
try:
|
| 432 |
+
blip2 = get_blip2()
|
| 433 |
+
captions.append(blip2(img)[0]["generated_text"])
|
| 434 |
+
except: captions.append("BLIP2-opt failed.")
|
| 435 |
+
return captions
|
| 436 |
+
|
| 437 |
+
# ---------------- Metrics Computation ---------------------
|
| 438 |
+
def compute_metrics_button(images, captions, idx1, idx2):
|
| 439 |
+
# CLIP similarity
|
| 440 |
+
img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
|
| 441 |
+
img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
|
| 442 |
+
with torch.no_grad():
|
| 443 |
+
feat1 = clip_model.encode_image(img1_clip)
|
| 444 |
+
feat2 = clip_model.encode_image(img2_clip)
|
| 445 |
+
clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
|
| 446 |
+
|
| 447 |
+
# LPIPS
|
| 448 |
+
img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
|
| 449 |
+
img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
|
| 450 |
+
with torch.no_grad():
|
| 451 |
+
lpips_score = float(lpips_model(img1_lp, img2_lp).item())
|
| 452 |
+
|
| 453 |
+
# BERTScore
|
| 454 |
+
_, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
|
| 455 |
+
bert_f1 = float(F1.mean().item())
|
| 456 |
+
|
| 457 |
+
# Cosine similarity of embeddings
|
| 458 |
+
emb1 = sentence_model.encode([captions[idx1]])
|
| 459 |
+
emb2 = sentence_model.encode([captions[idx2]])
|
| 460 |
+
cosine_sim = float(cosine_similarity(emb1, emb2)[0][0])
|
| 461 |
+
|
| 462 |
+
# Jaccard similarity
|
| 463 |
+
tokens1 = set(captions[idx1].lower().split())
|
| 464 |
+
tokens2 = set(captions[idx2].lower().split())
|
| 465 |
+
jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2))
|
| 466 |
+
|
| 467 |
+
# ROUGE
|
| 468 |
+
scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
|
| 469 |
+
rouge_scores = scorer.score(captions[idx1], captions[idx2])
|
| 470 |
+
|
| 471 |
+
return f""
|
| 472 |
+
**Metrics Comparison**
|
| 473 |
+
- CLIP Similarity: {clip_sim:.4f}
|
| 474 |
+
- LPIPS Score: {lpips_score:.4f}
|
| 475 |
+
- BERTScore F1: {bert_f1:.4f}
|
| 476 |
+
- Cosine Similarity: {cosine_sim:.4f}
|
| 477 |
+
- Jaccard Similarity: {jaccard_sim:.4f}
|
| 478 |
+
- ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
|
| 479 |
+
- ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
|
| 480 |
+
""
|
| 481 |
+
|
| 482 |
+
# ---- NLP ----
|
| 483 |
+
def nlp_bundle(caption):
|
| 484 |
+
try:
|
| 485 |
+
sentiment = sentiment_model(caption)
|
| 486 |
+
sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment])
|
| 487 |
+
except: sentiment = "Sentiment failed."
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
ents_list = ner_model(caption)
|
| 491 |
+
ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
|
| 492 |
+
except: ents = "NER failed."
|
| 493 |
+
|
| 494 |
+
try:
|
| 495 |
+
topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
|
| 496 |
+
topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])])
|
| 497 |
+
except: topics = "Topics failed."
|
| 498 |
+
|
| 499 |
+
return sentiment, ents, topics
|
| 500 |
+
|
| 501 |
+
# ---------------- VQA ----------------
|
| 502 |
+
def answer_vqa(question, image):
|
| 503 |
+
if image is None or question.strip() == "":
|
| 504 |
+
return "Upload an image and enter a question."
|
| 505 |
+
model = get_vqa_model()
|
| 506 |
+
inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
|
| 507 |
+
with torch.no_grad():
|
| 508 |
+
generated_ids = model.generate(**inputs)
|
| 509 |
+
answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
|
| 510 |
+
free_gpu_cache()
|
| 511 |
+
return answer
|
| 512 |
+
|
| 513 |
+
# Convert a PIL.Image to PNG byte stream
|
| 514 |
+
def to_bytes(img):
|
| 515 |
+
import io
|
| 516 |
+
buf = io.BytesIO()
|
| 517 |
+
img.save(buf, format="PNG")
|
| 518 |
+
return buf.getvalue()
|
| 519 |
+
|
| 520 |
+
# ==============================
|
| 521 |
+
# SECTION 4 β UI (GRADIO)
|
| 522 |
+
# ==============================
|
| 523 |
+
def build_ui():
|
| 524 |
+
with gr.Blocks(title="Multimodal AI Image Studio") as demo:
|
| 525 |
+
|
| 526 |
+
gr.HTML(
|
| 527 |
+
<style>
|
| 528 |
+
.heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
|
| 529 |
+
.orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; }
|
| 530 |
+
.teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; }
|
| 531 |
+
.loading-line {
|
| 532 |
+
height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%);
|
| 533 |
+
background-size:200% 100%; animation: loading 1s linear infinite;
|
| 534 |
+
}
|
| 535 |
+
@keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} }
|
| 536 |
+
.circular-img img {
|
| 537 |
+
border-radius: 21%;
|
| 538 |
+
object-fit: cover;
|
| 539 |
+
width: 400px;
|
| 540 |
+
height: 200px;
|
| 541 |
+
box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3),
|
| 542 |
+
5px 5px 15px rgba(0,0,0,0.3);
|
| 543 |
+
border: 2px solid rgba(255,255,255,0.6);
|
| 544 |
+
}
|
| 545 |
+
</style>
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange")
|
| 549 |
+
images_state = gr.State([])
|
| 550 |
+
captions_state = gr.State([])
|
| 551 |
+
|
| 552 |
+
# ---------------- Image Input ----------------
|
| 553 |
+
gr.Markdown("### Select Image Source", elem_classes="heading-orange")
|
| 554 |
+
with gr.Tabs():
|
| 555 |
+
with gr.Tab("π Upload Image"):
|
| 556 |
+
upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=900, width=960, elem_classes="circular-img")
|
| 557 |
+
upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn")
|
| 558 |
+
with gr.Tab("π· Webcam"):
|
| 559 |
+
webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=900, width=960, elem_classes="circular-img")
|
| 560 |
+
webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn")
|
| 561 |
+
with gr.Tab("π From URL"):
|
| 562 |
+
url_input = gr.Textbox(label="Paste Image URL")
|
| 563 |
+
url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn")
|
| 564 |
+
|
| 565 |
+
# ---------------- Previews ----------------
|
| 566 |
+
with gr.Row():
|
| 567 |
+
with gr.Column(scale=1, min_width=200):
|
| 568 |
+
preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230)
|
| 569 |
+
blip_caption_box = gr.Markdown()
|
| 570 |
+
with gr.Column(scale=1, min_width=200):
|
| 571 |
+
preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230)
|
| 572 |
+
vit_caption_box = gr.Markdown()
|
| 573 |
+
with gr.Column(scale=1, min_width=200):
|
| 574 |
+
preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230)
|
| 575 |
+
blip2_caption_box = gr.Markdown()
|
| 576 |
+
|
| 577 |
+
# ---------------- Generate Captions ----------------
|
| 578 |
+
def generate_all(img, images_state, captions_state):
|
| 579 |
+
if img is None:
|
| 580 |
+
return (None, None, None, "No image.", "No image.", "No image.", [], [])
|
| 581 |
+
captions = make_captions(img)
|
| 582 |
+
return (img, img, img, captions[0], captions[1], captions[2], [img], captions)
|
| 583 |
+
|
| 584 |
+
upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state],
|
| 585 |
+
outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
|
| 586 |
+
webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state],
|
| 587 |
+
outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
|
| 588 |
+
|
| 589 |
+
def load_from_url(url, images_state, captions_state):
|
| 590 |
+
import requests
|
| 591 |
+
from io import BytesIO
|
| 592 |
+
try:
|
| 593 |
+
img = Image.open(BytesIO(requests.get(url).content))
|
| 594 |
+
except:
|
| 595 |
+
return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], [])
|
| 596 |
+
return generate_all(img, images_state, captions_state)
|
| 597 |
+
|
| 598 |
+
url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state],
|
| 599 |
+
outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
|
| 600 |
+
|
| 601 |
+
# ---------------- Metrics ----------------
|
| 602 |
+
|
| 603 |
+
""
|
| 604 |
+
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
|
| 605 |
+
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
|
| 606 |
+
metrics_A = gr.Markdown()
|
| 607 |
+
metrics_B = gr.Markdown()
|
| 608 |
+
metrics_C = gr.Markdown()
|
| 609 |
+
|
| 610 |
def compute_metrics_all_pairs_ui(images, captions):
|
| 611 |
yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
|
| 612 |
if len(images) < 1 or len(captions) < 3:
|
|
|
|
| 622 |
f"**ViT-GPT2 β BLIP2**<br>{C}")
|
| 623 |
|
| 624 |
metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
|
| 625 |
+
outputs=[metrics_A, metrics_B, metrics_C])""
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
# ---------------- Metrics ----------------
|
| 629 |
+
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
|
| 630 |
+
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
|
| 631 |
+
|
| 632 |
+
with gr.Row(elem_classes="metrics-row"):
|
| 633 |
+
metrics_A = gr.Markdown()
|
| 634 |
+
metrics_B = gr.Markdown()
|
| 635 |
+
metrics_C = gr.Markdown()
|
| 636 |
+
|
| 637 |
+
def compute_metrics_all_pairs_ui(images, captions):
|
| 638 |
+
|
| 639 |
+
# 3 spinners (one for each column)
|
| 640 |
+
yield (
|
| 641 |
+
"<div class='loading-line'></div>",
|
| 642 |
+
"<div class='loading-line'></div>",
|
| 643 |
+
"<div class='loading-line'></div>"
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if len(images) < 1 or len(captions) < 3:
|
| 647 |
+
msg = "<b>Upload 1 image and generate all 3 captions.</b>"
|
| 648 |
+
yield (msg, msg, msg)
|
| 649 |
+
return
|
| 650 |
+
|
| 651 |
+
# duplicate image for internal function
|
| 652 |
+
imgs = images * 3
|
| 653 |
+
|
| 654 |
+
# compute
|
| 655 |
+
A = compute_metrics_button(imgs, captions, 0, 1)
|
| 656 |
+
B = compute_metrics_button(imgs, captions, 0, 2)
|
| 657 |
+
C = compute_metrics_button(imgs, captions, 1, 2)
|
| 658 |
+
|
| 659 |
+
# return 3 separate markdown blocks (side-by-side)
|
| 660 |
+
yield (
|
| 661 |
+
f"### BLIP-large β ViT-GPT2\n{A}",
|
| 662 |
+
f"### BLIP-large β BLIP2\n{B}",
|
| 663 |
+
f"### ViT-GPT2 β BLIP2\n{C}"
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
metrics_btn.click(
|
| 667 |
+
compute_metrics_all_pairs_ui,
|
| 668 |
+
inputs=[images_state, captions_state],
|
| 669 |
+
outputs=[metrics_A, metrics_B, metrics_C]
|
| 670 |
+
)
|
| 671 |
|
| 672 |
# ---------------- NLP ----------------
|
| 673 |
gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
|
|
|
|
| 683 |
blocks = []
|
| 684 |
for label, cap in zip(labels, captions):
|
| 685 |
s, e, t = nlp_bundle(cap)
|
| 686 |
+
block = f""
|
| 687 |
<div style='flex:1;padding:10px;min-width:240px;'>
|
| 688 |
<h3><u>{label}</u></h3>
|
| 689 |
<b>Sentiment</b><br>{s}<br><br>
|
| 690 |
<b>Entities</b><br>{e}<br><br>
|
| 691 |
<b>Topics</b><br>{t}
|
| 692 |
</div>
|
| 693 |
+
""
|
| 694 |
blocks.append(block)
|
| 695 |
yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
|
| 696 |
|
|
|
|
| 716 |
# ==============================
|
| 717 |
demo = build_ui()
|
| 718 |
demo.launch(share=True, debug=False)
|
| 719 |
+
|
| 720 |
+
"""
|