| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from typing import TYPE_CHECKING |
| |
|
| | from ..models.auto import AutoModelForVision2Seq |
| | from ..utils import requires_backends |
| | from .base import PipelineTool |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from PIL import Image |
| |
|
| |
|
| | class ImageCaptioningTool(PipelineTool): |
| | default_checkpoint = "Salesforce/blip-image-captioning-base" |
| | description = ( |
| | "This is a tool that generates a description of an image. It takes an input named `image` which should be the " |
| | "image to caption, and returns a text that contains the description in English." |
| | ) |
| | name = "image_captioner" |
| | model_class = AutoModelForVision2Seq |
| |
|
| | inputs = ["image"] |
| | outputs = ["text"] |
| |
|
| | def __init__(self, *args, **kwargs): |
| | requires_backends(self, ["vision"]) |
| | super().__init__(*args, **kwargs) |
| |
|
| | def encode(self, image: "Image"): |
| | return self.pre_processor(images=image, return_tensors="pt") |
| |
|
| | def forward(self, inputs): |
| | return self.model.generate(**inputs) |
| |
|
| | def decode(self, outputs): |
| | return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() |
| |
|