ProfRom commited on
Commit
c39a86b
·
verified ·
1 Parent(s): 75e18d7

Poojary Sanity Check 1

Browse files
Files changed (1) hide show
  1. app.py +35 -202
app.py CHANGED
@@ -1,203 +1,36 @@
1
- # app.py — Lazy Loaded Multimodal AI System
2
- #
3
- # Models load ONLY when needed to avoid memory overflow
4
- # Works on Hugging Face free CPU Spaces
5
 
6
- import torch
7
- import gradio as gr
8
-
9
- device = torch.device("cpu")
10
-
11
-
12
- # ---------------------------------------------------------
13
- # LAZY MODEL LOADERS
14
- # ---------------------------------------------------------
15
-
16
- def load_caption_model():
17
- from transformers import BlipProcessor, BlipForConditionalGeneration
18
- model_name = "Salesforce/blip-image-captioning-base"
19
- processor = BlipProcessor.from_pretrained(model_name)
20
- model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)
21
- return processor, model
22
-
23
-
24
- def load_sentiment_model():
25
- from transformers import pipeline
26
- return pipeline(
27
- "sentiment-analysis",
28
- model="distilbert-base-uncased-finetuned-sst-2-english"
29
- )
30
-
31
-
32
- def load_vqa_model():
33
- from transformers import BlipProcessor, BlipForQuestionAnswering
34
- model_name = "Salesforce/blip-vqa-base"
35
- processor = BlipProcessor.from_pretrained(model_name)
36
- model = BlipForQuestionAnswering.from_pretrained(model_name).to(device)
37
- return processor, model
38
-
39
-
40
- def load_detr_model():
41
- from transformers import DetrImageProcessor, DetrForObjectDetection
42
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
43
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
44
- return processor, model
45
-
46
-
47
- def load_vit_model():
48
- from transformers import ViTImageProcessor, ViTForImageClassification
49
- model_name = "google/vit-base-patch16-224"
50
- processor = ViTImageProcessor.from_pretrained(model_name)
51
- model = ViTForImageClassification.from_pretrained(model_name).to(device)
52
- return processor, model
53
-
54
-
55
- def load_llm():
56
- from transformers import AutoTokenizer, AutoModelForCausalLM
57
- name = "gpt2"
58
- tokenizer = AutoTokenizer.from_pretrained(name)
59
- model = AutoModelForCausalLM.from_pretrained(name).to(device)
60
- return tokenizer, model
61
-
62
-
63
- # ---------------------------------------------------------
64
- # TASK FUNCTIONS
65
- # ---------------------------------------------------------
66
-
67
- def generate_caption(image):
68
- processor, model = load_caption_model()
69
- inputs = processor(images=image, return_tensors="pt").to(device)
70
- with torch.no_grad():
71
- out_ids = model.generate(**inputs, max_new_tokens=30)
72
- return processor.decode(out_ids[0], skip_special_tokens=True)
73
-
74
-
75
- def analyze_sentiment(text):
76
- sentiment = load_sentiment_model()
77
- out = sentiment(text)[0]
78
- return out["label"], round(out["score"] * 100, 2)
79
-
80
-
81
- def vqa_answer(image, question):
82
- processor, model = load_vqa_model()
83
- inputs = processor(images=image, text=question, return_tensors="pt").to(device)
84
- with torch.no_grad():
85
- out = model.generate(**inputs)
86
- return processor.decode(out[0], skip_special_tokens=True)
87
-
88
-
89
- def detect_objects(image):
90
- processor, model = load_detr_model()
91
- inputs = processor(images=image, return_tensors="pt").to(device)
92
-
93
- with torch.no_grad():
94
- outputs = model(**inputs)
95
-
96
- target_sizes = torch.tensor([image.size[::-1]])
97
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
98
-
99
- detections = []
100
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
101
- if score > 0.3:
102
- detections.append(
103
- f"{model.config.id2label[label.item()]} (score {round(score.item(), 2)})"
104
- )
105
- if len(detections) == 0:
106
- return ["No high-confidence objects detected"]
107
- return detections
108
-
109
-
110
- def classify_scene(image):
111
- processor, model = load_vit_model()
112
- inputs = processor(images=image, return_tensors="pt").to(device)
113
- with torch.no_grad():
114
- logits = model(**inputs).logits
115
- label = logits.argmax(-1).item()
116
- return model.config.id2label[label]
117
-
118
-
119
- def rewrite_caption(caption, style):
120
- tokenizer, model = load_llm()
121
-
122
- if style == "Short":
123
- prompt = f"Summarize: {caption}"
124
- elif style == "Creative":
125
- prompt = f"Rewrite creatively: {caption}"
126
- elif style == "Technical":
127
- prompt = f"Rewrite in technical detail: {caption}"
128
- else:
129
- prompt = caption
130
-
131
- inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
132
- with torch.no_grad():
133
- outputs = model.generate(inputs, max_new_tokens=60)
134
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
135
-
136
-
137
- def extract_metadata(image):
138
- width, height = image.size
139
- meta = f"Dimensions: {width} x {height}\n"
140
- meta += "EXIF data detected\n" if "exif" in image.info else "No EXIF data available\n"
141
- return meta
142
-
143
-
144
- # ---------------------------------------------------------
145
- # MAIN LOGIC
146
- # ---------------------------------------------------------
147
-
148
- def process_all(image, question, style):
149
- if image is None:
150
- return ["No image"] * 8
151
-
152
- caption = generate_caption(image)
153
- sentiment_label, sentiment_score = analyze_sentiment(caption)
154
- vqa = vqa_answer(image, question) if question else "No question asked"
155
- objects = detect_objects(image)
156
- scene = classify_scene(image)
157
- rewritten = rewrite_caption(caption, style)
158
- metadata = extract_metadata(image)
159
-
160
- return caption, sentiment_label, sentiment_score, vqa, objects, scene, rewritten, metadata
161
-
162
-
163
- # ---------------------------------------------------------
164
- # GRADIO UI - BLOCKS
165
- # ---------------------------------------------------------
166
-
167
- with gr.Blocks(title="Multimodal AI System (Lazy Loaded)") as demo:
168
- gr.Markdown("# **Multimodal AI System (Emotion Removed)**")
169
-
170
- with gr.Row():
171
- image_input = gr.Image(type="pil", label="Upload Image")
172
- question_input = gr.Textbox(label="Ask a Question")
173
- style_input = gr.Dropdown(["Short", "Creative", "Technical"], label="Caption Style")
174
-
175
- run_btn = gr.Button("Run All AI Tools")
176
-
177
- caption = gr.Textbox(label="Generated Caption")
178
- sentiment_label = gr.Textbox(label="Sentiment Label")
179
- sentiment_score = gr.Number(label="Sentiment Score")
180
- vqa_output = gr.Textbox(label="VQA Answer")
181
- objects_output = gr.JSON(label="Detected Objects")
182
- scene_output = gr.Textbox(label="Scene Classification")
183
- rewritten_output = gr.Textbox(label="Rewritten Caption")
184
- metadata_output = gr.Textbox(label="Image Metadata")
185
-
186
- run_btn.click(
187
- process_all,
188
- [image_input, question_input, style_input],
189
- [
190
- caption,
191
- sentiment_label,
192
- sentiment_score,
193
- vqa_output,
194
- objects_output,
195
- scene_output,
196
- rewritten_output,
197
- metadata_output
198
- ]
199
- )
200
-
201
-
202
- if __name__ == "__main__":
203
- demo.launch()
 
 
 
 
 
1
 
2
+ #define model and processor
3
+ processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
4
+ model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
5
+ device = infer_device()
6
+
7
+ # Define inference function
8
+ def process_image(image, prompt):
9
+ # Process the image and prompt using the processor
10
+ inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)
11
+
12
+ try:
13
+ # Generate output from the model
14
+ output = model.generate(**inputs, max_new_tokens=10)
15
+
16
+ # Decode and return the output
17
+ decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0].strip()
18
+
19
+ #remove prompt from output
20
+ if decoded_output.startswith(prompt):
21
+ return decoded_output[len(prompt):].strip()
22
+ return decoded_output
23
+ except IndexError as e:
24
+ print(f"IndexError: {e}")
25
+ return "An error occurred during processing."
26
+
27
+ # Define the Gradio interface
28
+ inputs = [
29
+ gr.Image(type="pil"),
30
+ gr.Textbox(label="Prompt", placeholder="Enter your question")
31
+ ]
32
+ outputs = gr.Textbox(label="Answer")
33
+ # Create the Gradio app
34
+ demo = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Visual Question Answering", description="Upload an image and ask questions to get answers.")
35
+ # Launch the app
36
+ demo.launch()