Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import timm | |
| class ImageEncoder(nn.Module): | |
| """ | |
| Standard encoder to extract features from images. | |
| Used during basic training or feature extraction. | |
| """ | |
| def __init__(self, model_name='mobilenetv2_100', num_classes=0, pretrained=True, trainable=True): | |
| super().__init__() | |
| self.model = timm.create_model( | |
| model_name, pretrained=pretrained, num_classes=num_classes, global_pool="max" | |
| ) | |
| for p in self.model.parameters(): | |
| p.requires_grad = trainable | |
| def forward(self, x): | |
| return self.model(x) | |
| class Mixed_Encoder(nn.Module): | |
| """ | |
| The 'Mixed Mode' encoder required for your Hindi Space. | |
| Returns both Logits (for classification) and Features (for style/triplet loss). | |
| """ | |
| def __init__(self, model_name='mobilenetv2_100', num_classes=300, pretrained=True, trainable=True): | |
| super().__init__() | |
| # 1. Create the backbone (MobileNetV2) without the final head | |
| self.model = timm.create_model( | |
| model_name, pretrained=pretrained, num_classes=0, global_pool="max" | |
| ) | |
| # 2. Get the number of features the backbone produces (usually 1280 for mobilenetv2_100) | |
| self.num_features = self.model.num_features | |
| # 3. Add a classification head for the 300 writers | |
| self.classifier = nn.Linear(self.num_features, num_classes) | |
| for p in self.model.parameters(): | |
| p.requires_grad = trainable | |
| def forward(self, x): | |
| # Extract features (Shape: [Batch, 1280]) | |
| features = self.model(x) | |
| # Calculate logits for writer identification (Shape: [Batch, 300]) | |
| logits = self.classifier(features) | |
| # RETURN BOTH: | |
| # Logits go to Classification Loss | |
| # Features go to Triplet Loss / Style Extractor in Diffusion | |
| return logits, features |