Spaces:
Sleeping
Sleeping
Commit ·
b7dcf66
1
Parent(s): c126626
add: implement SwinTClassifications model and video processing functions
Browse files
model.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision.models import video as ptv
|
| 4 |
+
from torchvision.transforms import v2
|
| 5 |
+
from transformers import VivitImageProcessor
|
| 6 |
+
from decord import VideoReader
|
| 7 |
+
from decord.bridge import set_bridge
|
| 8 |
+
import gc
|
| 9 |
+
import tempfile
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Exactly 76 classes from your notebook metadata
|
| 13 |
+
CLASSES = [
|
| 14 |
+
'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
|
| 15 |
+
'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
|
| 16 |
+
'dress', 'dry', 'evening', 'expensive', 'famous', 'fast', 'female',
|
| 17 |
+
'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse',
|
| 18 |
+
'hot', 'hour', 'light', 'long', 'loose', 'loud', 'minute', 'monday',
|
| 19 |
+
'month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant',
|
| 20 |
+
'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes',
|
| 21 |
+
'short', 'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt',
|
| 22 |
+
'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly',
|
| 23 |
+
'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# Constants matched to your hyperparameters
|
| 27 |
+
CLIP_LENGTH = 16
|
| 28 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
class SwinTClassifications(nn.Module):
|
| 31 |
+
"""Model architecture from your notebook cell 79/197"""
|
| 32 |
+
def __init__(self, classes, weights="KINETICS400_V1"):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.classes = classes
|
| 35 |
+
# Load Swin3D-S backbone
|
| 36 |
+
self.base_model = ptv.swin3d_s(weights=weights)
|
| 37 |
+
|
| 38 |
+
# Classification head with your 76 output features
|
| 39 |
+
self.classification_head = nn.Sequential(
|
| 40 |
+
nn.Linear(self.base_model.head.in_features, len(self.classes))
|
| 41 |
+
)
|
| 42 |
+
# Head replaced with Identity as per your architecture
|
| 43 |
+
self.base_model.head = nn.Identity()
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.base_model(x)
|
| 47 |
+
x = self.classification_head(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
def load_model():
|
| 51 |
+
"""Downloads best model from your HF repo and loads weights"""
|
| 52 |
+
from huggingface_hub import hf_hub_download
|
| 53 |
+
|
| 54 |
+
print("Fetching model from Hugging Face Hub...")
|
| 55 |
+
model_path = hf_hub_download(
|
| 56 |
+
repo_id="Creator-090/isl-swin3d-model",
|
| 57 |
+
filename="ISL_best_model.pt"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
model = SwinTClassifications(classes=CLASSES)
|
| 61 |
+
model.load_state_dict(
|
| 62 |
+
torch.load(model_path, map_location=DEVICE, weights_only=True)
|
| 63 |
+
)
|
| 64 |
+
model = model.to(DEVICE)
|
| 65 |
+
model.eval()
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
def preprocess_video(video_bytes: bytes):
|
| 69 |
+
"""Preprocessing logic utilizing VivitImageProcessor and Decord"""
|
| 70 |
+
set_bridge("torch")
|
| 71 |
+
|
| 72 |
+
# Save bytes to temporary file for decord VideoReader
|
| 73 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
| 74 |
+
f.write(video_bytes)
|
| 75 |
+
tmp_path = f.name
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
# Manual processor configuration from your notebook
|
| 79 |
+
image_processor = VivitImageProcessor(
|
| 80 |
+
do_resize=True,
|
| 81 |
+
size={"shortest_edge": 224},
|
| 82 |
+
do_center_crop=True,
|
| 83 |
+
crop_size={"height": 224, "width": 224},
|
| 84 |
+
do_rescale=True,
|
| 85 |
+
rescale_factor=1/255,
|
| 86 |
+
do_normalize=True,
|
| 87 |
+
image_mean=[0.5, 0.5, 0.5],
|
| 88 |
+
image_std=[0.5, 0.5, 0.5],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
vr = VideoReader(tmp_path)
|
| 92 |
+
# Ensure we get exactly CLIP_LENGTH frames
|
| 93 |
+
total_frames = len(vr)
|
| 94 |
+
indices = list(range(min(total_frames, CLIP_LENGTH)))
|
| 95 |
+
if len(indices) < CLIP_LENGTH:
|
| 96 |
+
# Pad if video is too short
|
| 97 |
+
indices += [indices[-1]] * (CLIP_LENGTH - len(indices))
|
| 98 |
+
|
| 99 |
+
video = vr.get_batch(indices)
|
| 100 |
+
# Format: (C, T, H, W) as required by Swin3D
|
| 101 |
+
video = v2.functional.to_dtype(video.permute(0, 3, 1, 2), torch.uint8, scale=False)
|
| 102 |
+
|
| 103 |
+
processed = image_processor(list(video), return_tensors='pt', input_data_format='channels_first')
|
| 104 |
+
pixel_values = processed['pixel_values'].squeeze(0)
|
| 105 |
+
pixel_values = pixel_values.permute(1, 0, 2, 3) # Permute to (C, T, H, W)
|
| 106 |
+
|
| 107 |
+
return pixel_values.unsqueeze(0) # Add batch dimension
|
| 108 |
+
finally:
|
| 109 |
+
if os.path.exists(tmp_path):
|
| 110 |
+
os.remove(tmp_path)
|
| 111 |
+
|
| 112 |
+
def predict(model, video_bytes: bytes, top_k: int = 5):
|
| 113 |
+
"""Runs inference and returns the top results"""
|
| 114 |
+
pixel_values = preprocess_video(video_bytes).to(DEVICE)
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
# Standardize for CPU/GPU mixed precision
|
| 118 |
+
outputs = model(pixel_values)
|
| 119 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
|
| 120 |
+
|
| 121 |
+
top_probs, top_indices = torch.topk(probabilities, k=top_k)
|
| 122 |
+
|
| 123 |
+
results = []
|
| 124 |
+
for i in range(top_k):
|
| 125 |
+
results.append({
|
| 126 |
+
"class": CLASSES[top_indices[i].item()],
|
| 127 |
+
"confidence": float(top_probs[i].item())
|
| 128 |
+
})
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"prediction": results[0]["class"],
|
| 132 |
+
"confidence": results[0]["confidence"],
|
| 133 |
+
"top_k": results
|
| 134 |
+
}
|