Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| from typing import Tuple, Union | |
| import os | |
| class ImageEncoder(nn.Module): | |
| def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"): | |
| """Initialize the image encoder using CLIP. | |
| Args: | |
| clip_model_name: HuggingFace model name for CLIP | |
| """ | |
| super().__init__() | |
| # Store model name for lazy loading | |
| self.clip_model_name = clip_model_name | |
| self.clip_model = None | |
| self.processor = None | |
| self.valence_head = None | |
| self.arousal_head = None | |
| self.device = None | |
| self._initialized = False | |
| def _ensure_initialized(self): | |
| """Lazy initialization of the model components.""" | |
| if self._initialized: | |
| return | |
| print(f"Initializing ImageEncoder with {self.clip_model_name}...") | |
| print("Loading CLIP model from local cache (network disabled)...") | |
| # Prefer loading strictly from the local Hugging Face cache that `app.py` populates. | |
| # If the files are genuinely missing (e.g. first run without network), we fall back | |
| # to an online download so the user still gets a working application. | |
| # Determine the cache directory from env – this is set in `app.py`. | |
| hf_cache_dir = os.environ.get("HF_HUB_CACHE", None) | |
| try: | |
| self.clip_model = CLIPModel.from_pretrained( | |
| self.clip_model_name, | |
| cache_dir=hf_cache_dir, | |
| local_files_only=True, # use cache only on the first attempt | |
| ) | |
| self.processor = CLIPProcessor.from_pretrained( | |
| self.clip_model_name, | |
| cache_dir=hf_cache_dir, | |
| local_files_only=True, | |
| ) | |
| print("CLIP model loaded successfully from local cache") | |
| except (OSError, EnvironmentError) as cache_err: | |
| print( | |
| "Local cache for CLIP model not found – attempting a one-time online download..." | |
| ) | |
| # Note: this will still respect HF_HUB_CACHE so the files are cached for future runs. | |
| self.clip_model = CLIPModel.from_pretrained( | |
| self.clip_model_name, | |
| cache_dir=hf_cache_dir, | |
| ) | |
| self.processor = CLIPProcessor.from_pretrained( | |
| self.clip_model_name, | |
| cache_dir=hf_cache_dir, | |
| ) | |
| print("CLIP model downloaded and cached successfully") | |
| print("CLIP model loaded successfully") | |
| # Freeze CLIP parameters | |
| for param in self.clip_model.parameters(): | |
| param.requires_grad = False | |
| # Add projection layers for valence and arousal | |
| hidden_dim = self.clip_model.config.projection_dim | |
| projection_dim = hidden_dim // 2 | |
| self.valence_head = nn.Sequential( | |
| nn.Linear(hidden_dim, projection_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(projection_dim, projection_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(projection_dim // 2, 1), | |
| nn.Tanh() # Output between -1 and 1 | |
| ) | |
| self.arousal_head = nn.Sequential( | |
| nn.Linear(hidden_dim, projection_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(projection_dim, projection_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(projection_dim // 2, 1), | |
| nn.Tanh() # Output between -1 and 1 | |
| ) | |
| # Move model to GPU if available | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.to(self.device) | |
| print(f"Model moved to device: {self.device}") | |
| self._initialized = True | |
| def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward pass to get valence and arousal predictions. | |
| Args: | |
| images: Either PIL images or tensors in CLIP format | |
| Returns: | |
| Tuple of predicted valence and arousal scores | |
| """ | |
| # Ensure model is initialized | |
| self._ensure_initialized() | |
| # Process images if they're PIL images | |
| if isinstance(images, Image.Image): | |
| inputs = self.processor(images=images, return_tensors="pt") | |
| pixel_values = inputs.pixel_values.to(self.device) | |
| else: | |
| pixel_values = images.to(self.device) | |
| # Get CLIP image features | |
| image_features = self.clip_model.get_image_features(pixel_values) | |
| # Project to valence and arousal scores | |
| valence = self.valence_head(image_features) | |
| arousal = self.arousal_head(image_features) | |
| return valence, arousal | |
| def encode_image(self, image: Image.Image) -> torch.Tensor: | |
| """Get the raw CLIP image embeddings. | |
| Args: | |
| image: PIL image to encode | |
| Returns: | |
| Image embedding tensor | |
| """ | |
| # Ensure model is initialized | |
| self._ensure_initialized() | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device)) | |
| return image_features |