Spaces:
Sleeping
Sleeping
File size: 1,192 Bytes
b0ca6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torch.nn as nn
from torchvision import models
from transformers import AutoModel
class FineGrainedClassifier(nn.Module):
def __init__(self, num_classes=434, text_dim=768):
super(FineGrainedClassifier, self).__init__()
# Image encoder (ResNet50)
self.resnet = models.resnet50(pretrained=False)
self.resnet.fc = nn.Identity() # Remove final classification layer
# Text encoder (Jina embeddings)
self.text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
# Fully connected layers for text embeddings
self.text_fc = nn.Linear(text_dim, 1024)
# Fusion and classification
self.fusion_fc = nn.Linear(2048 + 1024, num_classes)
def forward(self, images, text_embeddings):
# Extract image features
image_features = self.resnet(images)
# Process text embeddings
text_features = self.text_fc(text_embeddings)
# Concatenate image and text features
combined = torch.cat((image_features, text_features), dim=1)
# Classify
output = self.fusion_fc(combined)
return output |