import torch import torch.nn as nn from torchvision import models from transformers import AutoModel class FineGrainedClassifier(nn.Module): def __init__(self, num_classes=434, text_dim=768): super(FineGrainedClassifier, self).__init__() # Image encoder (ResNet50) self.resnet = models.resnet50(pretrained=False) self.resnet.fc = nn.Identity() # Remove final classification layer # Text encoder (Jina embeddings) self.text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True) # Fully connected layers for text embeddings self.text_fc = nn.Linear(text_dim, 1024) # Fusion and classification self.fusion_fc = nn.Linear(2048 + 1024, num_classes) def forward(self, images, text_embeddings): # Extract image features image_features = self.resnet(images) # Process text embeddings text_features = self.text_fc(text_embeddings) # Concatenate image and text features combined = torch.cat((image_features, text_features), dim=1) # Classify output = self.fusion_fc(combined) return output