Monimoy commited on
Commit
a34f5dd
·
verified ·
1 Parent(s): a6fd536

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -126,10 +126,56 @@ def predict(image, question):
126
  traceback.print_exc()
127
  #return f"An error occurred: {str(e)}"
128
  return f"An error occurred: {traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  # 4. Gradio Interface
131
  iface = gr.Interface(
132
- fn=predict,
133
  inputs=[
134
  gr.Image(label="Upload an Image"),
135
  gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")
 
126
  traceback.print_exc()
127
  #return f"An error occurred: {str(e)}"
128
  return f"An error occurred: {traceback.format_exc()}"
129
+
130
+ # 3. Inference Function
131
+ @spaces.GPU
132
+ def predict1(image_input, question):
133
+ """
134
+ Takes an image and a question as input and returns an answer.
135
+ """
136
+ if image_input is None or question is None or question == "":
137
+ return "Please provide both an image and a question."
138
+
139
+ try:
140
+ image = Image.fromarray(image_input).convert("RGB")
141
+ image = image_transform(image).unsqueeze(0).to(device)
142
+
143
+ prompt = f"Question: {question}\nAnswer:"
144
+ encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
145
+
146
+ # Pass the image and encoded prompt to the model
147
+ with torch.no_grad():
148
+ # Get image embeddings
149
+ image_embeddings = model.image_encoder(image)
150
+ projected_image_embeddings = model.image_projection(image_embeddings)
151
+
152
+ # Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
153
+ projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
154
+
155
+ # Concatenate along the sequence dimension (dim=1)
156
+ extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=encoded["attention_mask"].device), encoded["attention_mask"]], dim=1)
157
+ extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=encoded["input_ids"].device), encoded["input_ids"]], dim=1)
158
+
159
+ # Generate answer
160
+ generated_tokens = model.generate(
161
+ input_ids=extended_input_ids,
162
+ attention_mask=extended_attention_mask,
163
+ max_length=200,
164
+ pad_token_id=text_tokenizer.eos_token_id
165
+ )
166
+
167
+ answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
168
+ answer = answer.replace(prompt, "").strip() # Remove prompt from answer
169
+
170
+ return answer
171
+
172
+ except Exception as e:
173
+ #return f"An error occurred: {str(e)}"
174
+ return f"An error occurred: {traceback.format_exc()}"
175
 
176
  # 4. Gradio Interface
177
  iface = gr.Interface(
178
+ fn=predict1,
179
  inputs=[
180
  gr.Image(label="Upload an Image"),
181
  gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")