File size: 3,563 Bytes
aff893b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
import argparse, json, os
import numpy as np
import torch
from huggingface_hub import snapshot_download
from transformers import VideoMAEImageProcessor, AutoModel
from decord import VideoReader

ID2LABEL = {0: "normal", 1: "abnormal"}

class ClassificationHead(torch.nn.Module):
    def __init__(self, in_dim: int, hidden_dims, num_labels: int = 2, dropout: float = 0.1):
        super().__init__()
        dims = [in_dim] + list(hidden_dims)
        layers = []
        for i in range(len(dims) - 1):
            layers.append(torch.nn.Linear(dims[i], dims[i + 1]))
            layers.append(torch.nn.GELU())
            layers.append(torch.nn.Dropout(dropout))
        layers.append(torch.nn.Linear(dims[-1], num_labels))
        self.net = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

def pool_tokens(feats: torch.Tensor, expected_feat_dim: int | None = None) -> torch.Tensor:
    if feats.dim() != 3:
        return feats
    _, d1, d2 = feats.shape
    if expected_feat_dim is not None:
        if d1 == expected_feat_dim:
            return feats.mean(dim=2)
        if d2 == expected_feat_dim:
            return feats.mean(dim=1)
    return feats.mean(dim=2 if d1 <= d2 else 1)

def load_frames(path: str, clip_len: int):
    vr = VideoReader(path)
    idxs = np.linspace(0, len(vr) - 1, num=clip_len, dtype=int)
    return [vr[i].asnumpy() for i in idxs]

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--repo_id", default="happy8825/internvideo_tuned")
    ap.add_argument("--video", required=True, help="Path to video file")
    ap.add_argument("--device", default="cuda", help="cuda or cpu")
    ap.add_argument("--precision", choices=["fp16","bf16","fp32"], default="bf16")
    args = ap.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    local = snapshot_download(args.repo_id)
    cfg = json.load(open(os.path.join(local, "train_config.json"), "r"))
    base_model = cfg.get("base_model", "revliter/internvideo_next_large_p14_res224_f16")
    clip_len = int(cfg.get("clip_len", 16))
    frame_size = int(cfg.get("frame_size", 224))
    hidden = cfg.get("hidden", [512])
    feature_dim = cfg.get("feature_dim") or cfg.get("hidden_size")

    processor = VideoMAEImageProcessor.from_pretrained(base_model)
    backbone = AutoModel.from_pretrained(base_model, trust_remote_code=True).to(device).eval()
    if device.type == "cuda" and args.precision != "fp32":
        target = torch.float16 if args.precision == "fp16" else torch.bfloat16
        backbone.to(dtype=target)
    head = ClassificationHead(in_dim=feature_dim or backbone.config.hidden_size, hidden_dims=hidden)
    state = torch.load(os.path.join(local, "best_head.pt"), map_location="cpu")
    head.load_state_dict(state["head"])
    head.to(device).eval()

    frames = load_frames(args.video, clip_len=clip_len)
    px = processor(frames, return_tensors="pt")["pixel_values"].permute(0,2,1,3,4).to(device)
    amp_dtype = torch.float16 if args.precision == "fp16" else (torch.bfloat16 if args.precision == "bf16" else None)
    with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=(device.type=="cuda" and amp_dtype is not None)):
        feats = backbone.extract_features(pixel_values=px)
    pooled = pool_tokens(feats, expected_feat_dim=feature_dim)
    logits = head(pooled.float())
    pred_id = int(logits.argmax(dim=-1).item())
    print(ID2LABEL.get(pred_id, str(pred_id)))

if __name__ == "__main__":
    main()