Image_captioning / image_to_text.py
Prashanthsrn's picture
Update image_to_text.py
818bff5 verified
from PIL import Image
import torch
import numpy as np
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
def preprocess_image(image, target_size=(224, 224)):
# Resize the image to a fixed size
image = image.resize(target_size, Image.LANCZOS)
# Convert to numpy array and normalize
image_array = np.array(image) / 255.0
# Transpose to (channels, height, width) format
image_array = np.transpose(image_array, (2, 0, 1))
return image_array
def generate_caption(image):
# Load pre-trained model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Prepare image
if not isinstance(image, Image.Image):
image = Image.open(image)
image = image.convert('RGB')
image_array = preprocess_image(image)
# Create a batch with a single image
batch = np.expand_dims(image_array, axis=0)
# Convert to tensor
pixel_values = torch.tensor(batch).float()
pixel_values = pixel_values.to(device)
# Generate caption
output_ids = model.generate(pixel_values, max_length=16, num_return_sequences=1)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return preds[0]