File size: 6,270 Bytes
1cf4369
b7dcf66
 
 
 
 
 
de9af52
 
b7dcf66
86c7cf3
b7dcf66
1cf4369
 
 
 
 
 
 
 
 
b7dcf66
 
 
86c7cf3
1cf4369
 
86c7cf3
 
1cf4369
86c7cf3
1cf4369
86c7cf3
1cf4369
86c7cf3
 
 
1cf4369
 
b7dcf66
86c7cf3
b7dcf66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cf4369
b7dcf66
 
1cf4369
86c7cf3
b7dcf66
1cf4369
b7dcf66
 
1cf4369
b7dcf66
 
 
 
 
1cf4369
 
 
 
b7dcf66
1cf4369
86c7cf3
1cf4369
86c7cf3
1cf4369
 
 
 
b7dcf66
 
1cf4369
86c7cf3
 
 
 
1cf4369
 
 
 
 
 
 
 
 
86c7cf3
1cf4369
053568e
86c7cf3
1cf4369
86c7cf3
 
 
 
 
1cf4369
 
 
 
 
 
 
 
 
de9af52
1cf4369
 
3067ed1
86c7cf3
 
 
1cf4369
 
3067ed1
 
86c7cf3
1cf4369
 
 
 
103cb2c
1cf4369
 
 
 
 
86c7cf3
1cf4369
 
103cb2c
86c7cf3
1cf4369
 
 
 
 
 
 
053568e
86c7cf3
 
 
 
 
 
1cf4369
 
 
 
 
 
 
 
 
053568e
 
 
1cf4369
053568e
 
1cf4369
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import io
import torch
import torch.nn as nn
from torchvision.models import video as ptv
from torchvision.transforms import v2
from decord import VideoReader
from decord.bridge import set_bridge
import cv2
import numpy as np

#  Classes 
CLASSES = [
    'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
    'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
    'dress', 'dry', 'evening', 'expensive', 'famous', 'fast', 'female',
    'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse',
    'hot', 'hour', 'light', 'long', 'loose', 'loud', 'minute', 'monday',
    'month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant',
    'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes',
    'short', 'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt',
    'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly',
    'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
]

#  Constants 
CLIP_LENGTH = 16
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_FP16    = DEVICE.type == "cuda"   # False on HF free tier (CPU only)
_DTYPE      = torch.float16 if USE_FP16 else torch.float32

print(f"[model] device={DEVICE} | fp16={USE_FP16} | dtype={_DTYPE}")

# Global transform pipeline (built once)
TRANSFORMS = v2.Compose([
    v2.Resize(224, antialias=True),
    v2.CenterCrop(224),
    v2.ToDtype(_DTYPE, scale=True),
    v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Model
class SwinTClassifications(nn.Module):
    def __init__(self, classes, weights="KINETICS400_V1"):
        super().__init__()
        self.classes = classes
        self.base_model = ptv.swin3d_s(weights=weights)
        self.classification_head = nn.Sequential(
            nn.Linear(self.base_model.head.in_features, len(self.classes))
        )
        self.base_model.head = nn.Identity()

    def forward(self, x):
        x = self.base_model(x)
        x = self.classification_head(x)
        return x


def load_model():
    from huggingface_hub import hf_hub_download

    print(f"Loading model on {DEVICE} ...")
    model_path = hf_hub_download(
        repo_id="Creator-090/isl-swin3d-model",
        filename="ISL_best_model.pt"
    )

    model = SwinTClassifications(classes=CLASSES)
    model.load_state_dict(
        torch.load(model_path, map_location=DEVICE, weights_only=True)
    )
    model = model.to(DEVICE)

    if USE_FP16:
        model = model.half()

    model.eval()

    # torch.compile only on CUDA — can error or be very slow on CPU
    if DEVICE.type == "cuda":
        print("Compiling model with torch.compile ...")
        model = torch.compile(model, mode="reduce-overhead")

    _warmup(model)
    print("Model ready.")
    return model


def _warmup(model):
    # 1 round on CPU (warmup is slow ~30s on CPU Swin3D), 3 on GPU
    rounds = 1 if DEVICE.type == "cpu" else 3
    print(f"Warming up ({rounds} round(s) on {DEVICE}) ...")
    dummy = torch.zeros(1, 3, CLIP_LENGTH, 224, 224, device=DEVICE, dtype=_DTYPE)
    with torch.no_grad():
        for _ in range(rounds):
            _ = model(dummy)
    if DEVICE.type == "cuda":
        torch.cuda.synchronize()
    print("Warmup complete.")


# Preprocessing 
def _frames_to_tensor(frames: list) -> torch.Tensor:
    video = torch.stack([
        torch.from_numpy(f).permute(2, 0, 1)
        for f in frames
    ])                                   # (T, C, H, W) uint8
    video = video.to(DEVICE)
    video = TRANSFORMS(video)           # (T, C, H, W) float
    video = video.permute(1, 0, 2, 3)  # (C, T, H, W)
    return video.unsqueeze(0)           # (1, C, T, H, W)


def _pad_or_trim(frames: list, clip_length: int) -> list:
    if len(frames) < clip_length:
        frames += [frames[-1]] * (clip_length - len(frames))
    elif len(frames) > clip_length:
        indices = [int(i * len(frames) / clip_length) for i in range(clip_length)]
        frames  = [frames[i] for i in indices]
    return frames


def preprocess_video(video_bytes: bytes, clip_length: int = CLIP_LENGTH) -> torch.Tensor:
    # Don't set torch bridge — keep numpy so .asnumpy() works
    vr    = VideoReader(io.BytesIO(video_bytes))
    total = len(vr)
    idx   = list(range(min(total, clip_length)))
    if len(idx) < clip_length:
        idx += [idx[-1]] * (clip_length - len(idx))

    batch  = vr.get_batch(idx).asnumpy()          # numpy (T, H, W, C)
    frames = [batch[i] for i in range(batch.shape[0])]
    return _frames_to_tensor(frames)


def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = CLIP_LENGTH) -> torch.Tensor:
    frames = []
    for fb in frames_list_bytes:
        arr = np.frombuffer(fb, np.uint8)
        img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        frames.append(img)

    if not frames:
        raise ValueError("No valid frames could be decoded.")

    frames = _pad_or_trim(frames, clip_length)
    return _frames_to_tensor(frames)


# Inference 
def _run_inference(model, pixel_values: torch.Tensor, top_k: int) -> dict:
    with torch.no_grad():
        if USE_FP16:
            # autocast only valid on CUDA
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(pixel_values)
        else:
            # CPU path — plain fp32, no autocast
            outputs = model(pixel_values)

        probs = torch.nn.functional.softmax(outputs, dim=-1)[0]

    top_probs, top_indices = torch.topk(probs, k=top_k)
    results = [
        {"class": CLASSES[top_indices[i].item()], "confidence": float(top_probs[i].item())}
        for i in range(top_k)
    ]
    return {
        "prediction": results[0]["class"],
        "confidence": results[0]["confidence"],
        "top_k":      results,
    }


def predict(model, video_bytes: bytes, top_k: int = 5) -> dict:
    pixel_values = preprocess_video(video_bytes)
    return _run_inference(model, pixel_values, top_k)


def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5) -> dict:
    pixel_values = preprocess_frames(frames_list_bytes)
    return _run_inference(model, pixel_values, top_k)