File size: 2,104 Bytes
0b42e4c
 
 
 
 
 
716c14d
7c35328
0b42e4c
 
 
 
abc1750
0b42e4c
 
 
abd9184
 
 
 
 
 
 
 
 
0b42e4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de12553
0b42e4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee35021
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import requests
import gradio as gr
from io import BytesIO
from PIL import Image
from model import CaptionModel
from torchvision import transforms
from preprocess import Tokenizer, return_user_agent

tokenizer = Tokenizer('./')
tokenizer.load_tokenizer('./checkpoints/vocab-v1.pkl')

weights = torch.load('./checkpoints/caption_model.pt', map_location=torch.device('cpu'))
model = CaptionModel(tokenizer)
model.load_state_dict(weights['state_dict'])

val_tfms = transforms.Compose([
            # smaller edge of image resized to 256
            transforms.Resize(256),
            transforms.ToTensor(),
            # normalize image for pre-trained model
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))
        ])

def decode_caption(idxs, tokenizer):
    temp = []
    for i in idxs:
        temp.append(tokenizer.idx2val[i])
    return ' '.join(temp).replace('<end>', '')

def predict_fn(image, link):
   
    if link != '':
        try:
            resp = requests.get(link, headers=return_user_agent())
            image = Image.open(BytesIO(resp.content))
        except:
            error_image = Image.open('./error.jpg') 
            error_text = 'Image from given link could not be downloaded, please try again with valid link'
            return error_image, error_text
        
    display_image = transforms.Resize(100)(image)
    image = val_tfms(image).unsqueeze(0)
    model.eval()
    out = model.predict(image, torch.device('cpu'))

    caption = decode_caption(out[0], tokenizer)
    return display_image, caption


demo = gr.Interface(
    fn=predict_fn,
    inputs=[
        gr.Image(label="Input Image", type='pil'),
        gr.Textbox(label='Enter Image Link', placeholder='Enter or Paste any Image link from Internet')
        ],
    outputs=[
        gr.Image(label="Display Image for link as input", type='pil'),
        gr.Textbox(label="Generated Caption"),
    ],
    title="Image Captioning System",
    description="Image Captioning Model trained on Flick8k Dataset ",
)

demo.launch(debug=True)