mayaram's picture
Update app.py
7e7a932
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, Resize
inference_transforms = Compose(
[
Resize((224, 224)),
ToTensor(),
normalize
]
)
import matplotlib.pyplot as plt
def caption_image(m, path):
if 'http' in path:
response = requests.get(path)
img = Image.open(BytesIO(response.content))
else:
img = Image.open(path)
img_matrix = inference_transforms(img).unsqueeze(0)
generated = m.generate(
img_matrix,
num_beams=3,
max_length=15,
early_stopping=True,
do_sample=True,
top_k=10,
num_return_sequences=5,
)
caption_options = [arabert_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
display(img)
plt.show()
return caption_options, generated, img_matrix
captions, generated, image_matrix = caption_image(
finetuned_model, '/content/1.jpg'
)
captions