import sys from pathlib import Path import numpy as np import torch from torch.utils.data import DataLoader ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from graphwm.config_graph import GraphWMArgs from graphwm.dataset.collate_graph_wm import collate_graph_wm from graphwm.models.ctrl_world_graph import CtrlWorldGraph from graphwm.original_ctrl_world import import_original_modules from scripts.train_wm_graph import build_datasets def write_video(path: Path, frames: np.ndarray, fps: int = 5): try: import mediapy as media media.write_video(str(path), frames, fps=fps) return except Exception: import imageio.v2 as imageio imageio.mimwrite(str(path), frames, fps=fps, macro_block_size=None) def decode_latents_to_video(pipeline, latents: torch.Tensor, decode_chunk_size: int): bsz, num_frames = latents.shape[:2] flat = latents.flatten(0, 1) decoded = [] for i in range(0, flat.shape[0], decode_chunk_size): chunk = flat[i:i + decode_chunk_size] / pipeline.vae.config.scaling_factor sample = pipeline.vae.decode(chunk, num_frames=chunk.shape[0]).sample decoded.append(sample) video = torch.cat(decoded, dim=0).reshape(bsz, num_frames, -1, flat.shape[-2] * 8, flat.shape[-1] * 8) video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).byte() return video.permute(0, 1, 3, 4, 2).cpu().numpy() def main(): args = GraphWMArgs() args.ckpt_path = '/workspace/Ctrl-World-Graph/model_ckpt/ctrl_world_graph/checkpoint-100000.pt' args.eval_batch_size = 1 args.num_workers = 0 _, val_ds = build_datasets(args) if val_ds is None or len(val_ds) == 0: raise ValueError('Validation dataset is empty.') sample = val_ds[0] batch = collate_graph_wm([sample]) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CtrlWorldGraph(args).to(device) state_dict = torch.load(args.ckpt_path, map_location='cpu') model.load_state_dict(state_dict, strict=False) model.eval() original = import_original_modules(args.ctrl_world_root) CtrlWorldDiffusionPipeline = original['CtrlWorldDiffusionPipeline'] batch['rgb'] = batch['rgb'].to(device) batch['graph_seq'] = [g.to(device) for g in batch['graph_seq']] with torch.no_grad(): latents = model.encode_rgb_to_latents(batch['rgb']) graph_hidden = model.encode_graph_condition(batch).to(device=device, dtype=model.unet.dtype) current_latent = latents[:, args.num_history] history = latents[:, :args.num_history] if args.num_history > 0 else None _, pred_latents = CtrlWorldDiffusionPipeline.__call__( model.pipeline, image=current_latent, text=graph_hidden, width=args.width, height=args.height, num_frames=args.num_frames, history=history, num_inference_steps=args.num_inference_steps, decode_chunk_size=args.decode_chunk_size, max_guidance_scale=args.guidance_scale, fps=args.fps, motion_bucket_id=args.motion_bucket_id, output_type='latent', return_dict=False, frame_level_cond=args.frame_level_cond, his_cond_zero=args.his_cond_zero, ) pred_video = decode_latents_to_video(model.pipeline, pred_latents, args.decode_chunk_size)[0] gt_video = (batch['rgb'][0].permute(0, 2, 3, 1).clamp(0, 1) * 255).byte().cpu().numpy() compare_video = np.concatenate([gt_video, pred_video], axis=2) out_dir = Path('/workspace/Ctrl-World-Graph/eval_videos') out_dir.mkdir(parents=True, exist_ok=True) pred_path = out_dir / 'val0_pred.mp4' gt_path = out_dir / 'val0_gt.mp4' compare_path = out_dir / 'val0_compare.mp4' write_video(pred_path, pred_video, fps=args.fps) write_video(gt_path, gt_video, fps=args.fps) write_video(compare_path, compare_video, fps=args.fps) print('saved_pred=', pred_path) print('saved_gt=', gt_path) print('saved_compare=', compare_path) print('frame_ids=', batch['frame_ids'][0].tolist()) if __name__ == '__main__': main()