Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,9 +25,12 @@ def generate_caption(image):
|
|
| 25 |
])
|
| 26 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
caption = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 32 |
return caption
|
| 33 |
|
|
|
|
| 25 |
])
|
| 26 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 27 |
|
| 28 |
+
# Prepend the fixed prompt to the input tensor
|
| 29 |
+
fixed_prompt_tensor = tokenizer(FIXED_PROMPT, return_tensors="pt").input_ids.to(device)
|
| 30 |
+
input_tensor = torch.cat((fixed_prompt_tensor, image_tensor), dim=1)
|
| 31 |
+
|
| 32 |
+
# Generate caption
|
| 33 |
+
output = model.generate(pixel_values=image_tensor)
|
| 34 |
caption = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 35 |
return caption
|
| 36 |
|