| |
| """ |
| 批量推理: 11 张图片 × 随机 WASD 方向变换 |
| 每 ~30 latents (~120 帧) 变换一个方向 |
| 视频长度 497 帧 (~125 latents, ~20s @24fps) |
| """ |
| import random |
| import subprocess |
| import os |
| import time |
|
|
| random.seed(42) |
|
|
| IMAGES = [ |
| "1.png", "2.png", "3.png", "4.png", "5.png", |
| "6.jpeg", "7.png", "8.png", "9.png", "10.png", "test.png" |
| ] |
|
|
| DIRECTIONS = ["w", "a", "s", "d"] |
| |
| ALL_MOVES = ["w", "a", "s", "d", "left", "right", "up", "down"] |
|
|
| IMAGE_DIR = "/root/HY-WorldPlay/assets/img" |
| OUTPUT_BASE = "/root/test_results/batch" |
|
|
| MODEL_PATH = subprocess.check_output( |
| "find /root/models -maxdepth 3 -name 'HunyuanVideo*' -type d | grep -v temp | head -1", |
| shell=True |
| ).decode().strip() |
|
|
| WP_PATH = subprocess.check_output( |
| "find /root/models -maxdepth 3 -name 'HY-WorldPlay' -type d | grep -v temp | head -1", |
| shell=True |
| ).decode().strip() |
|
|
| ACTION_CKPT = f"{WP_PATH}/ar_distilled_action_model/diffusion_pytorch_model.safetensors" |
|
|
|
|
| def generate_random_pose(total_latents=31, segment_latents=8): |
| """生成随机 WASD 方向序列,每 segment_latents 变换一次""" |
| segments = [] |
| remaining = total_latents |
| while remaining > 0: |
| direction = random.choice(DIRECTIONS) |
| |
| if random.random() < 0.3: |
| turn = random.choice(["left", "right"]) |
| turn_len = min(random.randint(2, 4), remaining) |
| segments.append(f"{turn}-{turn_len}") |
| remaining -= turn_len |
| if remaining <= 0: |
| break |
| seg_len = min(segment_latents + random.randint(-2, 2), remaining) |
| if seg_len <= 0: |
| break |
| segments.append(f"{direction}-{seg_len}") |
| remaining -= seg_len |
| return ",".join(segments) |
|
|
|
|
| def run_inference(image_name, pose, output_dir): |
| """运行单次推理""" |
| image_path = os.path.join(IMAGE_DIR, image_name) |
| |
| cmd = [ |
| "python3", "/root/scripts/run_fp8_turbo3_gpu.py", |
| "--model_path", MODEL_PATH, |
| "--action_ckpt", ACTION_CKPT, |
| "--prompt", "Explore a vivid 3D world with smooth camera movement.", |
| "--image_path", image_path, |
| "--resolution", "480p", |
| "--aspect_ratio", "16:9", |
| "--video_length", "125", |
| "--seed", str(random.randint(0, 99999)), |
| "--rewrite", "false", |
| "--sr", "false", |
| "--pose", pose, |
| "--output_path", output_dir, |
| "--few_step", "true", |
| "--num_inference_steps", "4", |
| "--model_type", "ar", |
| "--use_vae_parallel", "false", |
| "--use_sageattn", "true", |
| "--use_fp8_gemm", "false", |
| "--transformer_resident_ar_rollout", "true", |
| "--width", "832", |
| "--height", "480", |
| ] |
| |
| env = os.environ.copy() |
| env["PYTHONPATH"] = "/root/HY-WorldPlay:" + env.get("PYTHONPATH", "") |
| env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
| |
| result = subprocess.run(cmd, env=env, cwd="/root/HY-WorldPlay", |
| capture_output=True, text=True) |
| return result.returncode, result.stdout, result.stderr |
|
|
|
|
| if __name__ == "__main__": |
| os.makedirs(OUTPUT_BASE, exist_ok=True) |
| |
| print("=" * 60) |
| print("批量推理: 11 张图片 × 随机 WASD") |
| print("=" * 60) |
| |
| results = [] |
| |
| for i, img in enumerate(IMAGES): |
| pose = generate_random_pose(total_latents=31, segment_latents=8) |
| output_dir = os.path.join(OUTPUT_BASE, img.split(".")[0]) |
| os.makedirs(output_dir, exist_ok=True) |
| |
| print(f"\n[{i+1}/11] {img}") |
| print(f" pose: {pose}") |
| print(f" output: {output_dir}") |
| |
| t0 = time.time() |
| code, stdout, stderr = run_inference(img, pose, output_dir) |
| elapsed = time.time() - t0 |
| |
| |
| success = code == 0 |
| peak_mem = "" |
| total_time = "" |
| for line in (stdout + stderr).split("\n"): |
| if "峰值显存" in line: |
| peak_mem = line.strip() |
| if "总耗时" in line: |
| total_time = line.strip() |
| |
| status = "✅" if success else "❌" |
| print(f" {status} {elapsed:.0f}s | {peak_mem} | {total_time}") |
| |
| if not success: |
| |
| err_lines = stderr.strip().split("\n")[-10:] |
| for l in err_lines: |
| print(f" ERR: {l}") |
| |
| results.append({ |
| "image": img, |
| "pose": pose, |
| "success": success, |
| "time": elapsed, |
| }) |
| |
| |
| print("\n" + "=" * 60) |
| print("汇总") |
| print("=" * 60) |
| ok = sum(1 for r in results if r["success"]) |
| print(f"成功: {ok}/11") |
| print(f"总耗时: {sum(r['time'] for r in results):.0f}s") |
| for r in results: |
| s = "✅" if r["success"] else "❌" |
| print(f" {s} {r['image']:12s} {r['time']:.0f}s pose={r['pose']}") |
| |
| print(f"\n视频输出目录: {OUTPUT_BASE}/") |
| print("完成!") |
|
|