Spaces:
Sleeping
Sleeping
File size: 3,895 Bytes
40ccd41 45419f0 40ccd41 45419f0 40ccd41 45419f0 40ccd41 45419f0 40ccd41 7306515 45419f0 40ccd41 45419f0 40ccd41 45419f0 40ccd41 45419f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# from agentlego.tools import BaseTool
# from PIL import Image
# import torch
# class ImageDescriptionTool(BaseTool):
# default_desc = 'Uses a pretrained BLIP model to generate descriptions for images.'
# def __init__(self):
# super().__init__()
# # Load models inside the class initialization
# from transformers import AutoProcessor, AutoModelForImageTextToText
# MODEL_ID = "Salesforce/blip-image-captioning-base"
# self.processor = AutoProcessor.from_pretrained(MODEL_ID)
# self.model = AutoModelForImageTextToText.from_pretrained(MODEL_ID)
# # Set up device and generation parameters
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model.to(self.device)
# self.max_length = 256
# self.num_beams = 5
# self.gen_kwargs = {
# "max_length": self.max_length,
# "num_beams": self.num_beams,
# "early_stopping": True
# }
# def apply(self, image_path: str) -> str:
# try:
# # Open the image
# image = Image.open(image_path)
# if image.mode != "RGB":
# image = image.convert(mode="RGB")
# # Preprocess image
# inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# # Generate caption
# with torch.no_grad():
# output_ids = self.model.generate(**inputs, **self.gen_kwargs)
# # Decode prediction
# caption = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
# return f"Description: **{caption}** (generated with BLIP base model)"
# except Exception as e:
# return f"Error during image description: {str(e)}"
from agentlego.tools import BaseTool
from PIL import Image
import torch
class ImageDescriptionTool(BaseTool):
default_desc = 'Uses a pretrained VIT-GPT2 model to generate descriptions for images.'
def __init__(self):
super().__init__()
# Load models inside the class initialization
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
self.model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
self.feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
self.tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# Set up device and generation parameters
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.max_length = 16
self.num_beams = 4
self.gen_kwargs = {"max_length": self.max_length} # no num_beams = greedy decoding
def apply(self, image_path: str) -> str:
try:
# Open the image
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert(mode="RGB")
# Preprocess image
pixel_values = self.feature_extractor(images=[image], return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
# Generate caption
with torch.no_grad():
output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
# Decode prediction
pred = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
pred = pred.strip()
return f"Description: **{pred}** (generated with VIT-GPT2 model)"
except Exception as e:
return f"Error during image description: {str(e)}"
|