Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """app | |
| Automatically generated by Colaboratory. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1Uvn7yZCyrMpOYNPb7K0G45tQZJVx8LyX | |
| """ | |
| from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| max_length = 16 | |
| num_beams = 4 | |
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | |
| def predict_step(image): | |
| # images = [] | |
| # for image_path in image_paths: | |
| # i_image = Image.open(image_path) | |
| # if i_image.mode != "RGB": | |
| # i_image = i_image.convert(mode="RGB") | |
| # images.append(i_image) | |
| pixel_values = feature_extractor(images = image, return_tensors = "pt").pixel_values | |
| pixel_values = pixel_values.to(device) | |
| output_ids = model.generate(pixel_values, **gen_kwargs) | |
| preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| preds = [pred.strip() for pred in preds] | |
| return preds | |
| inputs = [ gr.inputs.Image(type = 'pil', label = 'Original Image')] | |
| outputs = [ gr.outputs.Textbox(label = 'Caption')] | |
| title = 'Image Captioning using ViT + GPT2' | |
| description = 'ViT and GPT2 are used here to generate Image Caption for the user uploaded image.' | |
| article = " <a href=' https://huggingface.co/sachin/vit2distilgpt2 '>Model Repository on Hugging Face Model Hub</a>" | |
| gr.Interface( | |
| predict_step, | |
| inputs, outputs, | |
| title = title, | |
| description = description, | |
| article = article, | |
| theme = 'huggingface' | |
| ).launch(debug = True, enable_queue = True) | |