File size: 2,483 Bytes
da7bf91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
from pathlib import Path

import torch

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from graphwm.config_graph import GraphWMArgs
from graphwm.dataset.collate_graph_wm import collate_graph_wm
from graphwm.dataset.dataset_graph_wm import SampledDataGraphWorldModelDataset
from graphwm.models.ctrl_world_graph import CtrlWorldGraph, GraphConditioner


def move_graph_seq_to_device(graph_seq, device):
    return [graph_batch.to(device) for graph_batch in graph_seq]


def main():
    args = GraphWMArgs()
    args.num_history = 0
    args.num_frames = 5
    args.train_batch_size = 1
    args.sampled_resize_hw = (192, 320)
    args.svd_model_path = '/workspace/Ctrl-World/ckpt/stable-video-diffusion-img2vid'
    args.clip_model_path = '/workspace/Ctrl-World/ckpt/clip-vit-base-patch32'

    dataset = SampledDataGraphWorldModelDataset(
        sample_root=args.sampled_data_root,
        type_vocab=args.graph_type_vocab,
        session_id=args.sampled_session_id,
        episode_id=args.sampled_episode_id,
        num_history=args.num_history,
        num_frames=args.num_frames,
        resize_hw=args.sampled_resize_hw,
        include_depth=False,
    )
    sample = dataset[0]
    batch = collate_graph_wm([sample])

    print('sample_rgb_shape=', tuple(batch['rgb'].shape))
    print('sample_frame_ids=', batch['frame_ids'][0].tolist())
    print('graph_seq_len=', len(batch['graph_seq']))
    print('graph_t0_x_shape=', tuple(batch['graph_seq'][0].x.shape))
    print('graph_t0_edge_attr_shape=', tuple(batch['graph_seq'][0].edge_attr.shape))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('device=', device)

    graph_conditioner = GraphConditioner(args).to(device)
    graph_seq = move_graph_seq_to_device(batch['graph_seq'], device)
    with torch.no_grad():
        graph_hidden = graph_conditioner(graph_seq)
    print('graph_hidden_shape=', tuple(graph_hidden.shape))

    model = CtrlWorldGraph(args).to(device)
    model.eval()
    batch['graph_seq'] = move_graph_seq_to_device(batch['graph_seq'], device)
    batch['rgb'] = batch['rgb'].to(device)

    with torch.no_grad():
        latents = model.encode_rgb_to_latents(batch['rgb'])
    print('latents_shape=', tuple(latents.shape))

    with torch.no_grad():
        loss, _ = model(batch)
    print('forward_ok=True')
    print('loss=', float(loss.item()))


if __name__ == '__main__':
    main()