Spaces:
Runtime error
Runtime error
Create feature_extractor.py
Browse files- feature_extractor.py +52 -0
feature_extractor.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import timm
|
| 4 |
+
|
| 5 |
+
class ImageEncoder(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Standard encoder to extract features from images.
|
| 8 |
+
Used during basic training or feature extraction.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, model_name='mobilenetv2_100', num_classes=0, pretrained=True, trainable=True):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.model = timm.create_model(
|
| 13 |
+
model_name, pretrained=pretrained, num_classes=num_classes, global_pool="max"
|
| 14 |
+
)
|
| 15 |
+
for p in self.model.parameters():
|
| 16 |
+
p.requires_grad = trainable
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return self.model(x)
|
| 20 |
+
|
| 21 |
+
class Mixed_Encoder(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
The 'Mixed Mode' encoder required for your Hindi Space.
|
| 24 |
+
Returns both Logits (for classification) and Features (for style/triplet loss).
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, model_name='mobilenetv2_100', num_classes=300, pretrained=True, trainable=True):
|
| 27 |
+
super().__init__()
|
| 28 |
+
# 1. Create the backbone (MobileNetV2) without the final head
|
| 29 |
+
self.model = timm.create_model(
|
| 30 |
+
model_name, pretrained=pretrained, num_classes=0, global_pool="max"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# 2. Get the number of features the backbone produces (usually 1280 for mobilenetv2_100)
|
| 34 |
+
self.num_features = self.model.num_features
|
| 35 |
+
|
| 36 |
+
# 3. Add a classification head for the 300 writers
|
| 37 |
+
self.classifier = nn.Linear(self.num_features, num_classes)
|
| 38 |
+
|
| 39 |
+
for p in self.model.parameters():
|
| 40 |
+
p.requires_grad = trainable
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
# Extract features (Shape: [Batch, 1280])
|
| 44 |
+
features = self.model(x)
|
| 45 |
+
|
| 46 |
+
# Calculate logits for writer identification (Shape: [Batch, 300])
|
| 47 |
+
logits = self.classifier(features)
|
| 48 |
+
|
| 49 |
+
# RETURN BOTH:
|
| 50 |
+
# Logits go to Classification Loss
|
| 51 |
+
# Features go to Triplet Loss / Style Extractor in Diffusion
|
| 52 |
+
return logits, features
|