import torch import torch.nn as nn import timm class ImageEncoder(nn.Module): """ Standard encoder to extract features from images. Used during basic training or feature extraction. """ def __init__(self, model_name='mobilenetv2_100', num_classes=0, pretrained=True, trainable=True): super().__init__() self.model = timm.create_model( model_name, pretrained=pretrained, num_classes=num_classes, global_pool="max" ) for p in self.model.parameters(): p.requires_grad = trainable def forward(self, x): return self.model(x) class Mixed_Encoder(nn.Module): """ The 'Mixed Mode' encoder required for your Hindi Space. Returns both Logits (for classification) and Features (for style/triplet loss). """ def __init__(self, model_name='mobilenetv2_100', num_classes=300, pretrained=True, trainable=True): super().__init__() # 1. Create the backbone (MobileNetV2) without the final head self.model = timm.create_model( model_name, pretrained=pretrained, num_classes=0, global_pool="max" ) # 2. Get the number of features the backbone produces (usually 1280 for mobilenetv2_100) self.num_features = self.model.num_features # 3. Add a classification head for the 300 writers self.classifier = nn.Linear(self.num_features, num_classes) for p in self.model.parameters(): p.requires_grad = trainable def forward(self, x): # Extract features (Shape: [Batch, 1280]) features = self.model(x) # Calculate logits for writer identification (Shape: [Batch, 300]) logits = self.classifier(features) # RETURN BOTH: # Logits go to Classification Loss # Features go to Triplet Loss / Style Extractor in Diffusion return logits, features