Spaces:
Sleeping
Sleeping
File size: 1,529 Bytes
d4d1ef4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2TokenizerFast
from tqdm import tqdm
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def get_caption(model, image_processor, tokenizer, image_path):
"""
Generate a caption for a given image using a ViT-GPT2 image captioning model.
This function:
- Opens the image and ensures it is in RGB mode.
- Processes the image using the ViT image processor.
- Uses the encoder-decoder model to generate a caption.
- Decodes and returns the caption as a string.
Args:
model (VisionEncoderDecoderModel): Pre-trained image captioning model.
image_processor (ViTImageProcessor): Pre-trained image processor.
tokenizer (GPT2TokenizerFast): Tokenizer to decode model output.
image_path (str): Path to the input image file.
Returns:
str: Generated caption describing the image.
"""
# Open the image
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert(mode="RGB")
img = image_processor(image, return_tensors="pt")
output = model.generate(**img)
caption = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return caption
|