JerryAnto's picture
Update app.py
c24dcc7
# -*- coding: utf-8 -*-
"""caption.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/17BgQi1eU254RKp6BKOdC-Kfr1LqIwKmj
## Image Caption Generator
In Colab, Pytorch comes preinstalled and same goes with PIL for Image. You will only need to install **transformers** from Huggingface.
"""
#!pip install transformers
#from google.colab import drive
#drive.mount('/content/drive')
#import transformers
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_step1(image_paths):
i_image = PIL.Image.open(image_paths)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
import gradio as gr
inputs = [
gr.inputs.Image(type="filepath", label="Original Image")
]
outputs = [
gr.outputs.Textbox(label = 'Caption')
]
title = "Image Captioning"
description = "ViT and GPT2 are used to generate Image Caption for the uploaded image."
article = " <a href='https://huggingface.co/nlpconnect/vit-gpt2-image-captioning'>Model Repo on Hugging Face Model Hub</a>"
examples = [
["horses.png"],
['persons.png'],
['football_player.png']
]
gr.Interface(
predict_step,
inputs,
outputs,
title=title,
description=description,
article=article,
examples=examples,
theme="huggingface",
).launch(debug=True, enable_queue=True)