#!/usr/bin/env python3 import argparse, json, os import numpy as np import torch from huggingface_hub import snapshot_download from transformers import VideoMAEImageProcessor, AutoModel from decord import VideoReader ID2LABEL = {0: "normal", 1: "abnormal"} class ClassificationHead(torch.nn.Module): def __init__(self, in_dim: int, hidden_dims, num_labels: int = 2, dropout: float = 0.1): super().__init__() dims = [in_dim] + list(hidden_dims) layers = [] for i in range(len(dims) - 1): layers.append(torch.nn.Linear(dims[i], dims[i + 1])) layers.append(torch.nn.GELU()) layers.append(torch.nn.Dropout(dropout)) layers.append(torch.nn.Linear(dims[-1], num_labels)) self.net = torch.nn.Sequential(*layers) def forward(self, x): return self.net(x) def pool_tokens(feats: torch.Tensor, expected_feat_dim: int | None = None) -> torch.Tensor: if feats.dim() != 3: return feats _, d1, d2 = feats.shape if expected_feat_dim is not None: if d1 == expected_feat_dim: return feats.mean(dim=2) if d2 == expected_feat_dim: return feats.mean(dim=1) return feats.mean(dim=2 if d1 <= d2 else 1) def load_frames(path: str, clip_len: int): vr = VideoReader(path) idxs = np.linspace(0, len(vr) - 1, num=clip_len, dtype=int) return [vr[i].asnumpy() for i in idxs] def main(): ap = argparse.ArgumentParser() ap.add_argument("--repo_id", default="happy8825/internvideo_tuned") ap.add_argument("--video", required=True, help="Path to video file") ap.add_argument("--device", default="cuda", help="cuda or cpu") ap.add_argument("--precision", choices=["fp16","bf16","fp32"], default="bf16") args = ap.parse_args() device = torch.device(args.device if torch.cuda.is_available() else "cpu") local = snapshot_download(args.repo_id) cfg = json.load(open(os.path.join(local, "train_config.json"), "r")) base_model = cfg.get("base_model", "revliter/internvideo_next_large_p14_res224_f16") clip_len = int(cfg.get("clip_len", 16)) frame_size = int(cfg.get("frame_size", 224)) hidden = cfg.get("hidden", [512]) feature_dim = cfg.get("feature_dim") or cfg.get("hidden_size") processor = VideoMAEImageProcessor.from_pretrained(base_model) backbone = AutoModel.from_pretrained(base_model, trust_remote_code=True).to(device).eval() if device.type == "cuda" and args.precision != "fp32": target = torch.float16 if args.precision == "fp16" else torch.bfloat16 backbone.to(dtype=target) head = ClassificationHead(in_dim=feature_dim or backbone.config.hidden_size, hidden_dims=hidden) state = torch.load(os.path.join(local, "best_head.pt"), map_location="cpu") head.load_state_dict(state["head"]) head.to(device).eval() frames = load_frames(args.video, clip_len=clip_len) px = processor(frames, return_tensors="pt")["pixel_values"].permute(0,2,1,3,4).to(device) amp_dtype = torch.float16 if args.precision == "fp16" else (torch.bfloat16 if args.precision == "bf16" else None) with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=(device.type=="cuda" and amp_dtype is not None)): feats = backbone.extract_features(pixel_values=px) pooled = pool_tokens(feats, expected_feat_dim=feature_dim) logits = head(pooled.float()) pred_id = int(logits.argmax(dim=-1).item()) print(ID2LABEL.get(pred_id, str(pred_id))) if __name__ == "__main__": main()