ProfRom commited on
Commit
5206179
·
verified ·
1 Parent(s): 01c9097

Gailey - Sanity Check 2

Browse files
Files changed (1) hide show
  1. app.py +202 -98
app.py CHANGED
@@ -1,100 +1,204 @@
1
 
2
- import gradio as gr
 
 
 
 
3
  import torch
4
- import os
5
- import tempfile
6
- from huggingface_hub import login
7
- from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering, infer_device, PaliGemmaForConditionalGeneration
8
- from accelerate import Accelerator
9
-
10
- # login to Hugging Face
11
- # login(token=os.getenv('HF_TOKEN'))
12
-
13
- # Set the device
14
- device = infer_device()
15
-
16
- # MODEL 1: BLIP-VQA
17
- processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
18
- model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
19
-
20
- # Define inference function for Model 1
21
- def process_image(image, prompt):
22
- inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
23
-
24
- try:
25
- # Generate output from the model
26
- output = model.generate(**inputs, max_new_tokens=10)
27
-
28
- # Decode and return the output
29
- decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0].strip()
30
-
31
- # remove prompt from output
32
- if decoded_output.startswith(prompt):
33
- return decoded_output[len(prompt):].strip()
34
- return decoded_output
35
- except Exception as e:
36
- print(f"Error in Model 1: {e}")
37
- return "An error occurred during processing for Model 1."
38
-
39
-
40
- # MODEL 2: PaliGemma
41
- processor2 = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
42
- model2 = PaliGemmaForConditionalGeneration.from_pretrained(
43
- "google/paligemma-3b-mix-224",
44
- torch_dtype=torch.bfloat16
45
- ).to(device)
46
-
47
-
48
- # Define inference function for Model 2
49
- def process_image2(image, prompt):
50
- inputs2 = processor2(
51
- text=prompt,
52
- images=image,
53
- return_tensors="pt"
54
- ).to(device, model2.dtype)
55
-
56
- try:
57
- output = model2.generate(**inputs2, max_new_tokens=10)
58
- decoded_output = processor2.batch_decode(
59
- output[:, inputs2["input_ids"].shape[1]:],
60
- skip_special_tokens=True
61
- )[0].strip()
62
-
63
- return decoded_output
64
- except Exception as e:
65
- print(f"Error in Model 2: {e}")
66
- return "An error occurred during processing for Model 2. Ensure your hardware supports bfloat16 or adjust the torch_dtype."
67
-
68
-
69
- # GRADIO INTERFACE
70
- inputs_model1 = [
71
- gr.Image(type="pil"),
72
- gr.Textbox(label="Prompt", placeholder="Enter your question")
73
- ]
74
- inputs_model2 = [
75
- gr.Image(type="pil"),
76
- gr.Textbox(label="Prompt", placeholder="Enter your question")
77
- ]
78
-
79
- outputs_model1 = gr.Textbox(label="Answer")
80
- outputs_model2 = gr.Textbox(label="Answer")
81
-
82
- # Create the Gradio apps for each model
83
- model1_inf = gr.Interface(
84
- fn=process_image,
85
- inputs=inputs_model1,
86
- outputs=outputs_model1,
87
- title="Model 1: BLIP-VQA-Base",
88
- description="Ask a question about the uploaded image using BLIP."
89
- )
90
-
91
- model2_inf = gr.Interface(
92
- fn=process_image2,
93
- inputs=inputs_model2,
94
- outputs=outputs_model2,
95
- title="Model 2: PaliGemma",
96
- description="Ask a question about the uploaded image using PaliGemma."
97
- )
98
-
99
- demo = gr.TabbedInterface([model1_inf, model2_inf],["Model 1 (BLIP)", "Model 2 (PaliGemma)"])
100
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ # app.py Lazy Loaded Multimodal AI System
3
+ #
4
+ # Models load ONLY when needed to avoid memory overflow
5
+ # Works on Hugging Face free CPU Spaces
6
+
7
  import torch
8
+ import gradio as gr
9
+
10
+ device = torch.device("cpu")
11
+
12
+
13
+ # ---------------------------------------------------------
14
+ # LAZY MODEL LOADERS
15
+ # ---------------------------------------------------------
16
+
17
+ def load_caption_model():
18
+ from transformers import BlipProcessor, BlipForConditionalGeneration
19
+ model_name = "Salesforce/blip-image-captioning-base"
20
+ processor = BlipProcessor.from_pretrained(model_name)
21
+ model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)
22
+ return processor, model
23
+
24
+
25
+ def load_sentiment_model():
26
+ from transformers import pipeline
27
+ return pipeline(
28
+ "sentiment-analysis",
29
+ model="distilbert-base-uncased-finetuned-sst-2-english"
30
+ )
31
+
32
+
33
+ def load_vqa_model():
34
+ from transformers import BlipProcessor, BlipForQuestionAnswering
35
+ model_name = "Salesforce/blip-vqa-base"
36
+ processor = BlipProcessor.from_pretrained(model_name)
37
+ model = BlipForQuestionAnswering.from_pretrained(model_name).to(device)
38
+ return processor, model
39
+
40
+
41
+ def load_detr_model():
42
+ from transformers import DetrImageProcessor, DetrForObjectDetection
43
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
44
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
45
+ return processor, model
46
+
47
+
48
+ def load_vit_model():
49
+ from transformers import ViTImageProcessor, ViTForImageClassification
50
+ model_name = "google/vit-base-patch16-224"
51
+ processor = ViTImageProcessor.from_pretrained(model_name)
52
+ model = ViTForImageClassification.from_pretrained(model_name).to(device)
53
+ return processor, model
54
+
55
+
56
+ def load_llm():
57
+ from transformers import AutoTokenizer, AutoModelForCausalLM
58
+ name = "gpt2"
59
+ tokenizer = AutoTokenizer.from_pretrained(name)
60
+ model = AutoModelForCausalLM.from_pretrained(name).to(device)
61
+ return tokenizer, model
62
+
63
+
64
+ # ---------------------------------------------------------
65
+ # TASKS
66
+ # ---------------------------------------------------------
67
+
68
+ def generate_caption(image):
69
+ processor, model = load_caption_model()
70
+ inputs = processor(images=image, return_tensors="pt").to(device)
71
+ with torch.no_grad():
72
+ out_ids = model.generate(**inputs, max_new_tokens=30)
73
+ return processor.decode(out_ids[0], skip_special_tokens=True)
74
+
75
+
76
+ def analyze_sentiment(text):
77
+ sentiment = load_sentiment_model()
78
+ out = sentiment(text)[0]
79
+ return out["label"], round(out["score"] * 100, 2)
80
+
81
+
82
+ def vqa_answer(image, question):
83
+ processor, model = load_vqa_model()
84
+ inputs = processor(images=image, text=question, return_tensors="pt").to(device)
85
+ with torch.no_grad():
86
+ out = model.generate(**inputs)
87
+ return processor.decode(out[0], skip_special_tokens=True)
88
+
89
+
90
+ def detect_objects(image):
91
+ processor, model = load_detr_model()
92
+ inputs = processor(images=image, return_tensors="pt").to(device)
93
+
94
+ with torch.no_grad():
95
+ outputs = model(**inputs)
96
+
97
+ target_sizes = torch.tensor([image.size[::-1]])
98
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
99
+
100
+ detections = []
101
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
102
+ if score > 0.3:
103
+ detections.append(
104
+ f"{model.config.id2label[label.item()]} (score {round(score.item(), 2)})"
105
+ )
106
+ if len(detections) == 0:
107
+ return ["No high-confidence objects detected"]
108
+ return detections
109
+
110
+
111
+ def classify_scene(image):
112
+ processor, model = load_vit_model()
113
+ inputs = processor(images=image, return_tensors="pt").to(device)
114
+ with torch.no_grad():
115
+ logits = model(**inputs).logits
116
+ label = logits.argmax(-1).item()
117
+ return model.config.id2label[label]
118
+
119
+
120
+ def rewrite_caption(caption, style):
121
+ tokenizer, model = load_llm()
122
+
123
+ if style == "Short":
124
+ prompt = f"Summarize: {caption}"
125
+ elif style == "Creative":
126
+ prompt = f"Rewrite creatively: {caption}"
127
+ elif style == "Technical":
128
+ prompt = f"Rewrite in technical detail: {caption}"
129
+ else:
130
+ prompt = caption
131
+
132
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
133
+ with torch.no_grad():
134
+ outputs = model.generate(inputs, max_new_tokens=60)
135
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
136
+
137
+
138
+ def extract_metadata(image):
139
+ width, height = image.size
140
+ meta = f"Dimensions: {width} x {height}\n"
141
+ meta += "EXIF data detected\n" if "exif" in image.info else "No EXIF data available\n"
142
+ return meta
143
+
144
+
145
+ # ---------------------------------------------------------
146
+ # MAIN LOOP
147
+ # ---------------------------------------------------------
148
+
149
+ def process_all(image, question, style):
150
+ if image is None:
151
+ return ["No image"] * 8
152
+
153
+ caption = generate_caption(image)
154
+ sentiment_label, sentiment_score = analyze_sentiment(caption)
155
+ vqa = vqa_answer(image, question) if question else "No question asked"
156
+ objects = detect_objects(image)
157
+ scene = classify_scene(image)
158
+ rewritten = rewrite_caption(caption, style)
159
+ metadata = extract_metadata(image)
160
+
161
+ return caption, sentiment_label, sentiment_score, vqa, objects, scene, rewritten, metadata
162
+
163
+
164
+ # ---------------------------------------------------------
165
+ # GRADIO UI
166
+ # ---------------------------------------------------------
167
+
168
+ with gr.Blocks(title="Multimodal AI System (Lazy Loaded)") as demo:
169
+ gr.Markdown("# **Multimodal AI System**")
170
+
171
+ with gr.Row():
172
+ image_input = gr.Image(type="pil", label="Upload Image")
173
+ question_input = gr.Textbox(label="Ask a Question")
174
+ style_input = gr.Dropdown(["Short", "Creative", "Technical"], label="Caption Style")
175
+
176
+ run_btn = gr.Button("Run All Tools")
177
+
178
+ caption = gr.Textbox(label="Generated Caption")
179
+ sentiment_label = gr.Textbox(label="Sentiment Label")
180
+ sentiment_score = gr.Number(label="Sentiment Score")
181
+ vqa_output = gr.Textbox(label="VQA Answer")
182
+ objects_output = gr.JSON(label="Detected Objects")
183
+ scene_output = gr.Textbox(label="Scene Classification")
184
+ rewritten_output = gr.Textbox(label="Rewritten Caption")
185
+ metadata_output = gr.Textbox(label="Image Metadata")
186
+
187
+ run_btn.click(
188
+ process_all,
189
+ [image_input, question_input, style_input],
190
+ [
191
+ caption,
192
+ sentiment_label,
193
+ sentiment_score,
194
+ vqa_output,
195
+ objects_output,
196
+ scene_output,
197
+ rewritten_output,
198
+ metadata_output
199
+ ]
200
+ )
201
+
202
+
203
+ if __name__ == "__main__":
204
+ demo.launch()