from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, Resize inference_transforms = Compose( [ Resize((224, 224)), ToTensor(), normalize ] ) import matplotlib.pyplot as plt def caption_image(m, path): if 'http' in path: response = requests.get(path) img = Image.open(BytesIO(response.content)) else: img = Image.open(path) img_matrix = inference_transforms(img).unsqueeze(0) generated = m.generate( img_matrix, num_beams=3, max_length=15, early_stopping=True, do_sample=True, top_k=10, num_return_sequences=5, ) caption_options = [arabert_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated] display(img) plt.show() return caption_options, generated, img_matrix captions, generated, image_matrix = caption_image( finetuned_model, '/content/1.jpg' ) captions