Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -115,21 +115,33 @@ def predict(image_input, question):
|
|
| 115 |
encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
|
| 116 |
|
| 117 |
with torch.no_grad():
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
| 129 |
answer = answer.replace(prompt, "").strip() # Remove prompt from answer
|
| 130 |
|
| 131 |
return answer
|
| 132 |
|
|
|
|
|
|
|
|
|
|
| 133 |
except Exception as e:
|
| 134 |
#return f"An error occurred: {str(e)}"
|
| 135 |
return f"An error occurred: {traceback.format_exc()}"
|
|
|
|
| 115 |
encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
|
| 116 |
|
| 117 |
with torch.no_grad():
|
| 118 |
+
# Get image embeddings
|
| 119 |
+
image_embeddings = model.image_encoder(image)
|
| 120 |
+
projected_image_embeddings = model.image_projection(image_embeddings)
|
| 121 |
+
|
| 122 |
+
# Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
|
| 123 |
+
projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
|
| 124 |
+
|
| 125 |
+
# Concatenate along the sequence dimension (dim=1)
|
| 126 |
+
extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=encoded["attention_mask"].device), encoded["attention_mask"]], dim=1)
|
| 127 |
+
extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=encoded["input_ids"].device), encoded["input_ids"]], dim=1)
|
| 128 |
+
|
| 129 |
+
# Generate answer
|
| 130 |
+
generated_tokens = model.phi3.generate(
|
| 131 |
+
input_ids=extended_input_ids,
|
| 132 |
+
attention_mask=extended_attention_mask,
|
| 133 |
+
max_length=200,
|
| 134 |
+
pad_token_id=text_tokenizer.eos_token_id
|
| 135 |
+
)
|
| 136 |
|
| 137 |
answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
| 138 |
answer = answer.replace(prompt, "").strip() # Remove prompt from answer
|
| 139 |
|
| 140 |
return answer
|
| 141 |
|
| 142 |
+
except Exception as e:
|
| 143 |
+
return f"An error occurred: {str(e)}"
|
| 144 |
+
|
| 145 |
except Exception as e:
|
| 146 |
#return f"An error occurred: {str(e)}"
|
| 147 |
return f"An error occurred: {traceback.format_exc()}"
|