Spaces:
Build error
Build error
File size: 2,104 Bytes
0b42e4c 716c14d 7c35328 0b42e4c abc1750 0b42e4c abd9184 0b42e4c de12553 0b42e4c ee35021 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import torch
import requests
import gradio as gr
from io import BytesIO
from PIL import Image
from model import CaptionModel
from torchvision import transforms
from preprocess import Tokenizer, return_user_agent
tokenizer = Tokenizer('./')
tokenizer.load_tokenizer('./checkpoints/vocab-v1.pkl')
weights = torch.load('./checkpoints/caption_model.pt', map_location=torch.device('cpu'))
model = CaptionModel(tokenizer)
model.load_state_dict(weights['state_dict'])
val_tfms = transforms.Compose([
# smaller edge of image resized to 256
transforms.Resize(256),
transforms.ToTensor(),
# normalize image for pre-trained model
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
def decode_caption(idxs, tokenizer):
temp = []
for i in idxs:
temp.append(tokenizer.idx2val[i])
return ' '.join(temp).replace('<end>', '')
def predict_fn(image, link):
if link != '':
try:
resp = requests.get(link, headers=return_user_agent())
image = Image.open(BytesIO(resp.content))
except:
error_image = Image.open('./error.jpg')
error_text = 'Image from given link could not be downloaded, please try again with valid link'
return error_image, error_text
display_image = transforms.Resize(100)(image)
image = val_tfms(image).unsqueeze(0)
model.eval()
out = model.predict(image, torch.device('cpu'))
caption = decode_caption(out[0], tokenizer)
return display_image, caption
demo = gr.Interface(
fn=predict_fn,
inputs=[
gr.Image(label="Input Image", type='pil'),
gr.Textbox(label='Enter Image Link', placeholder='Enter or Paste any Image link from Internet')
],
outputs=[
gr.Image(label="Display Image for link as input", type='pil'),
gr.Textbox(label="Generated Caption"),
],
title="Image Captioning System",
description="Image Captioning Model trained on Flick8k Dataset ",
)
demo.launch(debug=True) |