vbench-i2v / run_sample.py
lxl-158's picture
upload vbench-i2v
d0d70d4 verified
#!/usr/bin/env python
"""
VBench I2V 采样脚本
用法:
python run_sample.py --model_name my_model --resolution 16-9
"""
import os
import argparse
# 设置模型缓存路径
os.environ['VBENCH_CACHE_DIR'] = './vbench2_beta_i2v/pretrained_models'
os.environ['HF_HOME'] = './vbench2_beta_i2v/pretrained_models/huggingface'
os.environ['TORCH_HOME'] = './vbench2_beta_i2v/pretrained_models/torch'
import torch
from PIL import Image
from samplers import I2VSampler
class MySampler(I2VSampler):
"""I2V 采样器 - 实现 sample 方法"""
def setup(self):
# TODO: 加载你的模型
pass
def sample(self, image_path: str, prompt: str, index: int) -> torch.Tensor:
# TODO: 替换为你的模型采样代码
# 标准视频分辨率 (16:9 -> 576x320)
resolution_map = {
"1-1": (512, 512),
"8-5": (512, 320),
"7-4": (448, 256),
"16-9": (576, 320),
}
W, H = resolution_map[self.resolution]
# 当前是噪声占位
T, C = 16, 3
video = torch.randint(0, 256, (T, H, W, C), dtype=torch.uint8)
return video
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="VBench I2V 采样")
parser.add_argument("--model_name", type=str, required=True, help="模型名称")
parser.add_argument("--resolution", type=str, default="16-9",
choices=["1-1", "8-5", "7-4", "16-9"], help="分辨率")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument("--fps", type=int, default=8, help="视频帧率")
args = parser.parse_args()
sampler = MySampler(
model_name=args.model_name,
resolution=args.resolution,
seed=args.seed,
fps=args.fps,
)
sampler.run()