File size: 472 Bytes
469c325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch.nn as nn
import torchvision.models.video as models

def build_swin_model():
    print("Initializing Video Swin Transformer...")
    # Using torchvision's Swin3D-T (Tiny)
    # Weights=None for scratch
    model = models.swin3d_t(weights=None) 
    
    # Modify Head for Binary Classification
    # Original head is model.head (Linear)
    num_features = model.head.in_features
    model.head = nn.Linear(num_features, 2)
    
    return model