internvideo_tuned / inference_example.py
happy8825's picture
Add inference example
aff893b verified
#!/usr/bin/env python3
import argparse, json, os
import numpy as np
import torch
from huggingface_hub import snapshot_download
from transformers import VideoMAEImageProcessor, AutoModel
from decord import VideoReader
ID2LABEL = {0: "normal", 1: "abnormal"}
class ClassificationHead(torch.nn.Module):
def __init__(self, in_dim: int, hidden_dims, num_labels: int = 2, dropout: float = 0.1):
super().__init__()
dims = [in_dim] + list(hidden_dims)
layers = []
for i in range(len(dims) - 1):
layers.append(torch.nn.Linear(dims[i], dims[i + 1]))
layers.append(torch.nn.GELU())
layers.append(torch.nn.Dropout(dropout))
layers.append(torch.nn.Linear(dims[-1], num_labels))
self.net = torch.nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def pool_tokens(feats: torch.Tensor, expected_feat_dim: int | None = None) -> torch.Tensor:
if feats.dim() != 3:
return feats
_, d1, d2 = feats.shape
if expected_feat_dim is not None:
if d1 == expected_feat_dim:
return feats.mean(dim=2)
if d2 == expected_feat_dim:
return feats.mean(dim=1)
return feats.mean(dim=2 if d1 <= d2 else 1)
def load_frames(path: str, clip_len: int):
vr = VideoReader(path)
idxs = np.linspace(0, len(vr) - 1, num=clip_len, dtype=int)
return [vr[i].asnumpy() for i in idxs]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--repo_id", default="happy8825/internvideo_tuned")
ap.add_argument("--video", required=True, help="Path to video file")
ap.add_argument("--device", default="cuda", help="cuda or cpu")
ap.add_argument("--precision", choices=["fp16","bf16","fp32"], default="bf16")
args = ap.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
local = snapshot_download(args.repo_id)
cfg = json.load(open(os.path.join(local, "train_config.json"), "r"))
base_model = cfg.get("base_model", "revliter/internvideo_next_large_p14_res224_f16")
clip_len = int(cfg.get("clip_len", 16))
frame_size = int(cfg.get("frame_size", 224))
hidden = cfg.get("hidden", [512])
feature_dim = cfg.get("feature_dim") or cfg.get("hidden_size")
processor = VideoMAEImageProcessor.from_pretrained(base_model)
backbone = AutoModel.from_pretrained(base_model, trust_remote_code=True).to(device).eval()
if device.type == "cuda" and args.precision != "fp32":
target = torch.float16 if args.precision == "fp16" else torch.bfloat16
backbone.to(dtype=target)
head = ClassificationHead(in_dim=feature_dim or backbone.config.hidden_size, hidden_dims=hidden)
state = torch.load(os.path.join(local, "best_head.pt"), map_location="cpu")
head.load_state_dict(state["head"])
head.to(device).eval()
frames = load_frames(args.video, clip_len=clip_len)
px = processor(frames, return_tensors="pt")["pixel_values"].permute(0,2,1,3,4).to(device)
amp_dtype = torch.float16 if args.precision == "fp16" else (torch.bfloat16 if args.precision == "bf16" else None)
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=(device.type=="cuda" and amp_dtype is not None)):
feats = backbone.extract_features(pixel_values=px)
pooled = pool_tokens(feats, expected_feat_dim=feature_dim)
logits = head(pooled.float())
pred_id = int(logits.argmax(dim=-1).item())
print(ID2LABEL.get(pred_id, str(pred_id)))
if __name__ == "__main__":
main()