Mayanand's picture
Update app.py
0b42e4c
raw
history blame
1.71 kB
import torch
import requests
import gradio as gr
from io import BytesIO
from PIL import Image
from model import CaptionModel
from utils import Tokenizer, return_user_agent
tokenizer = Tokenizer('./')
tokenizer.load_tokenizer('./checkpoints/vocab-v1.pkl')
weights = torch.load('./checkpoints/epoch=87-step=7144.ckpt', map_location=torch.device('cpu'))
model = CaptionModel(tokenizer)
model.load_state_dict(weights['state_dict'])
def decode_caption(idxs, tokenizer):
temp = []
for i in idxs:
temp.append(tokenizer.idx2val[i])
return ' '.join(temp).replace('<end>', '')
def predict_fn(image, link):
if link != '':
try:
resp = requests.get(link, headers=return_user_agent())
image = Image.open(BytesIO(resp.content))
except:
error_image = Image.open('./error.jpg')
error_text = 'Image from given link could not be downloaded, please try again with valid link'
return error_image, error_text
display_image = image
image = val_tfms(image).unsqueeze(0)
model.eval()
out = model.predict(image, torch.device('cpu'))
caption = decode_caption(out[0], tokenizer)
return display_image, caption
demo = gr.Interface(
fn=predict_fn,
inputs=[
gr.Image(label="Input Image", type='pil'),
gr.Textbox(label='Enter Image Link', placeholder='Enter or Paste any Image link from Internet')
],
outputs=[
gr.Image(label="Display Image for link as input", type='pil'),
gr.Textbox(label="Generated Caption"),
],
title="Image Captioning System",
description="Image Captioning Model trained on Flick8k Dataset ",
)
demo.launch()