File size: 4,241 Bytes
da7bf91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import sys
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from graphwm.config_graph import GraphWMArgs
from graphwm.dataset.collate_graph_wm import collate_graph_wm
from graphwm.models.ctrl_world_graph import CtrlWorldGraph
from graphwm.original_ctrl_world import import_original_modules
from scripts.train_wm_graph import build_datasets


def write_video(path: Path, frames: np.ndarray, fps: int = 5):
    try:
        import mediapy as media
        media.write_video(str(path), frames, fps=fps)
        return
    except Exception:
        import imageio.v2 as imageio
        imageio.mimwrite(str(path), frames, fps=fps, macro_block_size=None)


def decode_latents_to_video(pipeline, latents: torch.Tensor, decode_chunk_size: int):
    bsz, num_frames = latents.shape[:2]
    flat = latents.flatten(0, 1)
    decoded = []
    for i in range(0, flat.shape[0], decode_chunk_size):
        chunk = flat[i:i + decode_chunk_size] / pipeline.vae.config.scaling_factor
        sample = pipeline.vae.decode(chunk, num_frames=chunk.shape[0]).sample
        decoded.append(sample)
    video = torch.cat(decoded, dim=0).reshape(bsz, num_frames, -1, flat.shape[-2] * 8, flat.shape[-1] * 8)
    video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).byte()
    return video.permute(0, 1, 3, 4, 2).cpu().numpy()


def main():
    args = GraphWMArgs()
    args.ckpt_path = '/workspace/Ctrl-World-Graph/model_ckpt/ctrl_world_graph/checkpoint-100000.pt'
    args.eval_batch_size = 1
    args.num_workers = 0

    _, val_ds = build_datasets(args)
    if val_ds is None or len(val_ds) == 0:
        raise ValueError('Validation dataset is empty.')

    sample = val_ds[0]
    batch = collate_graph_wm([sample])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CtrlWorldGraph(args).to(device)
    state_dict = torch.load(args.ckpt_path, map_location='cpu')
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    original = import_original_modules(args.ctrl_world_root)
    CtrlWorldDiffusionPipeline = original['CtrlWorldDiffusionPipeline']

    batch['rgb'] = batch['rgb'].to(device)
    batch['graph_seq'] = [g.to(device) for g in batch['graph_seq']]

    with torch.no_grad():
        latents = model.encode_rgb_to_latents(batch['rgb'])
        graph_hidden = model.encode_graph_condition(batch).to(device=device, dtype=model.unet.dtype)

        current_latent = latents[:, args.num_history]
        history = latents[:, :args.num_history] if args.num_history > 0 else None

        _, pred_latents = CtrlWorldDiffusionPipeline.__call__(
            model.pipeline,
            image=current_latent,
            text=graph_hidden,
            width=args.width,
            height=args.height,
            num_frames=args.num_frames,
            history=history,
            num_inference_steps=args.num_inference_steps,
            decode_chunk_size=args.decode_chunk_size,
            max_guidance_scale=args.guidance_scale,
            fps=args.fps,
            motion_bucket_id=args.motion_bucket_id,
            output_type='latent',
            return_dict=False,
            frame_level_cond=args.frame_level_cond,
            his_cond_zero=args.his_cond_zero,
        )

    pred_video = decode_latents_to_video(model.pipeline, pred_latents, args.decode_chunk_size)[0]
    gt_video = (batch['rgb'][0].permute(0, 2, 3, 1).clamp(0, 1) * 255).byte().cpu().numpy()
    compare_video = np.concatenate([gt_video, pred_video], axis=2)

    out_dir = Path('/workspace/Ctrl-World-Graph/eval_videos')
    out_dir.mkdir(parents=True, exist_ok=True)

    pred_path = out_dir / 'val0_pred.mp4'
    gt_path = out_dir / 'val0_gt.mp4'
    compare_path = out_dir / 'val0_compare.mp4'

    write_video(pred_path, pred_video, fps=args.fps)
    write_video(gt_path, gt_video, fps=args.fps)
    write_video(compare_path, compare_video, fps=args.fps)

    print('saved_pred=', pred_path)
    print('saved_gt=', gt_path)
    print('saved_compare=', compare_path)
    print('frame_ids=', batch['frame_ids'][0].tolist())


if __name__ == '__main__':
    main()