ProfRom commited on
Commit
adbc5fd
·
verified ·
1 Parent(s): e95fe57

Smallwood - Sanity Check 3

Browse files
Files changed (1) hide show
  1. app.py +122 -199
app.py CHANGED
@@ -1,204 +1,127 @@
1
 
2
- # app.py — Lazy Loaded Multimodal AI System
3
- #
4
- # Models load ONLY when needed to avoid memory overflow
5
- # Works on Hugging Face free CPU Spaces
6
-
7
- import torch
8
  import gradio as gr
9
-
10
- device = torch.device("cpu")
11
-
12
-
13
- # ---------------------------------------------------------
14
- # LAZY MODEL LOADERS
15
- # ---------------------------------------------------------
16
-
17
- def load_caption_model():
18
- from transformers import BlipProcessor, BlipForConditionalGeneration
19
- model_name = "Salesforce/blip-image-captioning-base"
20
- processor = BlipProcessor.from_pretrained(model_name)
21
- model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)
22
- return processor, model
23
-
24
-
25
- def load_sentiment_model():
26
- from transformers import pipeline
27
- return pipeline(
28
- "sentiment-analysis",
29
- model="distilbert-base-uncased-finetuned-sst-2-english"
30
- )
31
-
32
-
33
- def load_vqa_model():
34
- from transformers import BlipProcessor, BlipForQuestionAnswering
35
- model_name = "Salesforce/blip-vqa-base"
36
- processor = BlipProcessor.from_pretrained(model_name)
37
- model = BlipForQuestionAnswering.from_pretrained(model_name).to(device)
38
- return processor, model
39
-
40
-
41
- def load_detr_model():
42
- from transformers import DetrImageProcessor, DetrForObjectDetection
43
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
44
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
45
- return processor, model
46
-
47
-
48
- def load_vit_model():
49
- from transformers import ViTImageProcessor, ViTForImageClassification
50
- model_name = "google/vit-base-patch16-224"
51
- processor = ViTImageProcessor.from_pretrained(model_name)
52
- model = ViTForImageClassification.from_pretrained(model_name).to(device)
53
- return processor, model
54
-
55
-
56
- def load_llm():
57
- from transformers import AutoTokenizer, AutoModelForCausalLM
58
- name = "gpt2"
59
- tokenizer = AutoTokenizer.from_pretrained(name)
60
- model = AutoModelForCausalLM.from_pretrained(name).to(device)
61
- return tokenizer, model
62
-
63
-
64
- # ---------------------------------------------------------
65
- # TASKS
66
- # ---------------------------------------------------------
67
-
68
- def generate_caption(image):
69
- processor, model = load_caption_model()
70
- inputs = processor(images=image, return_tensors="pt").to(device)
71
- with torch.no_grad():
72
- out_ids = model.generate(**inputs, max_new_tokens=30)
73
- return processor.decode(out_ids[0], skip_special_tokens=True)
74
-
75
-
76
- def analyze_sentiment(text):
77
- sentiment = load_sentiment_model()
78
- out = sentiment(text)[0]
79
- return out["label"], round(out["score"] * 100, 2)
80
-
81
-
82
- def vqa_answer(image, question):
83
- processor, model = load_vqa_model()
84
- inputs = processor(images=image, text=question, return_tensors="pt").to(device)
85
- with torch.no_grad():
86
- out = model.generate(**inputs)
87
- return processor.decode(out[0], skip_special_tokens=True)
88
-
89
-
90
- def detect_objects(image):
91
- processor, model = load_detr_model()
92
- inputs = processor(images=image, return_tensors="pt").to(device)
93
-
94
- with torch.no_grad():
95
- outputs = model(**inputs)
96
-
97
- target_sizes = torch.tensor([image.size[::-1]])
98
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
99
-
100
- detections = []
101
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
102
- if score > 0.3:
103
- detections.append(
104
- f"{model.config.id2label[label.item()]} (score {round(score.item(), 2)})"
105
- )
106
- if len(detections) == 0:
107
- return ["No high-confidence objects detected"]
108
- return detections
109
-
110
-
111
- def classify_scene(image):
112
- processor, model = load_vit_model()
113
- inputs = processor(images=image, return_tensors="pt").to(device)
114
- with torch.no_grad():
115
- logits = model(**inputs).logits
116
- label = logits.argmax(-1).item()
117
- return model.config.id2label[label]
118
-
119
-
120
- def rewrite_caption(caption, style):
121
- tokenizer, model = load_llm()
122
-
123
- if style == "Short":
124
- prompt = f"Summarize: {caption}"
125
- elif style == "Creative":
126
- prompt = f"Rewrite creatively: {caption}"
127
- elif style == "Technical":
128
- prompt = f"Rewrite in technical detail: {caption}"
129
- else:
130
- prompt = caption
131
-
132
- inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
133
- with torch.no_grad():
134
- outputs = model.generate(inputs, max_new_tokens=60)
135
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
136
-
137
-
138
- def extract_metadata(image):
139
- width, height = image.size
140
- meta = f"Dimensions: {width} x {height}\n"
141
- meta += "EXIF data detected\n" if "exif" in image.info else "No EXIF data available\n"
142
- return meta
143
-
144
-
145
- # ---------------------------------------------------------
146
- # MAIN LOOP
147
- # ---------------------------------------------------------
148
-
149
- def process_all(image, question, style):
150
- if image is None:
151
- return ["No image"] * 8
152
-
153
- caption = generate_caption(image)
154
- sentiment_label, sentiment_score = analyze_sentiment(caption)
155
- vqa = vqa_answer(image, question) if question else "No question asked"
156
- objects = detect_objects(image)
157
- scene = classify_scene(image)
158
- rewritten = rewrite_caption(caption, style)
159
- metadata = extract_metadata(image)
160
-
161
- return caption, sentiment_label, sentiment_score, vqa, objects, scene, rewritten, metadata
162
-
163
-
164
- # ---------------------------------------------------------
165
- # GRADIO UI
166
- # ---------------------------------------------------------
167
-
168
- with gr.Blocks(title="Multimodal AI System (Lazy Loaded)") as demo:
169
- gr.Markdown("# **Multimodal AI System**")
170
-
171
- with gr.Row():
172
- image_input = gr.Image(type="pil", label="Upload Image")
173
- question_input = gr.Textbox(label="Ask a Question")
174
- style_input = gr.Dropdown(["Short", "Creative", "Technical"], label="Caption Style")
175
-
176
- run_btn = gr.Button("Run All Tools")
177
-
178
- caption = gr.Textbox(label="Generated Caption")
179
- sentiment_label = gr.Textbox(label="Sentiment Label")
180
- sentiment_score = gr.Number(label="Sentiment Score")
181
- vqa_output = gr.Textbox(label="VQA Answer")
182
- objects_output = gr.JSON(label="Detected Objects")
183
- scene_output = gr.Textbox(label="Scene Classification")
184
- rewritten_output = gr.Textbox(label="Rewritten Caption")
185
- metadata_output = gr.Textbox(label="Image Metadata")
186
-
187
- run_btn.click(
188
- process_all,
189
- [image_input, question_input, style_input],
190
  [
191
- caption,
192
- sentiment_label,
193
- sentiment_score,
194
- vqa_output,
195
- objects_output,
196
- scene_output,
197
- rewritten_output,
198
- metadata_output
199
- ]
200
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
-
203
- if __name__ == "__main__":
204
- demo.launch()
 
1
 
 
 
 
 
 
 
2
  import gradio as gr
3
+ from transformers import pipeline
4
+ from PIL import ImageDraw, ImageFont
5
+ import textwrap
6
+
7
+ # --- LOAD MODELS ---
8
+ print("Loading Models...")
9
+ caption_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
10
+ classification_pipeline = pipeline("image-classification", model="google/vit-base-patch16-224")
11
+ sentiment_pipeline = pipeline("sentiment-analysis")
12
+
13
+ # --- DRAWING FUNCTION ---
14
+ def add_caption_to_image(image, text):
15
+ draw = ImageDraw.Draw(image)
16
+ image_width, image_height = image.size
17
+
18
+ # 1. Setup Font
19
+ try:
20
+ font = ImageFont.truetype("DejaVuSans.ttf", 20)
21
+ except IOError:
22
+ font = ImageFont.load_default()
23
+
24
+ # 2. Wrap Text
25
+ avg_char_width = 12
26
+ chars_per_line = max(10, int((image_width - 40) / avg_char_width))
27
+ lines = textwrap.wrap(text, width=chars_per_line)
28
+
29
+ # 3. Calculate Box Size
30
+ line_height = 24
31
+ total_text_height = len(lines) * line_height
32
+ y_start = image_height - total_text_height - 20
33
+
34
+ max_line_width = 0
35
+ for line in lines:
36
+ bbox = draw.textbbox((0, 0), line, font=font)
37
+ w = bbox[2] - bbox[0]
38
+ if w > max_line_width: max_line_width = w
39
+
40
+ box_x = (image_width - max_line_width) / 2
41
+
42
+ # 4. Draw Box
43
+ padding = 10
44
+ draw.rectangle(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  [
46
+ (box_x - padding, y_start - padding),
47
+ (box_x + max_line_width + padding, y_start + total_text_height + padding)
48
+ ],
49
+ fill=(0, 0, 0, 180)
 
 
 
 
 
50
  )
51
+
52
+ # 5. Draw Text
53
+ current_y = y_start
54
+ for line in lines:
55
+ bbox = draw.textbbox((0, 0), line, font=font)
56
+ line_width = bbox[2] - bbox[0]
57
+ line_x = (image_width - line_width) / 2
58
+ draw.text((line_x, current_y), line, font=font, fill="white")
59
+ current_y += line_height
60
+
61
+ return image
62
+
63
+ # --- ANALYSIS FUNCTION ---
64
+ def multimodal_analysis(input_image):
65
+ if input_image is None: return None, "Upload image first", "N/A"
66
+
67
+ processed_image = input_image.copy()
68
+
69
+ # 1. Caption
70
+ try:
71
+ caption = caption_pipeline(input_image)[0]['generated_text']
72
+ except:
73
+ return processed_image, "Error", "Error"
74
+
75
+ # 2. Draw
76
+ final_img = add_caption_to_image(processed_image, caption)
77
+
78
+ # 3. Classify
79
+ try:
80
+ res = classification_pipeline(input_image)
81
+ cls_str = f"{res[0]['label']} ({res[0]['score']:.2f})"
82
+ except:
83
+ cls_str = "Error"
84
+
85
+ # 4. Sentiment
86
+ try:
87
+ sent = sentiment_pipeline(caption)[0]['label']
88
+ except:
89
+ sent = "Error"
90
+
91
+ return final_img, cls_str, sent
92
+
93
+ # --- INTERFACE (Removed Theme to fix crash) ---
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown("# 🤖 Multimodal AI Analyst")
96
+ gr.Markdown("Select an example image below to see: **Image Captioning**, **Vision Classification**, and **NLP Sentiment Analysis** working together.")
97
+
98
+ with gr.Row():
99
+ with gr.Column():
100
+ image_input = gr.Image(type="pil", label="Input Image")
101
+ submit_btn = gr.Button("🔍 Analyze Image", variant="primary")
102
+
103
+ with gr.Column():
104
+ output_image = gr.Image(label="AI Caption Result")
105
+ with gr.Row():
106
+ output_class = gr.Textbox(label="Object Class")
107
+ output_sent = gr.Textbox(label="Caption Sentiment")
108
+
109
+ # EXACT FILES FROM YOUR LIST
110
+ examples = [
111
+ ["Ashe Catcum with Pikachu.png"],
112
+ ["Beautiful sunrise over ocean.png"],
113
+ ["Cat on a couch.png"],
114
+ ["Female Crying.png"],
115
+ ["Lions Football team huddle.png"],
116
+ ["michael jordan trophy.png"],
117
+ ["Puppies playing in grass.png"],
118
+ ["Red Ferrari.png"],
119
+ ["Siamese cat.png"],
120
+ ["Stormy dark sky lightning.png"]
121
+ ]
122
+
123
+ gr.Examples(examples=examples, inputs=image_input)
124
+ submit_btn.click(fn=multimodal_analysis, inputs=image_input, outputs=[output_image, output_class, output_sent])
125
+
126
+ demo.launch()
127