HY-WorldPlay-FP8 / scripts /batch_inference.py
vibegavin's picture
Initial release: FP8 quantized weights + turbo3 scripts + video demos
881f988 verified
#!/usr/bin/env python3
"""
批量推理: 11 张图片 × 随机 WASD 方向变换
每 ~30 latents (~120 帧) 变换一个方向
视频长度 497 帧 (~125 latents, ~20s @24fps)
"""
import random
import subprocess
import os
import time
random.seed(42)
IMAGES = [
"1.png", "2.png", "3.png", "4.png", "5.png",
"6.jpeg", "7.png", "8.png", "9.png", "10.png", "test.png"
]
DIRECTIONS = ["w", "a", "s", "d"]
# 加入转向动作
ALL_MOVES = ["w", "a", "s", "d", "left", "right", "up", "down"]
IMAGE_DIR = "/root/HY-WorldPlay/assets/img"
OUTPUT_BASE = "/root/test_results/batch"
MODEL_PATH = subprocess.check_output(
"find /root/models -maxdepth 3 -name 'HunyuanVideo*' -type d | grep -v temp | head -1",
shell=True
).decode().strip()
WP_PATH = subprocess.check_output(
"find /root/models -maxdepth 3 -name 'HY-WorldPlay' -type d | grep -v temp | head -1",
shell=True
).decode().strip()
ACTION_CKPT = f"{WP_PATH}/ar_distilled_action_model/diffusion_pytorch_model.safetensors"
def generate_random_pose(total_latents=31, segment_latents=8):
"""生成随机 WASD 方向序列,每 segment_latents 变换一次"""
segments = []
remaining = total_latents
while remaining > 0:
direction = random.choice(DIRECTIONS)
# 随机混入转向
if random.random() < 0.3:
turn = random.choice(["left", "right"])
turn_len = min(random.randint(2, 4), remaining)
segments.append(f"{turn}-{turn_len}")
remaining -= turn_len
if remaining <= 0:
break
seg_len = min(segment_latents + random.randint(-2, 2), remaining)
if seg_len <= 0:
break
segments.append(f"{direction}-{seg_len}")
remaining -= seg_len
return ",".join(segments)
def run_inference(image_name, pose, output_dir):
"""运行单次推理"""
image_path = os.path.join(IMAGE_DIR, image_name)
cmd = [
"python3", "/root/scripts/run_fp8_turbo3_gpu.py",
"--model_path", MODEL_PATH,
"--action_ckpt", ACTION_CKPT,
"--prompt", "Explore a vivid 3D world with smooth camera movement.",
"--image_path", image_path,
"--resolution", "480p",
"--aspect_ratio", "16:9",
"--video_length", "125",
"--seed", str(random.randint(0, 99999)),
"--rewrite", "false",
"--sr", "false",
"--pose", pose,
"--output_path", output_dir,
"--few_step", "true",
"--num_inference_steps", "4",
"--model_type", "ar",
"--use_vae_parallel", "false",
"--use_sageattn", "true",
"--use_fp8_gemm", "false",
"--transformer_resident_ar_rollout", "true",
"--width", "832",
"--height", "480",
]
env = os.environ.copy()
env["PYTHONPATH"] = "/root/HY-WorldPlay:" + env.get("PYTHONPATH", "")
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
result = subprocess.run(cmd, env=env, cwd="/root/HY-WorldPlay",
capture_output=True, text=True)
return result.returncode, result.stdout, result.stderr
if __name__ == "__main__":
os.makedirs(OUTPUT_BASE, exist_ok=True)
print("=" * 60)
print("批量推理: 11 张图片 × 随机 WASD")
print("=" * 60)
results = []
for i, img in enumerate(IMAGES):
pose = generate_random_pose(total_latents=31, segment_latents=8)
output_dir = os.path.join(OUTPUT_BASE, img.split(".")[0])
os.makedirs(output_dir, exist_ok=True)
print(f"\n[{i+1}/11] {img}")
print(f" pose: {pose}")
print(f" output: {output_dir}")
t0 = time.time()
code, stdout, stderr = run_inference(img, pose, output_dir)
elapsed = time.time() - t0
# 提取关键信息
success = code == 0
peak_mem = ""
total_time = ""
for line in (stdout + stderr).split("\n"):
if "峰值显存" in line:
peak_mem = line.strip()
if "总耗时" in line:
total_time = line.strip()
status = "✅" if success else "❌"
print(f" {status} {elapsed:.0f}s | {peak_mem} | {total_time}")
if not success:
# 打印最后 10 行错误
err_lines = stderr.strip().split("\n")[-10:]
for l in err_lines:
print(f" ERR: {l}")
results.append({
"image": img,
"pose": pose,
"success": success,
"time": elapsed,
})
# 汇总
print("\n" + "=" * 60)
print("汇总")
print("=" * 60)
ok = sum(1 for r in results if r["success"])
print(f"成功: {ok}/11")
print(f"总耗时: {sum(r['time'] for r in results):.0f}s")
for r in results:
s = "✅" if r["success"] else "❌"
print(f" {s} {r['image']:12s} {r['time']:.0f}s pose={r['pose']}")
print(f"\n视频输出目录: {OUTPUT_BASE}/")
print("完成!")