GenD-Sentinel / src /encoders /clip_encoder.py
yermandy's picture
init
c29babb
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
class CLIPEncoder(nn.Module):
def __init__(self, model_name="openai/clip-vit-large-patch14"):
"""
Models:
1. openai/clip-vit-base-patch16 | 768 features
2. openai/clip-vit-base-patch32 | 768 features
3. openai/clip-vit-large-patch14 | 1024 features
See more in src/config.py
"""
super().__init__()
try:
self._preprocess = CLIPProcessor.from_pretrained(model_name)
except Exception:
self._preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip: CLIPModel = CLIPModel.from_pretrained(model_name)
# take vision model from CLIP, maps image to vision_embed_dim
self.vision_model = clip.vision_model
self.model_name = model_name
self.features_dim = self.vision_model.config.hidden_size
# take visual_projection, maps vision_embed_dim to projection_dim
self.visual_projection = clip.visual_projection
def preprocess(self, image: Image) -> torch.Tensor:
return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0]
def forward(self, preprocessed_images: torch.Tensor) -> torch.Tensor:
return self.vision_model(preprocessed_images).pooler_output
def get_features_dim(self):
return self.features_dim
if __name__ == "__main__":
import autorootcwd # noqa: F401
from src.config import Backbone
from src.encoders._common import inference
model = CLIPEncoder(Backbone.CLIP_B_16.value)
inference(model)