Spaces:
Sleeping
Sleeping
| import io | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import video as ptv | |
| from torchvision.transforms import v2 | |
| from decord import VideoReader | |
| from decord.bridge import set_bridge | |
| import cv2 | |
| import numpy as np | |
| # Classes | |
| CLASSES = [ | |
| 'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind', | |
| 'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog', | |
| 'dress', 'dry', 'evening', 'expensive', 'famous', 'fast', 'female', | |
| 'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse', | |
| 'hot', 'hour', 'light', 'long', 'loose', 'loud', 'minute', 'monday', | |
| 'month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant', | |
| 'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes', | |
| 'short', 'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt', | |
| 'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly', | |
| 'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young' | |
| ] | |
| # Constants | |
| CLIP_LENGTH = 16 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| USE_FP16 = DEVICE.type == "cuda" # False on HF free tier (CPU only) | |
| _DTYPE = torch.float16 if USE_FP16 else torch.float32 | |
| print(f"[model] device={DEVICE} | fp16={USE_FP16} | dtype={_DTYPE}") | |
| # Global transform pipeline (built once) | |
| TRANSFORMS = v2.Compose([ | |
| v2.Resize(224, antialias=True), | |
| v2.CenterCrop(224), | |
| v2.ToDtype(_DTYPE, scale=True), | |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| # Model | |
| class SwinTClassifications(nn.Module): | |
| def __init__(self, classes, weights="KINETICS400_V1"): | |
| super().__init__() | |
| self.classes = classes | |
| self.base_model = ptv.swin3d_s(weights=weights) | |
| self.classification_head = nn.Sequential( | |
| nn.Linear(self.base_model.head.in_features, len(self.classes)) | |
| ) | |
| self.base_model.head = nn.Identity() | |
| def forward(self, x): | |
| x = self.base_model(x) | |
| x = self.classification_head(x) | |
| return x | |
| def load_model(): | |
| from huggingface_hub import hf_hub_download | |
| print(f"Loading model on {DEVICE} ...") | |
| model_path = hf_hub_download( | |
| repo_id="Creator-090/isl-swin3d-model", | |
| filename="ISL_best_model.pt" | |
| ) | |
| model = SwinTClassifications(classes=CLASSES) | |
| model.load_state_dict( | |
| torch.load(model_path, map_location=DEVICE, weights_only=True) | |
| ) | |
| model = model.to(DEVICE) | |
| if USE_FP16: | |
| model = model.half() | |
| model.eval() | |
| # torch.compile only on CUDA — can error or be very slow on CPU | |
| if DEVICE.type == "cuda": | |
| print("Compiling model with torch.compile ...") | |
| model = torch.compile(model, mode="reduce-overhead") | |
| _warmup(model) | |
| print("Model ready.") | |
| return model | |
| def _warmup(model): | |
| # 1 round on CPU (warmup is slow ~30s on CPU Swin3D), 3 on GPU | |
| rounds = 1 if DEVICE.type == "cpu" else 3 | |
| print(f"Warming up ({rounds} round(s) on {DEVICE}) ...") | |
| dummy = torch.zeros(1, 3, CLIP_LENGTH, 224, 224, device=DEVICE, dtype=_DTYPE) | |
| with torch.no_grad(): | |
| for _ in range(rounds): | |
| _ = model(dummy) | |
| if DEVICE.type == "cuda": | |
| torch.cuda.synchronize() | |
| print("Warmup complete.") | |
| # Preprocessing | |
| def _frames_to_tensor(frames: list) -> torch.Tensor: | |
| video = torch.stack([ | |
| torch.from_numpy(f).permute(2, 0, 1) | |
| for f in frames | |
| ]) # (T, C, H, W) uint8 | |
| video = video.to(DEVICE) | |
| video = TRANSFORMS(video) # (T, C, H, W) float | |
| video = video.permute(1, 0, 2, 3) # (C, T, H, W) | |
| return video.unsqueeze(0) # (1, C, T, H, W) | |
| def _pad_or_trim(frames: list, clip_length: int) -> list: | |
| if len(frames) < clip_length: | |
| frames += [frames[-1]] * (clip_length - len(frames)) | |
| elif len(frames) > clip_length: | |
| indices = [int(i * len(frames) / clip_length) for i in range(clip_length)] | |
| frames = [frames[i] for i in indices] | |
| return frames | |
| def preprocess_video(video_bytes: bytes, clip_length: int = CLIP_LENGTH) -> torch.Tensor: | |
| # Don't set torch bridge — keep numpy so .asnumpy() works | |
| vr = VideoReader(io.BytesIO(video_bytes)) | |
| total = len(vr) | |
| idx = list(range(min(total, clip_length))) | |
| if len(idx) < clip_length: | |
| idx += [idx[-1]] * (clip_length - len(idx)) | |
| batch = vr.get_batch(idx).asnumpy() # numpy (T, H, W, C) | |
| frames = [batch[i] for i in range(batch.shape[0])] | |
| return _frames_to_tensor(frames) | |
| def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = CLIP_LENGTH) -> torch.Tensor: | |
| frames = [] | |
| for fb in frames_list_bytes: | |
| arr = np.frombuffer(fb, np.uint8) | |
| img = cv2.imdecode(arr, cv2.IMREAD_COLOR) | |
| if img is None: | |
| continue | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| frames.append(img) | |
| if not frames: | |
| raise ValueError("No valid frames could be decoded.") | |
| frames = _pad_or_trim(frames, clip_length) | |
| return _frames_to_tensor(frames) | |
| # Inference | |
| def _run_inference(model, pixel_values: torch.Tensor, top_k: int) -> dict: | |
| with torch.no_grad(): | |
| if USE_FP16: | |
| # autocast only valid on CUDA | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| outputs = model(pixel_values) | |
| else: | |
| # CPU path — plain fp32, no autocast | |
| outputs = model(pixel_values) | |
| probs = torch.nn.functional.softmax(outputs, dim=-1)[0] | |
| top_probs, top_indices = torch.topk(probs, k=top_k) | |
| results = [ | |
| {"class": CLASSES[top_indices[i].item()], "confidence": float(top_probs[i].item())} | |
| for i in range(top_k) | |
| ] | |
| return { | |
| "prediction": results[0]["class"], | |
| "confidence": results[0]["confidence"], | |
| "top_k": results, | |
| } | |
| def predict(model, video_bytes: bytes, top_k: int = 5) -> dict: | |
| pixel_values = preprocess_video(video_bytes) | |
| return _run_inference(model, pixel_values, top_k) | |
| def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5) -> dict: | |
| pixel_values = preprocess_frames(frames_list_bytes) | |
| return _run_inference(model, pixel_values, top_k) |