ProfRom commited on
Commit
f2f22f7
·
verified ·
1 Parent(s): 7ddee74

Gailey - Sanity Check 3

Browse files
Files changed (1) hide show
  1. app.py +317 -67
app.py CHANGED
@@ -1,72 +1,322 @@
1
 
 
 
 
 
 
 
2
  import gradio as gr
3
- from transformers import pipeline
4
-
5
- # BLIP captioning
6
- caption_pipeline = pipeline(
7
- task="image-to-text",
8
- model="Salesforce/blip-image-captioning-base"
9
- )
10
-
11
- # BLIP VQA
12
- vqa_pipeline = pipeline(
13
- task="visual-question-answering",
14
- model="Salesforce/blip-vqa-base"
15
- )
16
-
17
- # CLIP zero-shot classification
18
- clip_pipeline = pipeline(
19
- task="zero-shot-image-classification",
20
- model="openai/clip-vit-base-patch32"
21
- )
22
-
23
- def process_image(image, question, labels):
24
- # Caption
25
- caption_result = caption_pipeline(image)
26
- caption = caption_result[0]["generated_text"]
27
-
28
- # VQA
29
- if question and question.strip():
30
- vqa_result = vqa_pipeline(image=image, question=question)
31
- vqa_answer = vqa_result[0]["answer"]
32
- else:
33
- vqa_answer = "No question provided."
34
-
35
- # CLIP Classification
36
- if labels and labels.strip():
37
- candidate_labels = [l.strip() for l in labels.split(",") if l.strip()]
38
- if candidate_labels:
39
- # NOTE: use 'images=' or positional arg
40
- clip_result = clip_pipeline(images=image, candidate_labels=candidate_labels)
41
- clip_output = "\n".join(
42
- f"{item['label']}: {round(item['score'] * 100, 1)}%"
43
- for item in clip_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
- else:
46
- clip_output = "No valid labels provided."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
- clip_output = "No labels provided."
49
-
50
- return caption, vqa_answer, clip_output
51
-
52
-
53
- demo = gr.Interface(
54
- fn=process_image,
55
- inputs=[
56
- gr.Image(type="pil", label="Upload an image"),
57
- gr.Textbox(label="Ask a question about the image (optional)"),
58
- gr.Textbox(
59
- label="Enter CLIP classification labels (comma-separated)",
60
- placeholder="e.g., man, boy, park, snow, happiness",
61
- ),
62
- ],
63
- outputs=[
64
- gr.Textbox(label="Generated Caption"),
65
- gr.Textbox(label="VQA Answer"),
66
- gr.Textbox(label="CLIP Classification Scores"),
67
- ],
68
- title="Multimodal AI — Captioning + VQA + Zero-Shot Classification",
69
- )
70
-
71
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
1
 
2
+ # app.py — Lazy Loaded Multimodal AI System
3
+ #
4
+ # Models load ONLY when needed to avoid memory overflow
5
+ # Works on Hugging Face free CPU Spaces
6
+
7
+ import torch
8
  import gradio as gr
9
+
10
+ device = torch.device("cpu")
11
+
12
+
13
+ # ---------------------------------------------------------
14
+ # LAZY MODEL LOADERS
15
+ # ---------------------------------------------------------
16
+
17
+ def load_caption_model():
18
+ from transformers import BlipProcessor, BlipForConditionalGeneration
19
+ model_name = "Salesforce/blip-image-captioning-base"
20
+ processor = BlipProcessor.from_pretrained(model_name)
21
+ model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)
22
+ return processor, model
23
+
24
+
25
+ def load_sentiment_model():
26
+ from transformers import pipeline
27
+ return pipeline(
28
+ "sentiment-analysis",
29
+ model="distilbert-base-uncased-finetuned-sst-2-english"
30
+ )
31
+
32
+
33
+ def load_vqa_model():
34
+ from transformers import BlipProcessor, BlipForQuestionAnswering
35
+ model_name = "Salesforce/blip-vqa-base"
36
+ processor = BlipProcessor.from_pretrained(model_name)
37
+ model = BlipForQuestionAnswering.from_pretrained(model_name).to(device)
38
+ return processor, model
39
+
40
+
41
+ def load_detr_model():
42
+ from transformers import DetrImageProcessor, DetrForObjectDetection
43
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
44
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
45
+ return processor, model
46
+
47
+
48
+ def load_vit_model():
49
+ from transformers import ViTImageProcessor, ViTForImageClassification
50
+ model_name = "google/vit-base-patch16-224"
51
+ processor = ViTImageProcessor.from_pretrained(model_name)
52
+ model = ViTForImageClassification.from_pretrained(model_name).to(device)
53
+ return processor, model
54
+
55
+
56
+ # NEW — more verbose, less repetitive rewrite model
57
+ def load_llm():
58
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
59
+ name = "google/flan-t5-large"
60
+ tokenizer = AutoTokenizer.from_pretrained(name)
61
+ model = AutoModelForSeq2SeqLM.from_pretrained(name).to(device)
62
+ return tokenizer, model
63
+
64
+
65
+ # ---------------------------------------------------------
66
+ # TASKS
67
+ # ---------------------------------------------------------
68
+
69
+ def generate_caption(image):
70
+ processor, model = load_caption_model()
71
+ inputs = processor(images=image, return_tensors="pt").to(device)
72
+ with torch.no_grad():
73
+ out_ids = model.generate(**inputs, max_new_tokens=30)
74
+ return processor.decode(out_ids[0], skip_special_tokens=True)
75
+
76
+
77
+ def analyze_sentiment(text):
78
+ sentiment = load_sentiment_model()
79
+ out = sentiment(text)[0]
80
+ return out["label"], round(out["score"] * 100, 2)
81
+
82
+
83
+ def vqa_answer(image, question):
84
+ processor, model = load_vqa_model()
85
+ inputs = processor(images=image, text=question, return_tensors="pt").to(device)
86
+ with torch.no_grad():
87
+ out = model.generate(**inputs)
88
+ return processor.decode(out[0], skip_special_tokens=True)
89
+
90
+
91
+ def detect_objects(image):
92
+ processor, model = load_detr_model()
93
+ inputs = processor(images=image, return_tensors="pt").to(device)
94
+
95
+ with torch.no_grad():
96
+ outputs = model(**inputs)
97
+
98
+ target_sizes = torch.tensor([image.size[::-1]])
99
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
100
+
101
+ detections = []
102
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
103
+ if score > 0.3:
104
+ detections.append(
105
+ f"{model.config.id2label[label.item()]} (score {round(score.item(), 2)})"
106
  )
107
+ if len(detections) == 0:
108
+ return ["No high-confidence objects detected"]
109
+ return detections
110
+
111
+
112
+ def classify_scene(image):
113
+ processor, model = load_vit_model()
114
+ inputs = processor(images=image, return_tensors="pt").to(device)
115
+ with torch.no_grad():
116
+ logits = model(**inputs).logits
117
+ label = logits.argmax(-1).item()
118
+ return model.config.id2label[label]
119
+
120
+
121
+ # ---------------------------------------------------------
122
+ # REWRITE CAPTIONS (8 STYLE SYSTEM + LENGTH SLIDER)
123
+ # ---------------------------------------------------------
124
+
125
+ def _build_style_prompt(caption, style):
126
+ base = (
127
+ "Rewrite the following image caption. "
128
+ "Keep the original meaning and important details, "
129
+ "but change the wording significantly and avoid repeating sentences verbatim. "
130
+ "Do not just copy the original text.\n\n"
131
+ f"Original caption:\n{caption}\n\n"
132
+ )
133
+
134
+ if style == "Short":
135
+ return (
136
+ base
137
+ + "Now produce a shorter, compact version in one or two sentences."
138
+ )
139
+ elif style == "Creative":
140
+ return (
141
+ base
142
+ + "Rewrite it in a colorful, imaginative, and richly descriptive style."
143
+ )
144
+ elif style == "Technical":
145
+ return (
146
+ base
147
+ + "Rewrite it in a highly technical, analytical style using precise visual terminology."
148
+ )
149
+ elif style == "Humorous":
150
+ return (
151
+ base
152
+ + "Rewrite it with a fun, humorous, witty tone while keeping the meaning."
153
+ )
154
+ elif style == "Poetic":
155
+ return (
156
+ base
157
+ + "Rewrite it in a poetic, rhythmic, metaphorical style using sensory language."
158
+ )
159
+ elif style == "Cinematic":
160
+ return (
161
+ base
162
+ + "Rewrite it as if describing an epic cinematic movie scene with dramatic, vivid imagery."
163
+ )
164
+ elif style == "Journalistic":
165
+ return (
166
+ base
167
+ + "Rewrite it in a factual, neutral, journalistic news-reporting style."
168
+ )
169
+ elif style == "Academic":
170
+ return (
171
+ base
172
+ + "Rewrite it in a formal, academic style with clear, analytical phrasing."
173
+ )
174
  else:
175
+ # Fallback: treat unknown style as creative rewrite
176
+ return (
177
+ base
178
+ + "Rewrite it in a natural, descriptive style."
179
+ )
180
+
181
+
182
+ def rewrite_caption(caption, style, length):
183
+ tokenizer, model = load_llm()
184
+
185
+ prompt = _build_style_prompt(caption, style)
186
+
187
+ # Tokenize
188
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
189
+
190
+ # First pass: normal creative decoding
191
+ with torch.no_grad():
192
+ outputs = model.generate(
193
+ **inputs,
194
+ max_new_tokens=length,
195
+ do_sample=True,
196
+ temperature=0.9,
197
+ top_p=0.9,
198
+ no_repeat_ngram_size=3,
199
+ repetition_penalty=1.2,
200
+ )
201
+
202
+ rewritten = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
203
+
204
+ # If the model basically echoed the caption, try a second, more forceful pass.
205
+ if rewritten.lower().strip() == caption.lower().strip():
206
+ strong_prompt = (
207
+ "Paraphrase and expand the following caption. "
208
+ "Use different wording and add extra detail, but keep the meaning. "
209
+ "Do not repeat the original sentence exactly.\n\n"
210
+ f"Original caption:\n{caption}"
211
+ )
212
+ strong_inputs = tokenizer(strong_prompt, return_tensors="pt").to(device)
213
+
214
+ with torch.no_grad():
215
+ outputs2 = model.generate(
216
+ **strong_inputs,
217
+ max_new_tokens=length,
218
+ do_sample=True,
219
+ temperature=1.0,
220
+ top_p=0.95,
221
+ no_repeat_ngram_size=3,
222
+ repetition_penalty=1.3,
223
+ )
224
+ rewritten2 = tokenizer.decode(outputs2[0], skip_special_tokens=True).strip()
225
+
226
+ # Only replace if it actually changed something
227
+ if rewritten2 and rewritten2.lower().strip() != caption.lower().strip():
228
+ rewritten = rewritten2
229
+
230
+ return rewritten
231
+
232
+
233
+ def extract_metadata(image):
234
+ width, height = image.size
235
+ meta = f"Dimensions: {width} x {height}\n"
236
+ meta += "EXIF data detected\n" if "exif" in image.info else "No EXIF data available\n"
237
+ return meta
238
+
239
+
240
+ # ---------------------------------------------------------
241
+ # MAIN LOOP
242
+ # ---------------------------------------------------------
243
+
244
+ def process_all(image, question, style, length):
245
+ if image is None:
246
+ return ["No image"] * 8
247
+
248
+ caption = generate_caption(image)
249
+ sentiment_label, sentiment_score = analyze_sentiment(caption)
250
+ vqa = vqa_answer(image, question) if question else "No question asked"
251
+ objects = detect_objects(image)
252
+ scene = classify_scene(image)
253
+ rewritten = rewrite_caption(caption, style, length)
254
+ metadata = extract_metadata(image)
255
+
256
+ return caption, sentiment_label, sentiment_score, vqa, objects, scene, rewritten, metadata
257
+
258
+
259
+ # ---------------------------------------------------------
260
+ # GRADIO UI
261
+ # ---------------------------------------------------------
262
+
263
+ with gr.Blocks(title="Multimodal AI System (Lazy Loaded)") as demo:
264
+ gr.Markdown("# **Multimodal AI System**")
265
+
266
+ with gr.Row():
267
+ image_input = gr.Image(type="pil", label="Upload Image")
268
+ question_input = gr.Textbox(label="Ask a Question")
269
+
270
+ style_input = gr.Dropdown(
271
+ [
272
+ "Short",
273
+ "Creative",
274
+ "Technical",
275
+ "Humorous",
276
+ "Poetic",
277
+ "Cinematic",
278
+ "Journalistic",
279
+ "Academic"
280
+ ],
281
+ label="Rewrite Style"
282
+ )
283
+
284
+ # New: length slider
285
+ length_slider = gr.Slider(
286
+ minimum=20,
287
+ maximum=200,
288
+ value=80,
289
+ step=10,
290
+ label="Rewrite Length (Max Tokens)"
291
+ )
292
+
293
+ run_btn = gr.Button("Run All Tools")
294
+
295
+ caption = gr.Textbox(label="Generated Caption")
296
+ sentiment_label = gr.Textbox(label="Sentiment Label")
297
+ sentiment_score = gr.Number(label="Sentiment Score")
298
+ vqa_output = gr.Textbox(label="VQA Answer")
299
+ objects_output = gr.JSON(label="Detected Objects")
300
+ scene_output = gr.Textbox(label="Scene Classification")
301
+ rewritten_output = gr.Textbox(label="Rewritten Caption")
302
+ metadata_output = gr.Textbox(label="Image Metadata")
303
+
304
+ run_btn.click(
305
+ process_all,
306
+ [image_input, question_input, style_input, length_slider],
307
+ [
308
+ caption,
309
+ sentiment_label,
310
+ sentiment_score,
311
+ vqa_output,
312
+ objects_output,
313
+ scene_output,
314
+ rewritten_output,
315
+ metadata_output
316
+ ]
317
+ )
318
+
319
+
320
+ if __name__ == "__main__":
321
+ demo.launch()
322