keysun89 commited on
Commit
bdc783f
Β·
verified Β·
1 Parent(s): beecf6a

Create feature_extractor.py

Browse files
Files changed (1) hide show
  1. feature_extractor.py +52 -0
feature_extractor.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+
5
+ class ImageEncoder(nn.Module):
6
+ """
7
+ Standard encoder to extract features from images.
8
+ Used during basic training or feature extraction.
9
+ """
10
+ def __init__(self, model_name='mobilenetv2_100', num_classes=0, pretrained=True, trainable=True):
11
+ super().__init__()
12
+ self.model = timm.create_model(
13
+ model_name, pretrained=pretrained, num_classes=num_classes, global_pool="max"
14
+ )
15
+ for p in self.model.parameters():
16
+ p.requires_grad = trainable
17
+
18
+ def forward(self, x):
19
+ return self.model(x)
20
+
21
+ class Mixed_Encoder(nn.Module):
22
+ """
23
+ The 'Mixed Mode' encoder required for your Hindi Space.
24
+ Returns both Logits (for classification) and Features (for style/triplet loss).
25
+ """
26
+ def __init__(self, model_name='mobilenetv2_100', num_classes=300, pretrained=True, trainable=True):
27
+ super().__init__()
28
+ # 1. Create the backbone (MobileNetV2) without the final head
29
+ self.model = timm.create_model(
30
+ model_name, pretrained=pretrained, num_classes=0, global_pool="max"
31
+ )
32
+
33
+ # 2. Get the number of features the backbone produces (usually 1280 for mobilenetv2_100)
34
+ self.num_features = self.model.num_features
35
+
36
+ # 3. Add a classification head for the 300 writers
37
+ self.classifier = nn.Linear(self.num_features, num_classes)
38
+
39
+ for p in self.model.parameters():
40
+ p.requires_grad = trainable
41
+
42
+ def forward(self, x):
43
+ # Extract features (Shape: [Batch, 1280])
44
+ features = self.model(x)
45
+
46
+ # Calculate logits for writer identification (Shape: [Batch, 300])
47
+ logits = self.classifier(features)
48
+
49
+ # RETURN BOTH:
50
+ # Logits go to Classification Loss
51
+ # Features go to Triplet Loss / Style Extractor in Diffusion
52
+ return logits, features