import torch import torch.nn as nn from transformers import CLIPProcessor, CLIPModel from PIL import Image from typing import Tuple, Union import os class ImageEncoder(nn.Module): def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"): """Initialize the image encoder using CLIP. Args: clip_model_name: HuggingFace model name for CLIP """ super().__init__() # Store model name for lazy loading self.clip_model_name = clip_model_name self.clip_model = None self.processor = None self.valence_head = None self.arousal_head = None self.device = None self._initialized = False def _ensure_initialized(self): """Lazy initialization of the model components.""" if self._initialized: return print(f"Initializing ImageEncoder with {self.clip_model_name}...") print("Loading CLIP model from local cache (network disabled)...") # Prefer loading strictly from the local Hugging Face cache that `app.py` populates. # If the files are genuinely missing (e.g. first run without network), we fall back # to an online download so the user still gets a working application. # Determine the cache directory from env – this is set in `app.py`. hf_cache_dir = os.environ.get("HF_HUB_CACHE", None) try: self.clip_model = CLIPModel.from_pretrained( self.clip_model_name, cache_dir=hf_cache_dir, local_files_only=True, # use cache only on the first attempt ) self.processor = CLIPProcessor.from_pretrained( self.clip_model_name, cache_dir=hf_cache_dir, local_files_only=True, ) print("CLIP model loaded successfully from local cache") except (OSError, EnvironmentError) as cache_err: print( "Local cache for CLIP model not found – attempting a one-time online download..." ) # Note: this will still respect HF_HUB_CACHE so the files are cached for future runs. self.clip_model = CLIPModel.from_pretrained( self.clip_model_name, cache_dir=hf_cache_dir, ) self.processor = CLIPProcessor.from_pretrained( self.clip_model_name, cache_dir=hf_cache_dir, ) print("CLIP model downloaded and cached successfully") print("CLIP model loaded successfully") # Freeze CLIP parameters for param in self.clip_model.parameters(): param.requires_grad = False # Add projection layers for valence and arousal hidden_dim = self.clip_model.config.projection_dim projection_dim = hidden_dim // 2 self.valence_head = nn.Sequential( nn.Linear(hidden_dim, projection_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(projection_dim, projection_dim // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(projection_dim // 2, 1), nn.Tanh() # Output between -1 and 1 ) self.arousal_head = nn.Sequential( nn.Linear(hidden_dim, projection_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(projection_dim, projection_dim // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(projection_dim // 2, 1), nn.Tanh() # Output between -1 and 1 ) # Move model to GPU if available self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.device) print(f"Model moved to device: {self.device}") self._initialized = True def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass to get valence and arousal predictions. Args: images: Either PIL images or tensors in CLIP format Returns: Tuple of predicted valence and arousal scores """ # Ensure model is initialized self._ensure_initialized() # Process images if they're PIL images if isinstance(images, Image.Image): inputs = self.processor(images=images, return_tensors="pt") pixel_values = inputs.pixel_values.to(self.device) else: pixel_values = images.to(self.device) # Get CLIP image features image_features = self.clip_model.get_image_features(pixel_values) # Project to valence and arousal scores valence = self.valence_head(image_features) arousal = self.arousal_head(image_features) return valence, arousal def encode_image(self, image: Image.Image) -> torch.Tensor: """Get the raw CLIP image embeddings. Args: image: PIL image to encode Returns: Image embedding tensor """ # Ensure model is initialized self._ensure_initialized() inputs = self.processor(images=image, return_tensors="pt") with torch.no_grad(): image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device)) return image_features