Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -126,10 +126,56 @@ def predict(image, question):
|
|
| 126 |
traceback.print_exc()
|
| 127 |
#return f"An error occurred: {str(e)}"
|
| 128 |
return f"An error occurred: {traceback.format_exc()}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
# 4. Gradio Interface
|
| 131 |
iface = gr.Interface(
|
| 132 |
-
fn=
|
| 133 |
inputs=[
|
| 134 |
gr.Image(label="Upload an Image"),
|
| 135 |
gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")
|
|
|
|
| 126 |
traceback.print_exc()
|
| 127 |
#return f"An error occurred: {str(e)}"
|
| 128 |
return f"An error occurred: {traceback.format_exc()}"
|
| 129 |
+
|
| 130 |
+
# 3. Inference Function
|
| 131 |
+
@spaces.GPU
|
| 132 |
+
def predict1(image_input, question):
|
| 133 |
+
"""
|
| 134 |
+
Takes an image and a question as input and returns an answer.
|
| 135 |
+
"""
|
| 136 |
+
if image_input is None or question is None or question == "":
|
| 137 |
+
return "Please provide both an image and a question."
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
image = Image.fromarray(image_input).convert("RGB")
|
| 141 |
+
image = image_transform(image).unsqueeze(0).to(device)
|
| 142 |
+
|
| 143 |
+
prompt = f"Question: {question}\nAnswer:"
|
| 144 |
+
encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
|
| 145 |
+
|
| 146 |
+
# Pass the image and encoded prompt to the model
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
# Get image embeddings
|
| 149 |
+
image_embeddings = model.image_encoder(image)
|
| 150 |
+
projected_image_embeddings = model.image_projection(image_embeddings)
|
| 151 |
+
|
| 152 |
+
# Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
|
| 153 |
+
projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
|
| 154 |
+
|
| 155 |
+
# Concatenate along the sequence dimension (dim=1)
|
| 156 |
+
extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=encoded["attention_mask"].device), encoded["attention_mask"]], dim=1)
|
| 157 |
+
extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=encoded["input_ids"].device), encoded["input_ids"]], dim=1)
|
| 158 |
+
|
| 159 |
+
# Generate answer
|
| 160 |
+
generated_tokens = model.generate(
|
| 161 |
+
input_ids=extended_input_ids,
|
| 162 |
+
attention_mask=extended_attention_mask,
|
| 163 |
+
max_length=200,
|
| 164 |
+
pad_token_id=text_tokenizer.eos_token_id
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
| 168 |
+
answer = answer.replace(prompt, "").strip() # Remove prompt from answer
|
| 169 |
+
|
| 170 |
+
return answer
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
#return f"An error occurred: {str(e)}"
|
| 174 |
+
return f"An error occurred: {traceback.format_exc()}"
|
| 175 |
|
| 176 |
# 4. Gradio Interface
|
| 177 |
iface = gr.Interface(
|
| 178 |
+
fn=predict1,
|
| 179 |
inputs=[
|
| 180 |
gr.Image(label="Upload an Image"),
|
| 181 |
gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")
|