| | |
| | """ |
| | VBench I2V 采样脚本 |
| | |
| | 用法: |
| | python run_sample.py --model_name my_model --resolution 16-9 |
| | """ |
| |
|
| | import os |
| | import argparse |
| |
|
| | |
| | os.environ['VBENCH_CACHE_DIR'] = './vbench2_beta_i2v/pretrained_models' |
| | os.environ['HF_HOME'] = './vbench2_beta_i2v/pretrained_models/huggingface' |
| | os.environ['TORCH_HOME'] = './vbench2_beta_i2v/pretrained_models/torch' |
| |
|
| | import torch |
| | from PIL import Image |
| | from samplers import I2VSampler |
| |
|
| |
|
| | class MySampler(I2VSampler): |
| | """I2V 采样器 - 实现 sample 方法""" |
| |
|
| | def setup(self): |
| | |
| | pass |
| |
|
| | def sample(self, image_path: str, prompt: str, index: int) -> torch.Tensor: |
| | |
| |
|
| | |
| | resolution_map = { |
| | "1-1": (512, 512), |
| | "8-5": (512, 320), |
| | "7-4": (448, 256), |
| | "16-9": (576, 320), |
| | } |
| | W, H = resolution_map[self.resolution] |
| |
|
| | |
| | T, C = 16, 3 |
| | video = torch.randint(0, 256, (T, H, W, C), dtype=torch.uint8) |
| | return video |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="VBench I2V 采样") |
| | parser.add_argument("--model_name", type=str, required=True, help="模型名称") |
| | parser.add_argument("--resolution", type=str, default="16-9", |
| | choices=["1-1", "8-5", "7-4", "16-9"], help="分辨率") |
| | parser.add_argument("--seed", type=int, default=42, help="随机种子") |
| | parser.add_argument("--fps", type=int, default=8, help="视频帧率") |
| | args = parser.parse_args() |
| |
|
| | sampler = MySampler( |
| | model_name=args.model_name, |
| | resolution=args.resolution, |
| | seed=args.seed, |
| | fps=args.fps, |
| | ) |
| | sampler.run() |
| |
|