| | import torch |
| | import torch.nn as nn |
| | from transformers import CLIPModel, CLIPProcessor |
| |
|
| | class CLIPClassifier(nn.Module): |
| | def __init__(self, num_classes): |
| | super(CLIPClassifier, self).__init__() |
| | |
| | self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
| | |
| | for param in self.clip.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | self.image_feature_dim = 512 |
| |
|
| | |
| | self.classifier = nn.Linear(self.image_feature_dim, num_classes) |
| | |
| | def forward(self, pixel_values): |
| | |
| | image_features = self.clip.get_image_features(pixel_values=pixel_values) |
| | |
| | logits = self.classifier(image_features) |
| | return logits |
| |
|
| | def get_processor(): |
| | |
| | return CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| |
|