File size: 1,112 Bytes
44f39b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
class CLIPClassifier(nn.Module):
def __init__(self, num_classes):
super(CLIPClassifier, self).__init__()
# Charger CLIP pré-entraîné
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# Geler les poids de CLIP (optionnel)
for param in self.clip.parameters():
param.requires_grad = False
# Taille correcte des features image
self.image_feature_dim = 512 # CLIP ViT-B/32 retourne des features de taille 512
# Ajouter une couche de classification
self.classifier = nn.Linear(self.image_feature_dim, num_classes)
def forward(self, pixel_values):
# Extraire les features de l'image via CLIP
image_features = self.clip.get_image_features(pixel_values=pixel_values)
# Classification
logits = self.classifier(image_features)
return logits
def get_processor():
# Pour prétraiter les images
return CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|