""" Vision Transformer Branch for Hybrid Food Classifier Uses DeiT-Base as backbone with custom head """ import torch import torch.nn as nn from transformers import DeiTModel, DeiTConfig from typing import Tuple class ViTBranch(nn.Module): """Vision Transformer branch using DeiT-Base""" def __init__( self, model_name: str = "facebook/deit-base-distilled-patch16-224", pretrained: bool = True, freeze_early_layers: bool = True, dropout: float = 0.1, feature_dim: int = 768 ): super(ViTBranch, self).__init__() self.feature_dim = feature_dim # Load DeiT model if pretrained: self.vit = DeiTModel.from_pretrained(model_name) else: config = DeiTConfig.from_pretrained(model_name) self.vit = DeiTModel(config) # Get model dimensions self.hidden_size = self.vit.config.hidden_size # 768 for base self.num_patches = (224 // 16) ** 2 # 196 patches for 224x224 image # Freeze early layers if specified if freeze_early_layers: self._freeze_early_layers() # Feature projection to match CNN branch self.feature_proj = nn.Sequential( nn.Linear(self.hidden_size, feature_dim), nn.LayerNorm(feature_dim), nn.GELU(), nn.Dropout(dropout) ) # Spatial feature projection (for fusion with CNN spatial features) self.spatial_proj = nn.Sequential( nn.Linear(self.hidden_size, feature_dim), nn.LayerNorm(feature_dim), nn.GELU(), nn.Dropout(dropout) ) # Additional processing head self.feature_head = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.LayerNorm(feature_dim), nn.GELU(), nn.Dropout(dropout) ) def _freeze_early_layers(self): """Freeze early layers of the ViT""" # Freeze first 8 transformer layers (out of 12) layers_to_freeze = 8 for i, layer in enumerate(self.vit.encoder.layer): if i < layers_to_freeze: for param in layer.parameters(): param.requires_grad = False def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass Args: x: Input tensor [B, 3, H, W] Returns: spatial_features: Patch features [B, num_patches, feature_dim] global_features: CLS token features [B, feature_dim] """ # Get ViT outputs outputs = self.vit(pixel_values=x) # Extract features last_hidden_states = outputs.last_hidden_state # [B, seq_len, hidden_size] # CLS token (first token) for global features cls_token = last_hidden_states[:, 0] # [B, hidden_size] # Patch tokens for spatial features patch_tokens = last_hidden_states[:, 1:] # [B, num_patches, hidden_size] # Project features global_features = self.feature_proj(cls_token) # [B, feature_dim] spatial_features = self.spatial_proj(patch_tokens) # [B, num_patches, feature_dim] # Additional processing global_features = self.feature_head(global_features) # [B, feature_dim] return spatial_features, global_features def get_feature_dim(self) -> int: """Get feature dimension""" return self.feature_dim def get_num_patches(self) -> int: """Get number of patches""" return self.num_patches