File size: 3,747 Bytes
28b51fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Vision Transformer Branch for Hybrid Food Classifier
Uses DeiT-Base as backbone with custom head
"""
import torch
import torch.nn as nn
from transformers import DeiTModel, DeiTConfig
from typing import Tuple

class ViTBranch(nn.Module):
    """Vision Transformer branch using DeiT-Base"""
    
    def __init__(
        self,
        model_name: str = "facebook/deit-base-distilled-patch16-224",
        pretrained: bool = True,
        freeze_early_layers: bool = True,
        dropout: float = 0.1,
        feature_dim: int = 768
    ):
        super(ViTBranch, self).__init__()
        
        self.feature_dim = feature_dim
        
        # Load DeiT model
        if pretrained:
            self.vit = DeiTModel.from_pretrained(model_name)
        else:
            config = DeiTConfig.from_pretrained(model_name)
            self.vit = DeiTModel(config)
        
        # Get model dimensions
        self.hidden_size = self.vit.config.hidden_size  # 768 for base
        self.num_patches = (224 // 16) ** 2  # 196 patches for 224x224 image
        
        # Freeze early layers if specified
        if freeze_early_layers:
            self._freeze_early_layers()
        
        # Feature projection to match CNN branch
        self.feature_proj = nn.Sequential(
            nn.Linear(self.hidden_size, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Spatial feature projection (for fusion with CNN spatial features)
        self.spatial_proj = nn.Sequential(
            nn.Linear(self.hidden_size, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Additional processing head
        self.feature_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
    
    def _freeze_early_layers(self):
        """Freeze early layers of the ViT"""
        # Freeze first 8 transformer layers (out of 12)
        layers_to_freeze = 8
        for i, layer in enumerate(self.vit.encoder.layer):
            if i < layers_to_freeze:
                for param in layer.parameters():
                    param.requires_grad = False
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass
        
        Args:
            x: Input tensor [B, 3, H, W]
            
        Returns:
            spatial_features: Patch features [B, num_patches, feature_dim]
            global_features: CLS token features [B, feature_dim]
        """
        # Get ViT outputs
        outputs = self.vit(pixel_values=x)
        
        # Extract features
        last_hidden_states = outputs.last_hidden_state  # [B, seq_len, hidden_size]
        
        # CLS token (first token) for global features
        cls_token = last_hidden_states[:, 0]  # [B, hidden_size]
        
        # Patch tokens for spatial features
        patch_tokens = last_hidden_states[:, 1:]  # [B, num_patches, hidden_size]
        
        # Project features
        global_features = self.feature_proj(cls_token)  # [B, feature_dim]
        spatial_features = self.spatial_proj(patch_tokens)  # [B, num_patches, feature_dim]
        
        # Additional processing
        global_features = self.feature_head(global_features)  # [B, feature_dim]
        
        return spatial_features, global_features
    
    def get_feature_dim(self) -> int:
        """Get feature dimension"""
        return self.feature_dim
    
    def get_num_patches(self) -> int:
        """Get number of patches"""
        return self.num_patches