Spaces:
Build error
Build error
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer | |
| def create_caption_transformer(img): | |
| """ | |
| create_caption_transformer() create a caption for an image using a transformer model | |
| that was trained on 'Flickr image dataset' | |
| :param img: a numpy array of the image | |
| :return: a string of the image caption | |
| """ | |
| sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu') | |
| caption_ids = model.generate(sample)[0] | |
| caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True) | |
| caption_text = caption_text.split('.')[0] | |
| return caption_text | |
| torch.__version__ | |
| IMAGES_EXAMPLES_FOLDER = 'examples/' | |
| images = os.listdir(IMAGES_EXAMPLES_FOLDER) | |
| IMAGES_EXAMPLES = [IMAGES_EXAMPLES_FOLDER + img for img in images] | |
| model = VisionEncoderDecoderModel.from_pretrained(os.getcwd()).to('cpu') | |
| feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') | |
| tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
| iface = gr.Interface(fn=create_caption_transformer, | |
| inputs="image", | |
| outputs='text', | |
| examples=IMAGES_EXAMPLES | |
| ).launch(share=True) | |