# -*- coding: utf-8 -*- """caption.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/17BgQi1eU254RKp6BKOdC-Kfr1LqIwKmj ## Image Caption Generator In Colab, Pytorch comes preinstalled and same goes with PIL for Image. You will only need to install **transformers** from Huggingface. """ #!pip install transformers #from google.colab import drive #drive.mount('/content/drive') #import transformers from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer import torch from PIL import Image model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) max_length = 16 num_beams = 4 gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def predict_step1(image_paths): i_image = PIL.Image.open(image_paths) if i_image.mode != "RGB": i_image = i_image.convert(mode="RGB") pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) output_ids = model.generate(pixel_values, **gen_kwargs) preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] return preds import gradio as gr inputs = [ gr.inputs.Image(type="filepath", label="Original Image") ] outputs = [ gr.outputs.Textbox(label = 'Caption') ] title = "Image Captioning" description = "ViT and GPT2 are used to generate Image Caption for the uploaded image." article = " Model Repo on Hugging Face Model Hub" examples = [ ["horses.png"], ['persons.png'], ['football_player.png'] ] gr.Interface( predict_step, inputs, outputs, title=title, description=description, article=article, examples=examples, theme="huggingface", ).launch(debug=True, enable_queue=True)