runthebandsup commited on
Commit
b0ca6f6
·
verified ·
1 Parent(s): b426041

Add model.py with FineGrainedClassifier definition

Browse files
Files changed (1) hide show
  1. model.py +35 -0
model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ from transformers import AutoModel
5
+
6
+ class FineGrainedClassifier(nn.Module):
7
+ def __init__(self, num_classes=434, text_dim=768):
8
+ super(FineGrainedClassifier, self).__init__()
9
+
10
+ # Image encoder (ResNet50)
11
+ self.resnet = models.resnet50(pretrained=False)
12
+ self.resnet.fc = nn.Identity() # Remove final classification layer
13
+
14
+ # Text encoder (Jina embeddings)
15
+ self.text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
16
+
17
+ # Fully connected layers for text embeddings
18
+ self.text_fc = nn.Linear(text_dim, 1024)
19
+
20
+ # Fusion and classification
21
+ self.fusion_fc = nn.Linear(2048 + 1024, num_classes)
22
+
23
+ def forward(self, images, text_embeddings):
24
+ # Extract image features
25
+ image_features = self.resnet(images)
26
+
27
+ # Process text embeddings
28
+ text_features = self.text_fc(text_embeddings)
29
+
30
+ # Concatenate image and text features
31
+ combined = torch.cat((image_features, text_features), dim=1)
32
+
33
+ # Classify
34
+ output = self.fusion_fc(combined)
35
+ return output