Chyd19 commited on
Commit
c5510ff
Β·
verified Β·
1 Parent(s): 06487e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -0
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # SECTION 1 β€” INSTALL + IMPORTS.
3
+ import torch
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
7
+ import lpips
8
+ import clip
9
+ from bert_score import score
10
+ import torchvision.transforms as T
11
+ from sentence_transformers import SentenceTransformer
12
+ from rouge_score import rouge_scorer
13
+ import numpy as np
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ def free_gpu_cache():
19
+ if torch.cuda.is_available():
20
+ torch.cuda.empty_cache()
21
+
22
+
23
+ # SECTION 2 β€” LOAD LIGHTWEIGHT MODELS
24
+ blip_large_captioner = pipeline(
25
+ "image-to-text",
26
+ model="Salesforce/blip-image-captioning-large",
27
+ device=0 if device=="cuda" else -1
28
+ )
29
+
30
+ vit_gpt2_captioner = pipeline(
31
+ "image-to-text",
32
+ model="nlpconnect/vit-gpt2-image-captioning",
33
+ device=0 if device=="cuda" else -1
34
+ )
35
+
36
+ # --- NLP Pipelines ---
37
+ sentiment_model = pipeline("sentiment-analysis")
38
+ ner_model = pipeline("ner", aggregation_strategy="simple")
39
+ topic_model = pipeline("zero-shot-classification",
40
+ model="facebook/bart-large-mnli")
41
+
42
+ # --- Metrics ---
43
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
44
+ lpips_model = lpips.LPIPS(net='alex').to(device)
45
+ lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))])
46
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity
47
+
48
+
49
+ # SECTION 2b LOAD HEAVY MODELS
50
+ blip2_captioner = None
51
+ vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
52
+ vqa_model = None
53
+
54
+ def get_blip2():
55
+ global blip2_captioner
56
+ if blip2_captioner is None:
57
+ blip2_captioner = pipeline(
58
+ "image-to-text",
59
+ model="Salesforce/blip2-opt-2.7b",
60
+ device=0 if device=="cuda" else -1
61
+ )
62
+ return blip2_captioner
63
+
64
+ def get_vqa_model():
65
+ global vqa_model
66
+ if vqa_model is None:
67
+ vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
68
+ return vqa_model
69
+
70
+
71
+ # SECTION 3 β€” FUNCTIONS DEFINITION
72
+ def make_captions(img):
73
+ captions = []
74
+ try: captions.append(blip_large_captioner(img)[0]["generated_text"])
75
+ except: captions.append("BLIP-large failed.")
76
+ try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"])
77
+ except: captions.append("ViT-GPT2 failed.")
78
+ try:
79
+ blip2 = get_blip2()
80
+ captions.append(blip2(img)[0]["generated_text"])
81
+ except: captions.append("BLIP2-opt failed.")
82
+ return captions
83
+
84
+ # ---------------- Metrics Computation ---------------------
85
+ def compute_metrics_button(images, captions, idx1, idx2):
86
+ # CLIP similarity
87
+ img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
88
+ img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
89
+ with torch.no_grad():
90
+ feat1 = clip_model.encode_image(img1_clip)
91
+ feat2 = clip_model.encode_image(img2_clip)
92
+ clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
93
+
94
+ # LPIPS
95
+ img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
96
+ img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
97
+ with torch.no_grad():
98
+ lpips_score = float(lpips_model(img1_lp, img2_lp).item())
99
+
100
+ # BERTScore
101
+ _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
102
+ bert_f1 = float(F1.mean().item())
103
+
104
+ # Cosine similarity of embeddings
105
+ emb1 = sentence_model.encode([captions[idx1]])
106
+ emb2 = sentence_model.encode([captions[idx2]])
107
+ cosine_sim = float(cosine_similarity(emb1, emb2)[0][0])
108
+
109
+ # Jaccard similarity
110
+ tokens1 = set(captions[idx1].lower().split())
111
+ tokens2 = set(captions[idx2].lower().split())
112
+ jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2))
113
+
114
+ # ROUGE
115
+ scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
116
+ rouge_scores = scorer.score(captions[idx1], captions[idx2])
117
+
118
+ return f"""
119
+ - CLIP: {clip_sim:.4f}
120
+ - LPIPS: {lpips_score:.4f}
121
+ - BERT-F1: {bert_f1:.4f}
122
+ - Cosine: {cosine_sim:.4f}
123
+ - Jaccard: {jaccard_sim:.4f}
124
+ - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
125
+ - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
126
+ """
127
+
128
+ # ---- NLP ----
129
+ def nlp_bundle(caption):
130
+ try:
131
+ sentiment = sentiment_model(caption)
132
+ sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment])
133
+ except: sentiment = "Sentiment failed."
134
+
135
+ try:
136
+ ents_list = ner_model(caption)
137
+ ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
138
+ except: ents = "NER failed."
139
+
140
+ try:
141
+ topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
142
+ topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])])
143
+ except: topics = "Topics failed."
144
+
145
+ return sentiment, ents, topics
146
+
147
+ # ---------------- VQA ----------------
148
+ def answer_vqa(question, image):
149
+ if image is None or question.strip() == "":
150
+ return "Upload an image and enter a question."
151
+ model = get_vqa_model()
152
+ inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
153
+ with torch.no_grad():
154
+ generated_ids = model.generate(**inputs)
155
+ answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
156
+ free_gpu_cache()
157
+ return answer
158
+
159
+ # Convert a PIL.Image to PNG byte stream
160
+ def to_bytes(img):
161
+ import io
162
+ buf = io.BytesIO()
163
+ img.save(buf, format="PNG")
164
+ return buf.getvalue()
165
+
166
+
167
+ # SECTION 4 β€” UI (GRADIO)
168
+ def build_ui():
169
+ with gr.Blocks(title="Multimodal AI Image Studio") as demo:
170
+
171
+ gr.HTML("""
172
+ <style>
173
+ .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
174
+ .orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; }
175
+ .teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; }
176
+ .loading-line {
177
+ height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%);
178
+ background-size:200% 100%; animation: loading 1s linear infinite;
179
+ }
180
+ @keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} }
181
+ .circular-img img {
182
+ border-radius: 21%;
183
+ object-fit: cover;
184
+ width: 400px;
185
+ height: 200px;
186
+ box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3),
187
+ 5px 5px 15px rgba(0,0,0,0.3);
188
+ border: 2px solid rgba(255,255,255,0.6);
189
+ }
190
+ .metrics-row {
191
+ display: flex;
192
+ flex-direction: row;
193
+ gap: 20px;
194
+ }
195
+ .metrics-row > div {
196
+ flex: 1;
197
+ }
198
+ </style>
199
+ """)
200
+
201
+ gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange")
202
+ images_state = gr.State([])
203
+ captions_state = gr.State([])
204
+
205
+ # ---------------- Image Input ----------------
206
+ gr.Markdown("### Select Image Source", elem_classes="heading-orange")
207
+ with gr.Tabs():
208
+ with gr.Tab("πŸ“ Upload Image"):
209
+ upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=600, elem_classes="circular-img")
210
+ upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn")
211
+ with gr.Tab("πŸ“· Webcam"):
212
+ webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=600, elem_classes="circular-img")
213
+ webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn")
214
+ with gr.Tab("πŸ”— From URL"):
215
+ url_input = gr.Textbox(label="Paste Image URL")
216
+ url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn")
217
+
218
+ # ---------------- Previews ----------------
219
+ with gr.Row():
220
+ with gr.Column(scale=1, min_width=200):
221
+ preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230)
222
+ blip_caption_box = gr.Markdown()
223
+ with gr.Column(scale=1, min_width=200):
224
+ preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230)
225
+ vit_caption_box = gr.Markdown()
226
+ with gr.Column(scale=1, min_width=200):
227
+ preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230)
228
+ blip2_caption_box = gr.Markdown()
229
+
230
+ # ---------------- Generate Captions ----------------
231
+ def generate_all(img, images_state, captions_state):
232
+ if img is None:
233
+ return (None, None, None, "No image.", "No image.", "No image.", [], [])
234
+ captions = make_captions(img)
235
+ return (img, img, img, captions[0], captions[1], captions[2], [img], captions)
236
+
237
+ upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state],
238
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
239
+ webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state],
240
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
241
+
242
+ def load_from_url(url, images_state, captions_state):
243
+ import requests
244
+ from io import BytesIO
245
+ try:
246
+ img = Image.open(BytesIO(requests.get(url).content))
247
+ except:
248
+ return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], [])
249
+ return generate_all(img, images_state, captions_state)
250
+
251
+ url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state],
252
+ outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
253
+
254
+ # ---------------- Metrics ----------------
255
+ gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
256
+ metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
257
+ with gr.Row(elem_classes="metrics-row"):
258
+ metrics_A = gr.Markdown()
259
+ metrics_B = gr.Markdown()
260
+ metrics_C = gr.Markdown()
261
+
262
+ def compute_metrics_all_pairs_ui(images, captions):
263
+ # 3 spinners
264
+ yield (
265
+ "<div class='loading-line'></div>",
266
+ "<div class='loading-line'></div>",
267
+ "<div class='loading-line'></div>"
268
+ )
269
+
270
+ if len(images) < 1 or len(captions) < 3:
271
+ msg = "<b>Upload 1 image and generate all 3 captions.</b>"
272
+ yield (msg, msg, msg)
273
+ return
274
+
275
+ imgs = images * 3
276
+ A = compute_metrics_button(imgs, captions, 0, 1)
277
+ B = compute_metrics_button(imgs, captions, 0, 2)
278
+ C = compute_metrics_button(imgs, captions, 1, 2)
279
+
280
+ yield (
281
+ f"### BLIP-large ↔ ViT-GPT2\n{A}",
282
+ f"### BLIP-large ↔ BLIP2\n{B}",
283
+ f"### ViT-GPT2 ↔ BLIP2\n{C}"
284
+ )
285
+
286
+ metrics_btn.click(
287
+ compute_metrics_all_pairs_ui,
288
+ inputs=[images_state, captions_state],
289
+ outputs=[metrics_A, metrics_B, metrics_C]
290
+ )
291
+
292
+ # ---------------- NLP ----------------
293
+ gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
294
+ nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
295
+
296
+ with gr.Row(elem_classes="metrics-row"): # reuse metrics-row for flex layout
297
+ nlp_A = gr.Markdown()
298
+ nlp_B = gr.Markdown()
299
+ nlp_C = gr.Markdown()
300
+
301
+ def do_nlp_all(captions):
302
+ # 3 spinners like metrics
303
+ yield (
304
+ "<div class='loading-line'></div>",
305
+ "<div class='loading-line'></div>",
306
+ "<div class='loading-line'></div>"
307
+ )
308
+
309
+ if len(captions) < 3:
310
+ msg = "<b>All 3 captions required.</b>"
311
+ yield (msg, msg, msg)
312
+ return
313
+
314
+ labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
315
+ results = []
316
+ for label, cap in zip(labels, captions):
317
+ s, e, t = nlp_bundle(cap)
318
+ block = f"""
319
+ <h3><u>{label}</u></h3>
320
+ <b>Sentiment</b><br>{s}<br><br>
321
+ <b>Entities</b><br>{e}<br><br>
322
+ <b>Topics</b><br>{t}
323
+ """
324
+ results.append(block)
325
+
326
+ yield (results[0], results[1], results[2])
327
+
328
+ nlp_btn.click(do_nlp_all, inputs=[captions_state], outputs=[nlp_A, nlp_B, nlp_C])
329
+
330
+
331
+ # ---------------- VQA ----------------
332
+ gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
333
+ with gr.Row():
334
+ vqa_input = gr.Textbox(label="Ask about the image")
335
+ vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
336
+ vqa_out = gr.Markdown()
337
+
338
+ def vqa_ui(question, image):
339
+ yield "<div class='loading-line'></div>"
340
+ yield answer_vqa(question, image)
341
+
342
+ vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out])
343
+
344
+ return demo
345
+
346
+ # LAUNCH
347
+ # ==============================
348
+ demo = build_ui()
349
+ demo.launch(debug=False)