Spaces:
Build error
Build error
File size: 983 Bytes
f32758e a86b90c 05991ba a86b90c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, Resize
inference_transforms = Compose(
[
Resize((224, 224)),
ToTensor(),
normalize
]
)
import matplotlib.pyplot as plt
def caption_image(m, path):
if 'http' in path:
response = requests.get(path)
img = Image.open(BytesIO(response.content))
else:
img = Image.open(path)
img_matrix = inference_transforms(img).unsqueeze(0)
generated = m.generate(
img_matrix,
num_beams=3,
max_length=15,
early_stopping=True,
do_sample=True,
top_k=10,
num_return_sequences=5,
)
caption_options = [arabert_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
display(img)
plt.show()
return caption_options, generated, img_matrix
captions, generated, image_matrix = caption_image(
finetuned_model, '/content/1.jpg'
)
captions |