Image_captioning / image_to_text.py
Prashanthsrn's picture
Update image_to_text.py
0959f52 verified
raw
history blame
1.06 kB
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
def generate_caption(image):
# Load pre-trained model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Prepare image
if isinstance(image, Image.Image):
image = image.convert('RGB')
else:
image = Image.open(image).convert('RGB')
pixel_values = feature_extractor(images=[image], return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# Generate caption
output_ids = model.generate(pixel_values, max_length=16, num_return_sequences=1)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return preds[0]