Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| from transformers import AutoModel | |
| class FineGrainedClassifier(nn.Module): | |
| def __init__(self, num_classes=434, text_dim=768): | |
| super(FineGrainedClassifier, self).__init__() | |
| # Image encoder (ResNet50) | |
| self.resnet = models.resnet50(pretrained=False) | |
| self.resnet.fc = nn.Identity() # Remove final classification layer | |
| # Text encoder (Jina embeddings) | |
| self.text_encoder = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True) | |
| # Fully connected layers for text embeddings | |
| self.text_fc = nn.Linear(text_dim, 1024) | |
| # Fusion and classification | |
| self.fusion_fc = nn.Linear(2048 + 1024, num_classes) | |
| def forward(self, images, text_embeddings): | |
| # Extract image features | |
| image_features = self.resnet(images) | |
| # Process text embeddings | |
| text_features = self.text_fc(text_embeddings) | |
| # Concatenate image and text features | |
| combined = torch.cat((image_features, text_features), dim=1) | |
| # Classify | |
| output = self.fusion_fc(combined) | |
| return output |