File size: 1,932 Bytes
bdc783f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import timm

class ImageEncoder(nn.Module):
    """
    Standard encoder to extract features from images.
    Used during basic training or feature extraction.
    """
    def __init__(self, model_name='mobilenetv2_100', num_classes=0, pretrained=True, trainable=True):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained=pretrained, num_classes=num_classes, global_pool="max"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

class Mixed_Encoder(nn.Module):
    """
    The 'Mixed Mode' encoder required for your Hindi Space.
    Returns both Logits (for classification) and Features (for style/triplet loss).
    """
    def __init__(self, model_name='mobilenetv2_100', num_classes=300, pretrained=True, trainable=True):
        super().__init__()
        # 1. Create the backbone (MobileNetV2) without the final head
        self.model = timm.create_model(
            model_name, pretrained=pretrained, num_classes=0, global_pool="max"
        )
        
        # 2. Get the number of features the backbone produces (usually 1280 for mobilenetv2_100)
        self.num_features = self.model.num_features
        
        # 3. Add a classification head for the 300 writers
        self.classifier = nn.Linear(self.num_features, num_classes)
        
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        # Extract features (Shape: [Batch, 1280])
        features = self.model(x)
        
        # Calculate logits for writer identification (Shape: [Batch, 300])
        logits = self.classifier(features)
        
        # RETURN BOTH: 
        # Logits go to Classification Loss
        # Features go to Triplet Loss / Style Extractor in Diffusion
        return logits, features