Creator-090 commited on
Commit
b7dcf66
·
1 Parent(s): c126626

add: implement SwinTClassifications model and video processing functions

Browse files
Files changed (1) hide show
  1. model.py +134 -0
model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import video as ptv
4
+ from torchvision.transforms import v2
5
+ from transformers import VivitImageProcessor
6
+ from decord import VideoReader
7
+ from decord.bridge import set_bridge
8
+ import gc
9
+ import tempfile
10
+ import os
11
+
12
+ # Exactly 76 classes from your notebook metadata
13
+ CLASSES = [
14
+ 'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
15
+ 'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
16
+ 'dress', 'dry', 'evening', 'expensive', 'famous', 'fast', 'female',
17
+ 'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse',
18
+ 'hot', 'hour', 'light', 'long', 'loose', 'loud', 'minute', 'monday',
19
+ 'month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant',
20
+ 'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes',
21
+ 'short', 'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt',
22
+ 'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly',
23
+ 'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
24
+ ]
25
+
26
+ # Constants matched to your hyperparameters
27
+ CLIP_LENGTH = 16
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ class SwinTClassifications(nn.Module):
31
+ """Model architecture from your notebook cell 79/197"""
32
+ def __init__(self, classes, weights="KINETICS400_V1"):
33
+ super().__init__()
34
+ self.classes = classes
35
+ # Load Swin3D-S backbone
36
+ self.base_model = ptv.swin3d_s(weights=weights)
37
+
38
+ # Classification head with your 76 output features
39
+ self.classification_head = nn.Sequential(
40
+ nn.Linear(self.base_model.head.in_features, len(self.classes))
41
+ )
42
+ # Head replaced with Identity as per your architecture
43
+ self.base_model.head = nn.Identity()
44
+
45
+ def forward(self, x):
46
+ x = self.base_model(x)
47
+ x = self.classification_head(x)
48
+ return x
49
+
50
+ def load_model():
51
+ """Downloads best model from your HF repo and loads weights"""
52
+ from huggingface_hub import hf_hub_download
53
+
54
+ print("Fetching model from Hugging Face Hub...")
55
+ model_path = hf_hub_download(
56
+ repo_id="Creator-090/isl-swin3d-model",
57
+ filename="ISL_best_model.pt"
58
+ )
59
+
60
+ model = SwinTClassifications(classes=CLASSES)
61
+ model.load_state_dict(
62
+ torch.load(model_path, map_location=DEVICE, weights_only=True)
63
+ )
64
+ model = model.to(DEVICE)
65
+ model.eval()
66
+ return model
67
+
68
+ def preprocess_video(video_bytes: bytes):
69
+ """Preprocessing logic utilizing VivitImageProcessor and Decord"""
70
+ set_bridge("torch")
71
+
72
+ # Save bytes to temporary file for decord VideoReader
73
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
74
+ f.write(video_bytes)
75
+ tmp_path = f.name
76
+
77
+ try:
78
+ # Manual processor configuration from your notebook
79
+ image_processor = VivitImageProcessor(
80
+ do_resize=True,
81
+ size={"shortest_edge": 224},
82
+ do_center_crop=True,
83
+ crop_size={"height": 224, "width": 224},
84
+ do_rescale=True,
85
+ rescale_factor=1/255,
86
+ do_normalize=True,
87
+ image_mean=[0.5, 0.5, 0.5],
88
+ image_std=[0.5, 0.5, 0.5],
89
+ )
90
+
91
+ vr = VideoReader(tmp_path)
92
+ # Ensure we get exactly CLIP_LENGTH frames
93
+ total_frames = len(vr)
94
+ indices = list(range(min(total_frames, CLIP_LENGTH)))
95
+ if len(indices) < CLIP_LENGTH:
96
+ # Pad if video is too short
97
+ indices += [indices[-1]] * (CLIP_LENGTH - len(indices))
98
+
99
+ video = vr.get_batch(indices)
100
+ # Format: (C, T, H, W) as required by Swin3D
101
+ video = v2.functional.to_dtype(video.permute(0, 3, 1, 2), torch.uint8, scale=False)
102
+
103
+ processed = image_processor(list(video), return_tensors='pt', input_data_format='channels_first')
104
+ pixel_values = processed['pixel_values'].squeeze(0)
105
+ pixel_values = pixel_values.permute(1, 0, 2, 3) # Permute to (C, T, H, W)
106
+
107
+ return pixel_values.unsqueeze(0) # Add batch dimension
108
+ finally:
109
+ if os.path.exists(tmp_path):
110
+ os.remove(tmp_path)
111
+
112
+ def predict(model, video_bytes: bytes, top_k: int = 5):
113
+ """Runs inference and returns the top results"""
114
+ pixel_values = preprocess_video(video_bytes).to(DEVICE)
115
+
116
+ with torch.no_grad():
117
+ # Standardize for CPU/GPU mixed precision
118
+ outputs = model(pixel_values)
119
+ probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
120
+
121
+ top_probs, top_indices = torch.topk(probabilities, k=top_k)
122
+
123
+ results = []
124
+ for i in range(top_k):
125
+ results.append({
126
+ "class": CLASSES[top_indices[i].item()],
127
+ "confidence": float(top_probs[i].item())
128
+ })
129
+
130
+ return {
131
+ "prediction": results[0]["class"],
132
+ "confidence": results[0]["confidence"],
133
+ "top_k": results
134
+ }