|
|
""" |
|
|
Vision Transformer Branch for Hybrid Food Classifier |
|
|
Uses DeiT-Base as backbone with custom head |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import DeiTModel, DeiTConfig |
|
|
from typing import Tuple |
|
|
|
|
|
class ViTBranch(nn.Module): |
|
|
"""Vision Transformer branch using DeiT-Base""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "facebook/deit-base-distilled-patch16-224", |
|
|
pretrained: bool = True, |
|
|
freeze_early_layers: bool = True, |
|
|
dropout: float = 0.1, |
|
|
feature_dim: int = 768 |
|
|
): |
|
|
super(ViTBranch, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
|
|
|
|
|
|
if pretrained: |
|
|
self.vit = DeiTModel.from_pretrained(model_name) |
|
|
else: |
|
|
config = DeiTConfig.from_pretrained(model_name) |
|
|
self.vit = DeiTModel(config) |
|
|
|
|
|
|
|
|
self.hidden_size = self.vit.config.hidden_size |
|
|
self.num_patches = (224 // 16) ** 2 |
|
|
|
|
|
|
|
|
if freeze_early_layers: |
|
|
self._freeze_early_layers() |
|
|
|
|
|
|
|
|
self.feature_proj = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, feature_dim), |
|
|
nn.LayerNorm(feature_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.spatial_proj = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, feature_dim), |
|
|
nn.LayerNorm(feature_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.feature_head = nn.Sequential( |
|
|
nn.Linear(feature_dim, feature_dim), |
|
|
nn.LayerNorm(feature_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
def _freeze_early_layers(self): |
|
|
"""Freeze early layers of the ViT""" |
|
|
|
|
|
layers_to_freeze = 8 |
|
|
for i, layer in enumerate(self.vit.encoder.layer): |
|
|
if i < layers_to_freeze: |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Forward pass |
|
|
|
|
|
Args: |
|
|
x: Input tensor [B, 3, H, W] |
|
|
|
|
|
Returns: |
|
|
spatial_features: Patch features [B, num_patches, feature_dim] |
|
|
global_features: CLS token features [B, feature_dim] |
|
|
""" |
|
|
|
|
|
outputs = self.vit(pixel_values=x) |
|
|
|
|
|
|
|
|
last_hidden_states = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
cls_token = last_hidden_states[:, 0] |
|
|
|
|
|
|
|
|
patch_tokens = last_hidden_states[:, 1:] |
|
|
|
|
|
|
|
|
global_features = self.feature_proj(cls_token) |
|
|
spatial_features = self.spatial_proj(patch_tokens) |
|
|
|
|
|
|
|
|
global_features = self.feature_head(global_features) |
|
|
|
|
|
return spatial_features, global_features |
|
|
|
|
|
def get_feature_dim(self) -> int: |
|
|
"""Get feature dimension""" |
|
|
return self.feature_dim |
|
|
|
|
|
def get_num_patches(self) -> int: |
|
|
"""Get number of patches""" |
|
|
return self.num_patches |