| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from graphwm.config_graph import GraphWMArgs |
| from graphwm.dataset.collate_graph_wm import collate_graph_wm |
| from graphwm.dataset.dataset_graph_wm import SampledDataGraphWorldModelDataset |
| from graphwm.models.ctrl_world_graph import CtrlWorldGraph, GraphConditioner |
|
|
|
|
| def move_graph_seq_to_device(graph_seq, device): |
| return [graph_batch.to(device) for graph_batch in graph_seq] |
|
|
|
|
| def main(): |
| args = GraphWMArgs() |
| args.num_history = 0 |
| args.num_frames = 5 |
| args.train_batch_size = 1 |
| args.sampled_resize_hw = (192, 320) |
| args.svd_model_path = '/workspace/Ctrl-World/ckpt/stable-video-diffusion-img2vid' |
| args.clip_model_path = '/workspace/Ctrl-World/ckpt/clip-vit-base-patch32' |
|
|
| dataset = SampledDataGraphWorldModelDataset( |
| sample_root=args.sampled_data_root, |
| type_vocab=args.graph_type_vocab, |
| session_id=args.sampled_session_id, |
| episode_id=args.sampled_episode_id, |
| num_history=args.num_history, |
| num_frames=args.num_frames, |
| resize_hw=args.sampled_resize_hw, |
| include_depth=False, |
| ) |
| sample = dataset[0] |
| batch = collate_graph_wm([sample]) |
|
|
| print('sample_rgb_shape=', tuple(batch['rgb'].shape)) |
| print('sample_frame_ids=', batch['frame_ids'][0].tolist()) |
| print('graph_seq_len=', len(batch['graph_seq'])) |
| print('graph_t0_x_shape=', tuple(batch['graph_seq'][0].x.shape)) |
| print('graph_t0_edge_attr_shape=', tuple(batch['graph_seq'][0].edge_attr.shape)) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print('device=', device) |
|
|
| graph_conditioner = GraphConditioner(args).to(device) |
| graph_seq = move_graph_seq_to_device(batch['graph_seq'], device) |
| with torch.no_grad(): |
| graph_hidden = graph_conditioner(graph_seq) |
| print('graph_hidden_shape=', tuple(graph_hidden.shape)) |
|
|
| model = CtrlWorldGraph(args).to(device) |
| model.eval() |
| batch['graph_seq'] = move_graph_seq_to_device(batch['graph_seq'], device) |
| batch['rgb'] = batch['rgb'].to(device) |
|
|
| with torch.no_grad(): |
| latents = model.encode_rgb_to_latents(batch['rgb']) |
| print('latents_shape=', tuple(latents.shape)) |
|
|
| with torch.no_grad(): |
| loss, _ = model(batch) |
| print('forward_ok=True') |
| print('loss=', float(loss.item())) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|