Mayanand commited on
Commit
0b42e4c
·
1 Parent(s): 9bd8267

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py CHANGED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ import gradio as gr
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ from model import CaptionModel
7
+ from utils import Tokenizer, return_user_agent
8
+
9
+ tokenizer = Tokenizer('./')
10
+ tokenizer.load_tokenizer('./checkpoints/vocab-v1.pkl')
11
+
12
+ weights = torch.load('./checkpoints/epoch=87-step=7144.ckpt', map_location=torch.device('cpu'))
13
+ model = CaptionModel(tokenizer)
14
+ model.load_state_dict(weights['state_dict'])
15
+
16
+ def decode_caption(idxs, tokenizer):
17
+ temp = []
18
+ for i in idxs:
19
+ temp.append(tokenizer.idx2val[i])
20
+ return ' '.join(temp).replace('<end>', '')
21
+
22
+ def predict_fn(image, link):
23
+
24
+ if link != '':
25
+ try:
26
+ resp = requests.get(link, headers=return_user_agent())
27
+ image = Image.open(BytesIO(resp.content))
28
+ except:
29
+ error_image = Image.open('./error.jpg')
30
+ error_text = 'Image from given link could not be downloaded, please try again with valid link'
31
+ return error_image, error_text
32
+
33
+ display_image = image
34
+ image = val_tfms(image).unsqueeze(0)
35
+ model.eval()
36
+ out = model.predict(image, torch.device('cpu'))
37
+
38
+ caption = decode_caption(out[0], tokenizer)
39
+ return display_image, caption
40
+
41
+
42
+ demo = gr.Interface(
43
+ fn=predict_fn,
44
+ inputs=[
45
+ gr.Image(label="Input Image", type='pil'),
46
+ gr.Textbox(label='Enter Image Link', placeholder='Enter or Paste any Image link from Internet')
47
+ ],
48
+ outputs=[
49
+ gr.Image(label="Display Image for link as input", type='pil'),
50
+ gr.Textbox(label="Generated Caption"),
51
+ ],
52
+ title="Image Captioning System",
53
+ description="Image Captioning Model trained on Flick8k Dataset ",
54
+ )
55
+
56
+ demo.launch()