Spaces:
Build error
Build error
| from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, Resize | |
| inference_transforms = Compose( | |
| [ | |
| Resize((224, 224)), | |
| ToTensor(), | |
| normalize | |
| ] | |
| ) | |
| import matplotlib.pyplot as plt | |
| def caption_image(m, path): | |
| if 'http' in path: | |
| response = requests.get(path) | |
| img = Image.open(BytesIO(response.content)) | |
| else: | |
| img = Image.open(path) | |
| img_matrix = inference_transforms(img).unsqueeze(0) | |
| generated = m.generate( | |
| img_matrix, | |
| num_beams=3, | |
| max_length=15, | |
| early_stopping=True, | |
| do_sample=True, | |
| top_k=10, | |
| num_return_sequences=5, | |
| ) | |
| caption_options = [arabert_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated] | |
| display(img) | |
| plt.show() | |
| return caption_options, generated, img_matrix | |
| captions, generated, image_matrix = caption_image( | |
| finetuned_model, '/content/1.jpg' | |
| ) | |
| captions |