Chyd19 commited on
Commit
b3e8e19
·
verified ·
1 Parent(s): c5510ff

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -349
app.py DELETED
@@ -1,349 +0,0 @@
1
-
2
- # SECTION 1 — INSTALL + IMPORTS.
3
- import torch
4
- import gradio as gr
5
- from PIL import Image
6
- from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
7
- import lpips
8
- import clip
9
- from bert_score import score
10
- import torchvision.transforms as T
11
- from sentence_transformers import SentenceTransformer
12
- from rouge_score import rouge_scorer
13
- import numpy as np
14
- from sklearn.metrics.pairwise import cosine_similarity
15
-
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
-
18
- def free_gpu_cache():
19
- if torch.cuda.is_available():
20
- torch.cuda.empty_cache()
21
-
22
-
23
- # SECTION 2 — LOAD LIGHTWEIGHT MODELS
24
- blip_large_captioner = pipeline(
25
- "image-to-text",
26
- model="Salesforce/blip-image-captioning-large",
27
- device=0 if device=="cuda" else -1
28
- )
29
-
30
- vit_gpt2_captioner = pipeline(
31
- "image-to-text",
32
- model="nlpconnect/vit-gpt2-image-captioning",
33
- device=0 if device=="cuda" else -1
34
- )
35
-
36
- # --- NLP Pipelines ---
37
- sentiment_model = pipeline("sentiment-analysis")
38
- ner_model = pipeline("ner", aggregation_strategy="simple")
39
- topic_model = pipeline("zero-shot-classification",
40
- model="facebook/bart-large-mnli")
41
-
42
- # --- Metrics ---
43
- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
44
- lpips_model = lpips.LPIPS(net='alex').to(device)
45
- lpips_transform = T.Compose([T.ToTensor(), T.Resize((128,128))])
46
- sentence_model = SentenceTransformer("all-MiniLM-L6-v2") # for cosine similarity
47
-
48
-
49
- # SECTION 2b LOAD HEAVY MODELS
50
- blip2_captioner = None
51
- vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
52
- vqa_model = None
53
-
54
- def get_blip2():
55
- global blip2_captioner
56
- if blip2_captioner is None:
57
- blip2_captioner = pipeline(
58
- "image-to-text",
59
- model="Salesforce/blip2-opt-2.7b",
60
- device=0 if device=="cuda" else -1
61
- )
62
- return blip2_captioner
63
-
64
- def get_vqa_model():
65
- global vqa_model
66
- if vqa_model is None:
67
- vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
68
- return vqa_model
69
-
70
-
71
- # SECTION 3 — FUNCTIONS DEFINITION
72
- def make_captions(img):
73
- captions = []
74
- try: captions.append(blip_large_captioner(img)[0]["generated_text"])
75
- except: captions.append("BLIP-large failed.")
76
- try: captions.append(vit_gpt2_captioner(img)[0]["generated_text"])
77
- except: captions.append("ViT-GPT2 failed.")
78
- try:
79
- blip2 = get_blip2()
80
- captions.append(blip2(img)[0]["generated_text"])
81
- except: captions.append("BLIP2-opt failed.")
82
- return captions
83
-
84
- # ---------------- Metrics Computation ---------------------
85
- def compute_metrics_button(images, captions, idx1, idx2):
86
- # CLIP similarity
87
- img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
88
- img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
89
- with torch.no_grad():
90
- feat1 = clip_model.encode_image(img1_clip)
91
- feat2 = clip_model.encode_image(img2_clip)
92
- clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
93
-
94
- # LPIPS
95
- img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
96
- img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
97
- with torch.no_grad():
98
- lpips_score = float(lpips_model(img1_lp, img2_lp).item())
99
-
100
- # BERTScore
101
- _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
102
- bert_f1 = float(F1.mean().item())
103
-
104
- # Cosine similarity of embeddings
105
- emb1 = sentence_model.encode([captions[idx1]])
106
- emb2 = sentence_model.encode([captions[idx2]])
107
- cosine_sim = float(cosine_similarity(emb1, emb2)[0][0])
108
-
109
- # Jaccard similarity
110
- tokens1 = set(captions[idx1].lower().split())
111
- tokens2 = set(captions[idx2].lower().split())
112
- jaccard_sim = float(len(tokens1 & tokens2) / len(tokens1 | tokens2))
113
-
114
- # ROUGE
115
- scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
116
- rouge_scores = scorer.score(captions[idx1], captions[idx2])
117
-
118
- return f"""
119
- - CLIP: {clip_sim:.4f}
120
- - LPIPS: {lpips_score:.4f}
121
- - BERT-F1: {bert_f1:.4f}
122
- - Cosine: {cosine_sim:.4f}
123
- - Jaccard: {jaccard_sim:.4f}
124
- - ROUGE-1: {rouge_scores['rouge1'].fmeasure:.4f}
125
- - ROUGE-L: {rouge_scores['rougeL'].fmeasure:.4f}
126
- """
127
-
128
- # ---- NLP ----
129
- def nlp_bundle(caption):
130
- try:
131
- sentiment = sentiment_model(caption)
132
- sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment])
133
- except: sentiment = "Sentiment failed."
134
-
135
- try:
136
- ents_list = ner_model(caption)
137
- ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
138
- except: ents = "NER failed."
139
-
140
- try:
141
- topics_raw = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
142
- topics = "<br>".join([f"{lbl}: {float(scr):.2f}" for lbl, scr in zip(topics_raw["labels"], topics_raw["scores"])])
143
- except: topics = "Topics failed."
144
-
145
- return sentiment, ents, topics
146
-
147
- # ---------------- VQA ----------------
148
- def answer_vqa(question, image):
149
- if image is None or question.strip() == "":
150
- return "Upload an image and enter a question."
151
- model = get_vqa_model()
152
- inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
153
- with torch.no_grad():
154
- generated_ids = model.generate(**inputs)
155
- answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
156
- free_gpu_cache()
157
- return answer
158
-
159
- # Convert a PIL.Image to PNG byte stream
160
- def to_bytes(img):
161
- import io
162
- buf = io.BytesIO()
163
- img.save(buf, format="PNG")
164
- return buf.getvalue()
165
-
166
-
167
- # SECTION 4 — UI (GRADIO)
168
- def build_ui():
169
- with gr.Blocks(title="Multimodal AI Image Studio") as demo:
170
-
171
- gr.HTML("""
172
- <style>
173
- .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
174
- .orange-btn button { background-color:#ff5500; color:white; border-radius:6px; height:36px; font-weight:bold; }
175
- .teal-btn button { background-color:#008080; color:white; border-radius:6px; height:36px; font-weight:bold; }
176
- .loading-line {
177
- height:4px; background:linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%);
178
- background-size:200% 100%; animation: loading 1s linear infinite;
179
- }
180
- @keyframes loading { 0% {background-position:200% 0;} 100% {background-position:-200% 0;} }
181
- .circular-img img {
182
- border-radius: 21%;
183
- object-fit: cover;
184
- width: 400px;
185
- height: 200px;
186
- box-shadow: inset -10px -10px 30px rgba(255,255,255,0.3),
187
- 5px 5px 15px rgba(0,0,0,0.3);
188
- border: 2px solid rgba(255,255,255,0.6);
189
- }
190
- .metrics-row {
191
- display: flex;
192
- flex-direction: row;
193
- gap: 20px;
194
- }
195
- .metrics-row > div {
196
- flex: 1;
197
- }
198
- </style>
199
- """)
200
-
201
- gr.Markdown("## Multimodal AI Image Studio: Comparative Image-to-Text Analysis", elem_classes="heading-orange")
202
- images_state = gr.State([])
203
- captions_state = gr.State([])
204
-
205
- # ---------------- Image Input ----------------
206
- gr.Markdown("### Select Image Source", elem_classes="heading-orange")
207
- with gr.Tabs():
208
- with gr.Tab("📁 Upload Image"):
209
- upload_input = gr.Image(type="pil", sources=["upload"], label="Upload Image", height=600, elem_classes="circular-img")
210
- upload_btn = gr.Button("Generate Captions", elem_classes="orange-btn")
211
- with gr.Tab("📷 Webcam"):
212
- webcam_input = gr.Image(type="pil", sources=["webcam"], label="Webcam", height=600, elem_classes="circular-img")
213
- webcam_btn = gr.Button("Capture & Generate Captions", elem_classes="orange-btn")
214
- with gr.Tab("🔗 From URL"):
215
- url_input = gr.Textbox(label="Paste Image URL")
216
- url_btn = gr.Button("Fetch & Generate Captions", elem_classes="orange-btn")
217
-
218
- # ---------------- Previews ----------------
219
- with gr.Row():
220
- with gr.Column(scale=1, min_width=200):
221
- preview1 = gr.Image(type="pil",label="Preview 1", interactive=False, height=230)
222
- blip_caption_box = gr.Markdown()
223
- with gr.Column(scale=1, min_width=200):
224
- preview2 = gr.Image(type="pil",label="Preview 2", interactive=False, height=230)
225
- vit_caption_box = gr.Markdown()
226
- with gr.Column(scale=1, min_width=200):
227
- preview3 = gr.Image(type="pil",label="Preview 3", interactive=False, height=230)
228
- blip2_caption_box = gr.Markdown()
229
-
230
- # ---------------- Generate Captions ----------------
231
- def generate_all(img, images_state, captions_state):
232
- if img is None:
233
- return (None, None, None, "No image.", "No image.", "No image.", [], [])
234
- captions = make_captions(img)
235
- return (img, img, img, captions[0], captions[1], captions[2], [img], captions)
236
-
237
- upload_btn.click(generate_all, inputs=[upload_input, images_state, captions_state],
238
- outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
239
- webcam_btn.click(generate_all, inputs=[webcam_input, images_state, captions_state],
240
- outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
241
-
242
- def load_from_url(url, images_state, captions_state):
243
- import requests
244
- from io import BytesIO
245
- try:
246
- img = Image.open(BytesIO(requests.get(url).content))
247
- except:
248
- return (None, None, None, "Bad URL.", "Bad URL.", "Bad URL.", [], [])
249
- return generate_all(img, images_state, captions_state)
250
-
251
- url_btn.click(load_from_url, inputs=[url_input, images_state, captions_state],
252
- outputs=[preview1, preview2, preview3, blip_caption_box, vit_caption_box, blip2_caption_box, images_state, captions_state])
253
-
254
- # ---------------- Metrics ----------------
255
- gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
256
- metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
257
- with gr.Row(elem_classes="metrics-row"):
258
- metrics_A = gr.Markdown()
259
- metrics_B = gr.Markdown()
260
- metrics_C = gr.Markdown()
261
-
262
- def compute_metrics_all_pairs_ui(images, captions):
263
- # 3 spinners
264
- yield (
265
- "<div class='loading-line'></div>",
266
- "<div class='loading-line'></div>",
267
- "<div class='loading-line'></div>"
268
- )
269
-
270
- if len(images) < 1 or len(captions) < 3:
271
- msg = "<b>Upload 1 image and generate all 3 captions.</b>"
272
- yield (msg, msg, msg)
273
- return
274
-
275
- imgs = images * 3
276
- A = compute_metrics_button(imgs, captions, 0, 1)
277
- B = compute_metrics_button(imgs, captions, 0, 2)
278
- C = compute_metrics_button(imgs, captions, 1, 2)
279
-
280
- yield (
281
- f"### BLIP-large ↔ ViT-GPT2\n{A}",
282
- f"### BLIP-large ↔ BLIP2\n{B}",
283
- f"### ViT-GPT2 ↔ BLIP2\n{C}"
284
- )
285
-
286
- metrics_btn.click(
287
- compute_metrics_all_pairs_ui,
288
- inputs=[images_state, captions_state],
289
- outputs=[metrics_A, metrics_B, metrics_C]
290
- )
291
-
292
- # ---------------- NLP ----------------
293
- gr.Markdown("### NLP Analysis", elem_classes="heading-orange")
294
- nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
295
-
296
- with gr.Row(elem_classes="metrics-row"): # reuse metrics-row for flex layout
297
- nlp_A = gr.Markdown()
298
- nlp_B = gr.Markdown()
299
- nlp_C = gr.Markdown()
300
-
301
- def do_nlp_all(captions):
302
- # 3 spinners like metrics
303
- yield (
304
- "<div class='loading-line'></div>",
305
- "<div class='loading-line'></div>",
306
- "<div class='loading-line'></div>"
307
- )
308
-
309
- if len(captions) < 3:
310
- msg = "<b>All 3 captions required.</b>"
311
- yield (msg, msg, msg)
312
- return
313
-
314
- labels = ["BLIP-large", "ViT-GPT2", "BLIP2"]
315
- results = []
316
- for label, cap in zip(labels, captions):
317
- s, e, t = nlp_bundle(cap)
318
- block = f"""
319
- <h3><u>{label}</u></h3>
320
- <b>Sentiment</b><br>{s}<br><br>
321
- <b>Entities</b><br>{e}<br><br>
322
- <b>Topics</b><br>{t}
323
- """
324
- results.append(block)
325
-
326
- yield (results[0], results[1], results[2])
327
-
328
- nlp_btn.click(do_nlp_all, inputs=[captions_state], outputs=[nlp_A, nlp_B, nlp_C])
329
-
330
-
331
- # ---------------- VQA ----------------
332
- gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
333
- with gr.Row():
334
- vqa_input = gr.Textbox(label="Ask about the image")
335
- vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
336
- vqa_out = gr.Markdown()
337
-
338
- def vqa_ui(question, image):
339
- yield "<div class='loading-line'></div>"
340
- yield answer_vqa(question, image)
341
-
342
- vqa_btn.click(vqa_ui, inputs=[vqa_input, preview1], outputs=[vqa_out])
343
-
344
- return demo
345
-
346
- # LAUNCH
347
- # ==============================
348
- demo = build_ui()
349
- demo.launch(debug=False)