adityaardak commited on
Commit
9c88c94
·
verified ·
1 Parent(s): 4a30650

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -62
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- # ---- CPU-only config ----
 
 
7
  MID = "apple/FastVLM-0.5B"
8
- IMAGE_TOKEN_INDEX = -200 # special image token id used by FastVLM
9
 
10
  tok = None
11
  model = None
@@ -13,50 +14,43 @@ model = None
13
  def load_model():
14
  global tok, model
15
  if tok is None or model is None:
16
- print("Loading model (CPU)…")
17
  tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
18
- # Force CPU + float32 (fp16 is unsafe on CPU)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  MID,
21
  torch_dtype=torch.float32,
22
  device_map="cpu",
23
  trust_remote_code=True,
24
  )
25
- print("Model loaded successfully on CPU!")
26
  return tok, model
27
 
28
- def caption_image(image, custom_prompt=None):
29
- """
30
- Generate a caption for the input image (CPU-only).
31
- """
32
  if image is None:
33
  return "Please upload an image first."
34
 
35
  try:
36
  tok, model = load_model()
37
 
38
- # Convert image to RGB if needed
39
  if image.mode != "RGB":
40
  image = image.convert("RGB")
41
 
42
- prompt = custom_prompt if custom_prompt else "Describe this image in detail."
43
-
44
- # Single-turn chat with an <image> placeholder
45
  messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
46
- rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 
 
 
 
47
 
48
- # Split around the literal "<image>"
49
  pre, post = rendered.split("<image>", 1)
50
 
51
- # Tokenize text around the image token
52
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
53
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
54
 
55
- # Derive device/dtype from the loaded model (CPU here, but future-proof)
56
  model_device = next(model.parameters()).device
57
  model_dtype = next(model.parameters()).dtype
58
 
59
- # Insert IMAGE token id at placeholder position
60
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype, device=model_device)
61
  input_ids = torch.cat(
62
  [pre_ids.to(model_device), img_tok, post_ids.to(model_device)],
@@ -64,23 +58,22 @@ def caption_image(image, custom_prompt=None):
64
  )
65
  attention_mask = torch.ones_like(input_ids, device=model_device)
66
 
67
- # Preprocess image using model's vision tower
68
- px = model.get_vision_tower().image_processor(
69
- images=image, return_tensors="pt"
70
  )["pixel_values"].to(device=model_device, dtype=model_dtype)
71
 
72
- # Generate caption (deterministic)
73
  with torch.no_grad():
74
  out = model.generate(
75
  inputs=input_ids,
76
  attention_mask=attention_mask,
77
- images=px,
78
- max_new_tokens=128,
79
- do_sample=False, # temperature is ignored when sampling is off
80
  )
81
 
82
- # Decode and slice to the assistant part if present
83
  generated_text = tok.decode(out[0], skip_special_tokens=True)
 
84
  if "Assistant:" in generated_text:
85
  response = generated_text.split("Assistant:", 1)[-1].strip()
86
  elif "assistant" in generated_text:
@@ -91,53 +84,160 @@ def caption_image(image, custom_prompt=None):
91
  return response
92
 
93
  except Exception as e:
94
- return f"Error generating caption: {str(e)}"
95
-
96
- # ---- Gradio UI (CPU) ----
97
- with gr.Blocks(title="FastVLM Image Captioning (CPU)") as demo:
98
- gr.Markdown(
99
- """
100
- # 🖼️ FastVLM Image Captioning (CPU)
101
- Upload an image to generate a detailed caption using Apple's FastVLM-0.5B.
102
- This build runs on **CPU only**. Expect slower generation than GPU.
103
- """
104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  with gr.Row():
107
- with gr.Column():
108
- image_input = gr.Image(type="pil", label="Upload Image", elem_id="image-upload")
109
- custom_prompt = gr.Textbox(
110
- label="Custom Prompt (Optional)",
111
- placeholder="Leave empty for default: 'Describe this image in detail.'",
 
 
 
 
 
 
 
 
 
 
 
 
112
  lines=2
113
  )
 
114
  with gr.Row():
115
- clear_btn = gr.ClearButton([image_input, custom_prompt])
116
- generate_btn = gr.Button("Generate Caption", variant="primary")
 
 
 
 
 
 
 
 
117
 
118
- with gr.Column():
119
  output = gr.Textbox(
120
- label="Generated Caption",
121
- lines=8,
122
- max_lines=15,
123
  show_copy_button=True
124
  )
125
 
126
- generate_btn.click(fn=caption_image, inputs=[image_input, custom_prompt], outputs=output)
127
-
128
- # Also generate on image upload if no custom prompt
129
- def _auto_caption(img, prompt):
130
- return caption_image(img, prompt) if (img is not None and not prompt) else None
131
 
132
- image_input.change(fn=_auto_caption, inputs=[image_input, custom_prompt], outputs=output)
 
 
 
 
 
 
133
 
134
- gr.Markdown(
135
- """
136
- ---
137
- **Model:** [apple/FastVLM-0.5B](https://huggingface.co/apple/FastVLM-0.5B)
138
- **Note:** CPU-only run. For speed, switch to a CUDA GPU build or a GPU Space.
139
- """
140
- )
141
 
142
  if __name__ == "__main__":
143
  demo.launch(
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # -----------------------------
6
+ # Model configuration
7
+ # -----------------------------
8
  MID = "apple/FastVLM-0.5B"
9
+ IMAGE_TOKEN_INDEX = -200
10
 
11
  tok = None
12
  model = None
 
14
  def load_model():
15
  global tok, model
16
  if tok is None or model is None:
17
+ print("Loading model on CPU...")
18
  tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  MID,
21
  torch_dtype=torch.float32,
22
  device_map="cpu",
23
  trust_remote_code=True,
24
  )
25
+ print("Model loaded successfully!")
26
  return tok, model
27
 
28
+
29
+ def run_fastvlm(image, prompt):
 
 
30
  if image is None:
31
  return "Please upload an image first."
32
 
33
  try:
34
  tok, model = load_model()
35
 
 
36
  if image.mode != "RGB":
37
  image = image.convert("RGB")
38
 
 
 
 
39
  messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
40
+ rendered = tok.apply_chat_template(
41
+ messages,
42
+ add_generation_prompt=True,
43
+ tokenize=False
44
+ )
45
 
 
46
  pre, post = rendered.split("<image>", 1)
47
 
 
48
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
49
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
50
 
 
51
  model_device = next(model.parameters()).device
52
  model_dtype = next(model.parameters()).dtype
53
 
 
54
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype, device=model_device)
55
  input_ids = torch.cat(
56
  [pre_ids.to(model_device), img_tok, post_ids.to(model_device)],
 
58
  )
59
  attention_mask = torch.ones_like(input_ids, device=model_device)
60
 
61
+ pixel_values = model.get_vision_tower().image_processor(
62
+ images=image,
63
+ return_tensors="pt"
64
  )["pixel_values"].to(device=model_device, dtype=model_dtype)
65
 
 
66
  with torch.no_grad():
67
  out = model.generate(
68
  inputs=input_ids,
69
  attention_mask=attention_mask,
70
+ images=pixel_values,
71
+ max_new_tokens=220,
72
+ do_sample=False
73
  )
74
 
 
75
  generated_text = tok.decode(out[0], skip_special_tokens=True)
76
+
77
  if "Assistant:" in generated_text:
78
  response = generated_text.split("Assistant:", 1)[-1].strip()
79
  elif "assistant" in generated_text:
 
84
  return response
85
 
86
  except Exception as e:
87
+ return f"Error: {str(e)}"
88
+
89
+
90
+ def build_prompt(mode, user_context):
91
+ context_part = f"\nExtra user context: {user_context}" if user_context.strip() else ""
92
+
93
+ prompts = {
94
+ "Scene Description":
95
+ f"""
96
+ You are an AI assistant helping a visually impaired person.
97
+ Describe the image in simple, human-friendly language.
98
+
99
+ Return output in this format:
100
+ 1. Quick Summary
101
+ 2. Main Objects Seen
102
+ 3. Relative Position of Important Objects
103
+ 4. Helpful Note
104
+
105
+ Keep the language simple and practical.{context_part}
106
+ """,
107
+
108
+ "Hazard Detection":
109
+ f"""
110
+ You are an AI safety assistant helping a visually impaired person.
111
+ Analyze the image for possible hazards.
112
+
113
+ Return output in this format:
114
+ 1. Quick Summary
115
+ 2. Possible Hazards
116
+ 3. Risk Level (Low/Medium/High)
117
+ 4. Safety Advice
118
+
119
+ Be practical and avoid exaggeration.{context_part}
120
+ """,
121
+
122
+ "Important Object Summary":
123
+ f"""
124
+ You are an AI visual assistant.
125
+ Identify the most important objects in the image that a visually impaired person should know about.
126
+
127
+ Return output in this format:
128
+ 1. Key Objects
129
+ 2. What Looks Most Important
130
+ 3. Why These Objects Matter
131
+ 4. Short Spoken Summary
132
+
133
+ Keep it easy to understand.{context_part}
134
+ """,
135
+
136
+ "Safe Action Suggestion":
137
+ f"""
138
+ You are an AI guidance assistant for a visually impaired person.
139
+ Based on the image, suggest the next safest action.
140
+
141
+ Return output in this format:
142
+ 1. What the Scene Looks Like
143
+ 2. What Needs Attention
144
+ 3. Recommended Action
145
+ 4. One-Line Safety Tip
146
+
147
+ Do not assume too much. Give cautious guidance.{context_part}
148
+ """
149
+ }
150
+
151
+ return prompts.get(mode, prompts["Scene Description"])
152
+
153
+
154
+ def analyze_image(image, mode, user_context):
155
+ if image is None:
156
+ return "Please upload an image."
157
+
158
+ prompt = build_prompt(mode, user_context)
159
+ return run_fastvlm(image, prompt)
160
+
161
+
162
+ def exhibition_pitch(mode):
163
+ pitches = {
164
+ "Scene Description":
165
+ "This mode explains the surrounding environment in simple words so a visually impaired person can understand the scene.",
166
+ "Hazard Detection":
167
+ "This mode checks whether the image contains obstacles or risky elements such as vehicles, stairs, clutter, or unsafe walking areas.",
168
+ "Important Object Summary":
169
+ "This mode highlights the most useful objects in the scene so the user can focus on what matters most.",
170
+ "Safe Action Suggestion":
171
+ "This mode provides the next practical action the user should consider, based on the visual situation."
172
+ }
173
+ return pitches.get(mode, "")
174
+
175
+
176
+ with gr.Blocks(title="VisionMate AI - Smart Visual Assistant") as demo:
177
+ gr.Markdown("""
178
+ # 👁️ VisionMate AI
179
+ ## Smart Visual Assistant for Visually Impaired People
180
+
181
+ Upload an image and let the AI explain the scene, identify hazards, summarize important objects, or suggest the safest next action.
182
+
183
+ ### Exhibition Theme
184
+ **AI for Social Good**
185
+ """)
186
 
187
  with gr.Row():
188
+ with gr.Column(scale=1):
189
+ image_input = gr.Image(type="pil", label="Upload Scene Image")
190
+
191
+ mode = gr.Radio(
192
+ choices=[
193
+ "Scene Description",
194
+ "Hazard Detection",
195
+ "Important Object Summary",
196
+ "Safe Action Suggestion"
197
+ ],
198
+ value="Scene Description",
199
+ label="Select Assistance Mode"
200
+ )
201
+
202
+ user_context = gr.Textbox(
203
+ label="Optional Context",
204
+ placeholder="Example: Person is walking alone on a road / indoor corridor / market area",
205
  lines=2
206
  )
207
+
208
  with gr.Row():
209
+ analyze_btn = gr.Button("Analyze Scene", variant="primary")
210
+ clear_btn = gr.ClearButton([image_input, user_context])
211
+
212
+ with gr.Column(scale=1):
213
+ mode_explanation = gr.Textbox(
214
+ label="Mode Purpose",
215
+ value=exhibition_pitch("Scene Description"),
216
+ interactive=False,
217
+ lines=4
218
+ )
219
 
 
220
  output = gr.Textbox(
221
+ label="AI Assistance Output",
222
+ lines=16,
223
+ max_lines=25,
224
  show_copy_button=True
225
  )
226
 
227
+ mode.change(fn=exhibition_pitch, inputs=mode, outputs=mode_explanation)
228
+ analyze_btn.click(fn=analyze_image, inputs=[image_input, mode, user_context], outputs=output)
 
 
 
229
 
230
+ gr.Markdown("""
231
+ ---
232
+ ### Suggested Demo Images for Exhibition
233
+ - A road with vehicles and pedestrians
234
+ - A classroom or hallway
235
+ - A kitchen or home environment
236
+ - A supermarket shelf or crowded place
237
 
238
+ ### Expected Impact
239
+ This project shows how computer vision and multimodal AI can improve accessibility and independence for visually impaired users.
240
+ """)
 
 
 
 
241
 
242
  if __name__ == "__main__":
243
  demo.launch(