Mayanand's picture
Update app.py
de12553
import torch
import requests
import gradio as gr
from io import BytesIO
from PIL import Image
from model import CaptionModel
from torchvision import transforms
from preprocess import Tokenizer, return_user_agent
tokenizer = Tokenizer('./')
tokenizer.load_tokenizer('./checkpoints/vocab-v1.pkl')
weights = torch.load('./checkpoints/caption_model.pt', map_location=torch.device('cpu'))
model = CaptionModel(tokenizer)
model.load_state_dict(weights['state_dict'])
val_tfms = transforms.Compose([
# smaller edge of image resized to 256
transforms.Resize(256),
transforms.ToTensor(),
# normalize image for pre-trained model
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
def decode_caption(idxs, tokenizer):
temp = []
for i in idxs:
temp.append(tokenizer.idx2val[i])
return ' '.join(temp).replace('<end>', '')
def predict_fn(image, link):
if link != '':
try:
resp = requests.get(link, headers=return_user_agent())
image = Image.open(BytesIO(resp.content))
except:
error_image = Image.open('./error.jpg')
error_text = 'Image from given link could not be downloaded, please try again with valid link'
return error_image, error_text
display_image = transforms.Resize(100)(image)
image = val_tfms(image).unsqueeze(0)
model.eval()
out = model.predict(image, torch.device('cpu'))
caption = decode_caption(out[0], tokenizer)
return display_image, caption
demo = gr.Interface(
fn=predict_fn,
inputs=[
gr.Image(label="Input Image", type='pil'),
gr.Textbox(label='Enter Image Link', placeholder='Enter or Paste any Image link from Internet')
],
outputs=[
gr.Image(label="Display Image for link as input", type='pil'),
gr.Textbox(label="Generated Caption"),
],
title="Image Captioning System",
description="Image Captioning Model trained on Flick8k Dataset ",
)
demo.launch(debug=True)