Spaces:
Sleeping
Sleeping
File size: 2,876 Bytes
61d3625 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | """Encoders for images, captions, and text."""
from __future__ import annotations
import torch
from PIL import Image
from transformers import (
AutoProcessor,
BlipForConditionalGeneration,
CLIPModel,
CLIPProcessor,
)
from .config import CONFIG
from .utils import torch_no_grad
def get_device() -> torch.device:
cfg_device = CONFIG.models.device
if cfg_device == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert cfg_device != "cuda" or torch.cuda.is_available(), "CUDA is not available"
return torch.device(cfg_device)
class ImageEncoder:
def __init__(self):
self.model = CLIPModel.from_pretrained(CONFIG.models.image_encoder)
self.processor = CLIPProcessor.from_pretrained(CONFIG.models.image_encoder, use_fast=True)
self.model = self.model.to(get_device())
def encode(self, images: list[Image.Image]) -> torch.Tensor:
with torch_no_grad():
inputs = self.processor(
images=images,
return_tensors="pt",
).to(get_device())
outputs = self.model.get_image_features(**inputs)
outputs = torch.nn.functional.normalize(outputs, p=2, dim=-1)
return outputs.cpu()
class TextEncoder:
def __init__(self):
self.model = CLIPModel.from_pretrained(CONFIG.models.vlm_model)
self.processor = CLIPProcessor.from_pretrained(CONFIG.models.vlm_model, use_fast=True)
self.model = self.model.to(get_device())
def encode(self, texts: list[str]) -> torch.Tensor:
with torch_no_grad():
inputs = self.processor(
text=texts,
return_tensors="pt",
padding=True,
).to(get_device())
outputs = self.model.get_text_features(**inputs)
outputs = torch.nn.functional.normalize(outputs, p=2, dim=-1)
return outputs.cpu()
class CaptionGenerator:
def __init__(self):
self.model = BlipForConditionalGeneration.from_pretrained(CONFIG.models.caption_model)
self.processor = AutoProcessor.from_pretrained(CONFIG.models.caption_model, use_fast=True)
self.model = self.model.to(get_device())
def generate(self, images: list[Image.Image], max_length: int = 64) -> list[str]:
with torch_no_grad():
inputs = self.processor(
images=images,
return_tensors="pt",
).to(get_device())
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=3,
)
captions = self.processor.batch_decode(outputs, skip_special_tokens=True)
return [caption.strip() for caption in captions]
|