isl-api / model.py
Creator-090's picture
fix: prevent setting torch bridge in preprocess_video for numpy compatibility
3067ed1
import io
import torch
import torch.nn as nn
from torchvision.models import video as ptv
from torchvision.transforms import v2
from decord import VideoReader
from decord.bridge import set_bridge
import cv2
import numpy as np
# Classes
CLASSES = [
'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
'dress', 'dry', 'evening', 'expensive', 'famous', 'fast', 'female',
'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse',
'hot', 'hour', 'light', 'long', 'loose', 'loud', 'minute', 'monday',
'month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant',
'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes',
'short', 'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt',
'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly',
'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
]
# Constants
CLIP_LENGTH = 16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_FP16 = DEVICE.type == "cuda" # False on HF free tier (CPU only)
_DTYPE = torch.float16 if USE_FP16 else torch.float32
print(f"[model] device={DEVICE} | fp16={USE_FP16} | dtype={_DTYPE}")
# Global transform pipeline (built once)
TRANSFORMS = v2.Compose([
v2.Resize(224, antialias=True),
v2.CenterCrop(224),
v2.ToDtype(_DTYPE, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# Model
class SwinTClassifications(nn.Module):
def __init__(self, classes, weights="KINETICS400_V1"):
super().__init__()
self.classes = classes
self.base_model = ptv.swin3d_s(weights=weights)
self.classification_head = nn.Sequential(
nn.Linear(self.base_model.head.in_features, len(self.classes))
)
self.base_model.head = nn.Identity()
def forward(self, x):
x = self.base_model(x)
x = self.classification_head(x)
return x
def load_model():
from huggingface_hub import hf_hub_download
print(f"Loading model on {DEVICE} ...")
model_path = hf_hub_download(
repo_id="Creator-090/isl-swin3d-model",
filename="ISL_best_model.pt"
)
model = SwinTClassifications(classes=CLASSES)
model.load_state_dict(
torch.load(model_path, map_location=DEVICE, weights_only=True)
)
model = model.to(DEVICE)
if USE_FP16:
model = model.half()
model.eval()
# torch.compile only on CUDA — can error or be very slow on CPU
if DEVICE.type == "cuda":
print("Compiling model with torch.compile ...")
model = torch.compile(model, mode="reduce-overhead")
_warmup(model)
print("Model ready.")
return model
def _warmup(model):
# 1 round on CPU (warmup is slow ~30s on CPU Swin3D), 3 on GPU
rounds = 1 if DEVICE.type == "cpu" else 3
print(f"Warming up ({rounds} round(s) on {DEVICE}) ...")
dummy = torch.zeros(1, 3, CLIP_LENGTH, 224, 224, device=DEVICE, dtype=_DTYPE)
with torch.no_grad():
for _ in range(rounds):
_ = model(dummy)
if DEVICE.type == "cuda":
torch.cuda.synchronize()
print("Warmup complete.")
# Preprocessing
def _frames_to_tensor(frames: list) -> torch.Tensor:
video = torch.stack([
torch.from_numpy(f).permute(2, 0, 1)
for f in frames
]) # (T, C, H, W) uint8
video = video.to(DEVICE)
video = TRANSFORMS(video) # (T, C, H, W) float
video = video.permute(1, 0, 2, 3) # (C, T, H, W)
return video.unsqueeze(0) # (1, C, T, H, W)
def _pad_or_trim(frames: list, clip_length: int) -> list:
if len(frames) < clip_length:
frames += [frames[-1]] * (clip_length - len(frames))
elif len(frames) > clip_length:
indices = [int(i * len(frames) / clip_length) for i in range(clip_length)]
frames = [frames[i] for i in indices]
return frames
def preprocess_video(video_bytes: bytes, clip_length: int = CLIP_LENGTH) -> torch.Tensor:
# Don't set torch bridge — keep numpy so .asnumpy() works
vr = VideoReader(io.BytesIO(video_bytes))
total = len(vr)
idx = list(range(min(total, clip_length)))
if len(idx) < clip_length:
idx += [idx[-1]] * (clip_length - len(idx))
batch = vr.get_batch(idx).asnumpy() # numpy (T, H, W, C)
frames = [batch[i] for i in range(batch.shape[0])]
return _frames_to_tensor(frames)
def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = CLIP_LENGTH) -> torch.Tensor:
frames = []
for fb in frames_list_bytes:
arr = np.frombuffer(fb, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
frames.append(img)
if not frames:
raise ValueError("No valid frames could be decoded.")
frames = _pad_or_trim(frames, clip_length)
return _frames_to_tensor(frames)
# Inference
def _run_inference(model, pixel_values: torch.Tensor, top_k: int) -> dict:
with torch.no_grad():
if USE_FP16:
# autocast only valid on CUDA
with torch.autocast(device_type="cuda", dtype=torch.float16):
outputs = model(pixel_values)
else:
# CPU path — plain fp32, no autocast
outputs = model(pixel_values)
probs = torch.nn.functional.softmax(outputs, dim=-1)[0]
top_probs, top_indices = torch.topk(probs, k=top_k)
results = [
{"class": CLASSES[top_indices[i].item()], "confidence": float(top_probs[i].item())}
for i in range(top_k)
]
return {
"prediction": results[0]["class"],
"confidence": results[0]["confidence"],
"top_k": results,
}
def predict(model, video_bytes: bytes, top_k: int = 5) -> dict:
pixel_values = preprocess_video(video_bytes)
return _run_inference(model, pixel_values, top_k)
def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5) -> dict:
pixel_values = preprocess_frames(frames_list_bytes)
return _run_inference(model, pixel_values, top_k)