|
|
|
|
|
import argparse, json, os |
|
|
import numpy as np |
|
|
import torch |
|
|
from huggingface_hub import snapshot_download |
|
|
from transformers import VideoMAEImageProcessor, AutoModel |
|
|
from decord import VideoReader |
|
|
|
|
|
ID2LABEL = {0: "normal", 1: "abnormal"} |
|
|
|
|
|
class ClassificationHead(torch.nn.Module): |
|
|
def __init__(self, in_dim: int, hidden_dims, num_labels: int = 2, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
dims = [in_dim] + list(hidden_dims) |
|
|
layers = [] |
|
|
for i in range(len(dims) - 1): |
|
|
layers.append(torch.nn.Linear(dims[i], dims[i + 1])) |
|
|
layers.append(torch.nn.GELU()) |
|
|
layers.append(torch.nn.Dropout(dropout)) |
|
|
layers.append(torch.nn.Linear(dims[-1], num_labels)) |
|
|
self.net = torch.nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
def pool_tokens(feats: torch.Tensor, expected_feat_dim: int | None = None) -> torch.Tensor: |
|
|
if feats.dim() != 3: |
|
|
return feats |
|
|
_, d1, d2 = feats.shape |
|
|
if expected_feat_dim is not None: |
|
|
if d1 == expected_feat_dim: |
|
|
return feats.mean(dim=2) |
|
|
if d2 == expected_feat_dim: |
|
|
return feats.mean(dim=1) |
|
|
return feats.mean(dim=2 if d1 <= d2 else 1) |
|
|
|
|
|
def load_frames(path: str, clip_len: int): |
|
|
vr = VideoReader(path) |
|
|
idxs = np.linspace(0, len(vr) - 1, num=clip_len, dtype=int) |
|
|
return [vr[i].asnumpy() for i in idxs] |
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--repo_id", default="happy8825/internvideo_tuned") |
|
|
ap.add_argument("--video", required=True, help="Path to video file") |
|
|
ap.add_argument("--device", default="cuda", help="cuda or cpu") |
|
|
ap.add_argument("--precision", choices=["fp16","bf16","fp32"], default="bf16") |
|
|
args = ap.parse_args() |
|
|
|
|
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
|
|
local = snapshot_download(args.repo_id) |
|
|
cfg = json.load(open(os.path.join(local, "train_config.json"), "r")) |
|
|
base_model = cfg.get("base_model", "revliter/internvideo_next_large_p14_res224_f16") |
|
|
clip_len = int(cfg.get("clip_len", 16)) |
|
|
frame_size = int(cfg.get("frame_size", 224)) |
|
|
hidden = cfg.get("hidden", [512]) |
|
|
feature_dim = cfg.get("feature_dim") or cfg.get("hidden_size") |
|
|
|
|
|
processor = VideoMAEImageProcessor.from_pretrained(base_model) |
|
|
backbone = AutoModel.from_pretrained(base_model, trust_remote_code=True).to(device).eval() |
|
|
if device.type == "cuda" and args.precision != "fp32": |
|
|
target = torch.float16 if args.precision == "fp16" else torch.bfloat16 |
|
|
backbone.to(dtype=target) |
|
|
head = ClassificationHead(in_dim=feature_dim or backbone.config.hidden_size, hidden_dims=hidden) |
|
|
state = torch.load(os.path.join(local, "best_head.pt"), map_location="cpu") |
|
|
head.load_state_dict(state["head"]) |
|
|
head.to(device).eval() |
|
|
|
|
|
frames = load_frames(args.video, clip_len=clip_len) |
|
|
px = processor(frames, return_tensors="pt")["pixel_values"].permute(0,2,1,3,4).to(device) |
|
|
amp_dtype = torch.float16 if args.precision == "fp16" else (torch.bfloat16 if args.precision == "bf16" else None) |
|
|
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=(device.type=="cuda" and amp_dtype is not None)): |
|
|
feats = backbone.extract_features(pixel_values=px) |
|
|
pooled = pool_tokens(feats, expected_feat_dim=feature_dim) |
|
|
logits = head(pooled.float()) |
|
|
pred_id = int(logits.argmax(dim=-1).item()) |
|
|
print(ID2LABEL.get(pred_id, str(pred_id))) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|