Spaces:
Build error
Build error
| import torch | |
| import requests | |
| import gradio as gr | |
| from io import BytesIO | |
| from PIL import Image | |
| from model import CaptionModel | |
| from torchvision import transforms | |
| from preprocess import Tokenizer, return_user_agent | |
| tokenizer = Tokenizer('./') | |
| tokenizer.load_tokenizer('./checkpoints/vocab-v1.pkl') | |
| weights = torch.load('./checkpoints/caption_model.pt', map_location=torch.device('cpu')) | |
| model = CaptionModel(tokenizer) | |
| model.load_state_dict(weights['state_dict']) | |
| val_tfms = transforms.Compose([ | |
| # smaller edge of image resized to 256 | |
| transforms.Resize(256), | |
| transforms.ToTensor(), | |
| # normalize image for pre-trained model | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) | |
| ]) | |
| def decode_caption(idxs, tokenizer): | |
| temp = [] | |
| for i in idxs: | |
| temp.append(tokenizer.idx2val[i]) | |
| return ' '.join(temp).replace('<end>', '') | |
| def predict_fn(image, link): | |
| if link != '': | |
| try: | |
| resp = requests.get(link, headers=return_user_agent()) | |
| image = Image.open(BytesIO(resp.content)) | |
| except: | |
| error_image = Image.open('./error.jpg') | |
| error_text = 'Image from given link could not be downloaded, please try again with valid link' | |
| return error_image, error_text | |
| display_image = transforms.Resize(100)(image) | |
| image = val_tfms(image).unsqueeze(0) | |
| model.eval() | |
| out = model.predict(image, torch.device('cpu')) | |
| caption = decode_caption(out[0], tokenizer) | |
| return display_image, caption | |
| demo = gr.Interface( | |
| fn=predict_fn, | |
| inputs=[ | |
| gr.Image(label="Input Image", type='pil'), | |
| gr.Textbox(label='Enter Image Link', placeholder='Enter or Paste any Image link from Internet') | |
| ], | |
| outputs=[ | |
| gr.Image(label="Display Image for link as input", type='pil'), | |
| gr.Textbox(label="Generated Caption"), | |
| ], | |
| title="Image Captioning System", | |
| description="Image Captioning Model trained on Flick8k Dataset ", | |
| ) | |
| demo.launch(debug=True) |