| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from graphwm.config_graph import GraphWMArgs |
| from graphwm.dataset.collate_graph_wm import collate_graph_wm |
| from graphwm.models.ctrl_world_graph import CtrlWorldGraph |
| from graphwm.original_ctrl_world import import_original_modules |
| from scripts.train_wm_graph import build_datasets |
|
|
|
|
| def write_video(path: Path, frames: np.ndarray, fps: int = 5): |
| try: |
| import mediapy as media |
| media.write_video(str(path), frames, fps=fps) |
| return |
| except Exception: |
| import imageio.v2 as imageio |
| imageio.mimwrite(str(path), frames, fps=fps, macro_block_size=None) |
|
|
|
|
| def decode_latents_to_video(pipeline, latents: torch.Tensor, decode_chunk_size: int): |
| bsz, num_frames = latents.shape[:2] |
| flat = latents.flatten(0, 1) |
| decoded = [] |
| for i in range(0, flat.shape[0], decode_chunk_size): |
| chunk = flat[i:i + decode_chunk_size] / pipeline.vae.config.scaling_factor |
| sample = pipeline.vae.decode(chunk, num_frames=chunk.shape[0]).sample |
| decoded.append(sample) |
| video = torch.cat(decoded, dim=0).reshape(bsz, num_frames, -1, flat.shape[-2] * 8, flat.shape[-1] * 8) |
| video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).byte() |
| return video.permute(0, 1, 3, 4, 2).cpu().numpy() |
|
|
|
|
| def main(): |
| args = GraphWMArgs() |
| args.ckpt_path = '/workspace/Ctrl-World-Graph/model_ckpt/ctrl_world_graph/checkpoint-100000.pt' |
| args.eval_batch_size = 1 |
| args.num_workers = 0 |
|
|
| _, val_ds = build_datasets(args) |
| if val_ds is None or len(val_ds) == 0: |
| raise ValueError('Validation dataset is empty.') |
|
|
| sample = val_ds[0] |
| batch = collate_graph_wm([sample]) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = CtrlWorldGraph(args).to(device) |
| state_dict = torch.load(args.ckpt_path, map_location='cpu') |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
|
|
| original = import_original_modules(args.ctrl_world_root) |
| CtrlWorldDiffusionPipeline = original['CtrlWorldDiffusionPipeline'] |
|
|
| batch['rgb'] = batch['rgb'].to(device) |
| batch['graph_seq'] = [g.to(device) for g in batch['graph_seq']] |
|
|
| with torch.no_grad(): |
| latents = model.encode_rgb_to_latents(batch['rgb']) |
| graph_hidden = model.encode_graph_condition(batch).to(device=device, dtype=model.unet.dtype) |
|
|
| current_latent = latents[:, args.num_history] |
| history = latents[:, :args.num_history] if args.num_history > 0 else None |
|
|
| _, pred_latents = CtrlWorldDiffusionPipeline.__call__( |
| model.pipeline, |
| image=current_latent, |
| text=graph_hidden, |
| width=args.width, |
| height=args.height, |
| num_frames=args.num_frames, |
| history=history, |
| num_inference_steps=args.num_inference_steps, |
| decode_chunk_size=args.decode_chunk_size, |
| max_guidance_scale=args.guidance_scale, |
| fps=args.fps, |
| motion_bucket_id=args.motion_bucket_id, |
| output_type='latent', |
| return_dict=False, |
| frame_level_cond=args.frame_level_cond, |
| his_cond_zero=args.his_cond_zero, |
| ) |
|
|
| pred_video = decode_latents_to_video(model.pipeline, pred_latents, args.decode_chunk_size)[0] |
| gt_video = (batch['rgb'][0].permute(0, 2, 3, 1).clamp(0, 1) * 255).byte().cpu().numpy() |
| compare_video = np.concatenate([gt_video, pred_video], axis=2) |
|
|
| out_dir = Path('/workspace/Ctrl-World-Graph/eval_videos') |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| pred_path = out_dir / 'val0_pred.mp4' |
| gt_path = out_dir / 'val0_gt.mp4' |
| compare_path = out_dir / 'val0_compare.mp4' |
|
|
| write_video(pred_path, pred_video, fps=args.fps) |
| write_video(gt_path, gt_video, fps=args.fps) |
| write_video(compare_path, compare_video, fps=args.fps) |
|
|
| print('saved_pred=', pred_path) |
| print('saved_gt=', gt_path) |
| print('saved_compare=', compare_path) |
| print('frame_ids=', batch['frame_ids'][0].tolist()) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|