File size: 2,876 Bytes
61d3625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Encoders for images, captions, and text."""
from __future__ import annotations

import torch
from PIL import Image
from transformers import (
    AutoProcessor,
    BlipForConditionalGeneration,
    CLIPModel,
    CLIPProcessor,
)

from .config import CONFIG
from .utils import torch_no_grad


def get_device() -> torch.device:
    cfg_device = CONFIG.models.device
    if cfg_device == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    assert cfg_device != "cuda" or torch.cuda.is_available(), "CUDA is not available"
    return torch.device(cfg_device)


class ImageEncoder:
    def __init__(self):
        self.model = CLIPModel.from_pretrained(CONFIG.models.image_encoder)
        self.processor = CLIPProcessor.from_pretrained(CONFIG.models.image_encoder, use_fast=True)

        self.model = self.model.to(get_device())

    def encode(self, images: list[Image.Image]) -> torch.Tensor:
        with torch_no_grad():
            inputs = self.processor(
                images=images,
                return_tensors="pt",
            ).to(get_device())
            outputs = self.model.get_image_features(**inputs)
            outputs = torch.nn.functional.normalize(outputs, p=2, dim=-1)
        return outputs.cpu()


class TextEncoder:
    def __init__(self):
        self.model = CLIPModel.from_pretrained(CONFIG.models.vlm_model)
        self.processor = CLIPProcessor.from_pretrained(CONFIG.models.vlm_model, use_fast=True)

        self.model = self.model.to(get_device())

    def encode(self, texts: list[str]) -> torch.Tensor:
        with torch_no_grad():
            inputs = self.processor(
                text=texts,
                return_tensors="pt",
                padding=True,
            ).to(get_device())
            outputs = self.model.get_text_features(**inputs)
            outputs = torch.nn.functional.normalize(outputs, p=2, dim=-1)
        return outputs.cpu()


class CaptionGenerator:
    def __init__(self):
        self.model = BlipForConditionalGeneration.from_pretrained(CONFIG.models.caption_model)
        self.processor = AutoProcessor.from_pretrained(CONFIG.models.caption_model, use_fast=True)
        
        self.model = self.model.to(get_device())

    def generate(self, images: list[Image.Image], max_length: int = 64) -> list[str]:
        with torch_no_grad():
            inputs = self.processor(
                images=images,
                return_tensors="pt",
            ).to(get_device())
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_beams=3,
            )
            captions = self.processor.batch_decode(outputs, skip_special_tokens=True)
        return [caption.strip() for caption in captions]