File size: 1,192 Bytes
b0ca6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from torchvision import models
from transformers import AutoModel

class FineGrainedClassifier(nn.Module):
    def __init__(self, num_classes=434, text_dim=768):
        super(FineGrainedClassifier, self).__init__()

        # Image encoder (ResNet50)
        self.resnet = models.resnet50(pretrained=False)
        self.resnet.fc = nn.Identity()  # Remove final classification layer

        # Text encoder (Jina embeddings)
        self.text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)

        # Fully connected layers for text embeddings
        self.text_fc = nn.Linear(text_dim, 1024)

        # Fusion and classification
        self.fusion_fc = nn.Linear(2048 + 1024, num_classes)

    def forward(self, images, text_embeddings):
        # Extract image features
        image_features = self.resnet(images)

        # Process text embeddings
        text_features = self.text_fc(text_embeddings)

        # Concatenate image and text features
        combined = torch.cat((image_features, text_features), dim=1)

        # Classify
        output = self.fusion_fc(combined)
        return output