Spaces:
Runtime error
Runtime error
File size: 1,932 Bytes
bdc783f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import torch.nn as nn
import timm
class ImageEncoder(nn.Module):
"""
Standard encoder to extract features from images.
Used during basic training or feature extraction.
"""
def __init__(self, model_name='mobilenetv2_100', num_classes=0, pretrained=True, trainable=True):
super().__init__()
self.model = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool="max"
)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
return self.model(x)
class Mixed_Encoder(nn.Module):
"""
The 'Mixed Mode' encoder required for your Hindi Space.
Returns both Logits (for classification) and Features (for style/triplet loss).
"""
def __init__(self, model_name='mobilenetv2_100', num_classes=300, pretrained=True, trainable=True):
super().__init__()
# 1. Create the backbone (MobileNetV2) without the final head
self.model = timm.create_model(
model_name, pretrained=pretrained, num_classes=0, global_pool="max"
)
# 2. Get the number of features the backbone produces (usually 1280 for mobilenetv2_100)
self.num_features = self.model.num_features
# 3. Add a classification head for the 300 writers
self.classifier = nn.Linear(self.num_features, num_classes)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
# Extract features (Shape: [Batch, 1280])
features = self.model(x)
# Calculate logits for writer identification (Shape: [Batch, 300])
logits = self.classifier(features)
# RETURN BOTH:
# Logits go to Classification Loss
# Features go to Triplet Loss / Style Extractor in Diffusion
return logits, features |