File size: 1,789 Bytes
31fc7e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
from transformers import CLIPModel

class VariableLengthCLIP(nn.Module):
    def __init__(self, clip_model, num_classes):
        super().__init__()
        self.clip_model = clip_model
        self.visual_projection = nn.Linear(clip_model.visual_projection.in_features, num_classes)

    def forward(self, x):
        batch_size, num_frames, c, h, w = x.shape
        x = x.view(batch_size * num_frames, c, h, w)
        features = self.clip_model.vision_model(x).pooler_output
        features = features.view(batch_size, num_frames, -1)
        features = torch.mean(features, dim=1)  # Average over frames
        return self.visual_projection(features)

    def unfreeze_vision_encoder(self, num_layers=2):
        # Freeze the entire vision encoder
        for param in self.clip_model.vision_model.parameters():
            param.requires_grad = False
        # Unfreeze the last few layers of the vision encoder
        for param in self.clip_model.vision_model.encoder.layers[-num_layers:].parameters():
            param.requires_grad = True

def create_model(num_classes, pretrained_model_name="openai/clip-vit-base-patch32"):
    clip_model = CLIPModel.from_pretrained(pretrained_model_name)
    return VariableLengthCLIP(clip_model, num_classes)

def load_model(num_classes, model_path, device, pretrained_model_name="openai/clip-vit-base-patch32"):
    # Create the model
    model = create_model(num_classes, pretrained_model_name)
    
    # Load the state dict
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    
    # Load the state dict, ignoring mismatched keys
    model.load_state_dict(state_dict, strict=False)
    
    model.to(device)  # Move the model to the appropriate device
    return model