Chyd19 commited on
Commit
468012b
Β·
verified Β·
1 Parent(s): cd443bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================
2
+ # SECTION 1 β€” INSTALL + IMPORTS
3
+ # ==============================
4
+
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
9
+ import lpips
10
+ import clip
11
+ from bert_score import score
12
+ import torchvision.transforms as T
13
+ from sentence_transformers import SentenceTransformer
14
+ from rouge_score import rouge_scorer
15
+ import numpy as np
16
+ from sklearn.metrics.pairwise import cosine_similarity
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ def free_gpu_cache():
21
+ if torch.cuda.is_available():
22
+ torch.cuda.empty_cache()
23
+
24
+ # ==============================
25
+ # SECTION 2 β€” LOAD LIGHTWEIGHT MODELS
26
+ # ==============================
27
+ blip_large_captioner = pipeline(
28
+ "image-to-text",
29
+ model="Salesforce/blip-image-captioning-large",
30
+ device=0 if device=="cuda" else -1
31
+ )
32
+
33
+ vit_gpt2_captioner = pipeline(
34
+ "image-to-text",
35
+ model="nlpconnect/vit-gpt2-image-captioning",
36
+ device=0 if device=="cuda" else -1
37
+ )
38
+
39
+ # --- NLP Pipelines ---
40
+ sentiment_model = pipeline("sentiment-analysis")
41
+ ner_model = pipeline("ner", aggregation_strategy="simple")
42
+ topic_model = pipeline("zero-shot-classification",
43
+ model="facebook/bart-large-mnli")
44
+
45
+ # --- Metrics ---
46
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
47
+ lpips_model = lpips.LPIPS(net='alex').to(device)
48
+ lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))])
49
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity
50
+
51
+ # ==============================
52
+ # SECTION 2b β€” LAZY LOAD HEAVY MODELS
53
+ # ==============================
54
+ blip2_captioner = None
55
+ vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
56
+ vqa_model = None
57
+
58
+ def get_blip2():
59
+ global blip2_captioner
60
+ if blip2_captioner is None:
61
+ blip2_captioner = pipeline(
62
+ "image-to-text",
63
+ model="Salesforce/blip2-opt-2.7b",
64
+ device=0 if device=="cuda" else -1
65
+ )
66
+ return blip2_captioner
67
+
68
+ def get_vqa_model():
69
+ global vqa_model
70
+ if vqa_model is None:
71
+ vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
72
+ return vqa_model
73
+
74
+ # ==============================
75
+ # SECTION 3 β€” FUNCTIONS
76
+ # ==============================
77
+ def make_captions(img):
78
+ captions = []
79
+ try: captions.append(blip_large_captioner(img)[0]["generated_text"])
80
+ except: captions.append("BLIP-large failed.")
81
+ try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"])
82
+ except: captions.append("ViT-GPT2 failed.")
83
+ try:
84
+ blip2 = get_blip2()
85
+ captions.append(blip2(img)[0]["generated_text"])
86
+ except: captions.append("BLIP2-opt failed.")
87
+ return captions
88
+
89
+ # ---------------- Metrics Computation ---------------------
90
+ def compute_metrics_button(images, captions, idx1, idx2):
91
+ # CLIP similarity
92
+ img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
93
+ img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
94
+ with torch.no_grad():
95
+ feat1 = clip_model.encode_image(img1_clip)
96
+ feat2 = clip_model.encode_image(img2_clip)
97
+ clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
98
+
99
+ # LPIPS
100
+ img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
101
+ img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
102
+ with torch.no_grad():
103
+ lpips_score = float(lpips_model(img1_lp, img2_lp).item())
104
+
105
+ # BERTScore
106
+ _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
107
+ bert_f1 = float(F1.mean().item())
108
+
109
+ # Cosine similarity of embeddings
110
+ emb1 = sentence_model.encode([captions[idx1]])
111
+ emb2 = sentence_model.encode([captions[idx2]])
112
+ cosine_sim = float(cosine_similarity(emb1, emb2)[0][0])
113
+
114
+ # Jaccard similarity
115
+ tokens1 = set(captions[idx1].lower().split())
116
+ tokens2 = set(captions[idx2].lower().split())
117
+ jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2))
118
+
119
+ # ROUGE
120
+ scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
121
+ rouge_scores = scorer.score(captions[idx1], captions[idx2])
122
+
123
+ return f"""
124
+ **Metrics Comparison**
125
+ - CLIP Similarity: {clip_sim:.4f}
126
+ - LPIPS Score: {lpips_score:.4f}
127
+ - BERTScore F1: {bert_f1:.4f}
128
+ - Cosine Similarity: {cosine_sim:.4f}
129
+ - Jaccard Similarity: {jaccard_sim:.4f}
130
+ - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
131
+ - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
132
+ """
133
+
134
+ # ---- NLP ----
135
+ def nlp_bundle(caption):
136
+ try:
137
+ sentiment = sentiment_model(caption)
138
+ sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment])
139
+ except: sentiment = "Sentiment failed."
140
+
141
+ try:
142
+ ents_list = ner_model(caption)
143
+ ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
144
+ except: ents = "NER failed."
145
+
146
+ try:
147
+ topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
148
+ topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])])
149
+ except: topics = "Topics failed."
150
+
151
+ return sentiment, ents, topics
152
+
153
+ # ---------------- VQA ----------------
154
+ def answer_vqa(question, image):
155
+ if image is None or question.strip() == "":
156
+ return "Upload an image and enter a question."
157
+ model = get_vqa_model()
158
+ inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
159
+ with torch.no_grad():
160
+ generated_ids = model.generate(**inputs)
161
+ answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
162
+ free_gpu_cache()
163
+ return answer
164
+
165
+ # Convert a PIL.Image to PNG byte stream
166
+ def to_bytes(img):
167
+ import io
168
+ buf = io.BytesIO()
169
+ img.save(buf, format="PNG")
170
+ return buf.getvalue()
171
+
172
+ # ==============================
173
+ # SECTION 4 β€” UI (GRADIO)
174
+ # ==============================
175
+ def build_ui():
176
+ with gr.Blocks(title="Multimodal AI Image Studio") as demo:
177
+
178
+ gr.HTML("""
179
+ <style>
180
+ .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
181
+ .orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; }
182
+ .teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; }
183
+ .loading-line {
184
+ height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%);
185
+ background-size:200% 100%; animation: loading 1s linear infinite;
186
+ }
187
+ @keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} }
188
+ .circular-img img {
189
+ border-radius: 21%;
190
+ object-fit: cover;
191
+ width: 400px;
192
+ height: 200px;
193
+ box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3),
194
+ 5px 5px 15px rgba(0,0,0,0.3);
195
+ border: 2px solid rgba(255,255,255,0.6);
196
+ }
197
+ </style>
198
+ """)
199
+
200
+ gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange")
201
+ images_state = gr.State([])
202
+ captions_state = gr.State([])
203
+
204
+ # ---------------- Image Input ----------------
205
+ gr.Markdown("### Select Image Source", elem_classes="heading-orange")
206
+ with gr.Tabs():
207
+ with gr.Tab("πŸ“ Upload Image"):
208
+ upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=900, width=960, elem_classes="circular-img")
209
+ upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn")
210
+ with gr.Tab("πŸ“· Webcam"):
211
+ webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=900, width=960, elem_classes="circular-img")
212
+ webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn")
213
+ with gr.Tab("πŸ”— From URL"):
214
+ url_input = gr.Textbox(label="Paste Image URL")
215
+ url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn")
216
+
217
+ # ---------------- Previews ----------------
218
+ with gr.Row():
219
+ with gr.Column(scale=1, min_width=200):
220
+ preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230)
221
+ blip_caption_box = gr.Markdown()
222
+ with gr.Column(scale=1, min_width=200):
223
+ preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230)
224
+ vit_caption_box = gr.Markdown()
225
+ with gr.Column(scale=1, min_width=200):
226
+ preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230)
227
+ blip2_caption_box = gr.Markdown()
228
+
229
+ # ---------------- Generate Captions ----------------
230
+ def generate_all(img, images_state, captions_state):
231
+ if img is None:
232
+ return (None, None, None, "No image.", "No image.", "No image.", [], [])
233
+ captions = make_captions(img)
234
+ return (img, img, img, captions[0], captions[1], captions[2], [img], captions)
235
+
236
+ upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state],
237
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
238
+ webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state],
239
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
240
+
241
+ def load_from_url(url, images_state, captions_state):
242
+ import requests
243
+ from io import BytesIO
244
+ try:
245
+ img = Image.open(BytesIO(requests.get(url).content))
246
+ except:
247
+ return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], [])
248
+ return generate_all(img, images_state, captions_state)
249
+
250
+ url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state],
251
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
252
+
253
+ # ---------------- Metrics ----------------
254
+ gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
255
+ metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
256
+ metrics_A = gr.Markdown()
257
+ metrics_B = gr.Markdown()
258
+ metrics_C = gr.Markdown()
259
+
260
+ def compute_metrics_all_pairs_ui(images, captions):
261
+ yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
262
+ if len(images) < 1 or len(captions) < 3:
263
+ msg = "Upload 1 image and generate all 3 captions."
264
+ yield msg, msg, msg
265
+ return
266
+ imgs = images * 3
267
+ A = compute_metrics_button(imgs, captions, 0, 1)
268
+ B = compute_metrics_button(imgs, captions, 0, 2)
269
+ C = compute_metrics_button(imgs, captions, 1, 2)
270
+ yield (f"**BLIP-large ↔ ViT-GPT2**<br>{A}",
271
+ f"**BLIP-large ↔ BLIP2**<br>{B}",
272
+ f"**ViT-GPT2 ↔ BLIP2**<br>{C}")
273
+
274
+ metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
275
+ outputs=[metrics_A, metrics_B, metrics_C])
276
+
277
+ # ---------------- NLP ----------------
278
+ gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
279
+ nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
280
+ nlp_out = gr.HTML()
281
+
282
+ def do_nlp(captions):
283
+ yield "<div class='loading-line'></div>"
284
+ if len(captions) < 3:
285
+ yield "<b>All captions required.</b>"
286
+ return
287
+ labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
288
+ blocks = []
289
+ for label, cap in zip(labels, captions):
290
+ s, e, t = nlp_bundle(cap)
291
+ block = f"""
292
+ <div style='flex:1;padding:10px;min-width:240px;'>
293
+ <h3><u>{label}</u></h3>
294
+ <b>Sentiment</b><br>{s}<br><br>
295
+ <b>Entities</b><br>{e}<br><br>
296
+ <b>Topics</b><br>{t}
297
+ </div>
298
+ """
299
+ blocks.append(block)
300
+ yield f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
301
+
302
+ nlp_btn.click(do_nlp, inputs=[captions_state], outputs=[nlp_out])
303
+
304
+ # ---------------- VQA ----------------
305
+ gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
306
+ with gr.Row():
307
+ vqa_input = gr.Textbox(label="Ask about the image")
308
+ vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
309
+ vqa_out = gr.Markdown()
310
+
311
+ def vqa_ui(question, image):
312
+ yield "<div class='loading-line'></div>"
313
+ yield answer_vqa(question, image)
314
+
315
+ vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out])
316
+
317
+ return demo
318
+
319
+ # ==============================
320
+ # LAUNCH
321
+ # ==============================
322
+ demo = build_ui()
323
+ demo.launch(share=True, debug=False)