import sys from pathlib import Path import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from graphwm.config_graph import GraphWMArgs from graphwm.dataset.collate_graph_wm import collate_graph_wm from graphwm.dataset.dataset_graph_wm import SampledDataGraphWorldModelDataset from graphwm.models.ctrl_world_graph import CtrlWorldGraph, GraphConditioner def move_graph_seq_to_device(graph_seq, device): return [graph_batch.to(device) for graph_batch in graph_seq] def main(): args = GraphWMArgs() args.num_history = 0 args.num_frames = 5 args.train_batch_size = 1 args.sampled_resize_hw = (192, 320) args.svd_model_path = '/workspace/Ctrl-World/ckpt/stable-video-diffusion-img2vid' args.clip_model_path = '/workspace/Ctrl-World/ckpt/clip-vit-base-patch32' dataset = SampledDataGraphWorldModelDataset( sample_root=args.sampled_data_root, type_vocab=args.graph_type_vocab, session_id=args.sampled_session_id, episode_id=args.sampled_episode_id, num_history=args.num_history, num_frames=args.num_frames, resize_hw=args.sampled_resize_hw, include_depth=False, ) sample = dataset[0] batch = collate_graph_wm([sample]) print('sample_rgb_shape=', tuple(batch['rgb'].shape)) print('sample_frame_ids=', batch['frame_ids'][0].tolist()) print('graph_seq_len=', len(batch['graph_seq'])) print('graph_t0_x_shape=', tuple(batch['graph_seq'][0].x.shape)) print('graph_t0_edge_attr_shape=', tuple(batch['graph_seq'][0].edge_attr.shape)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('device=', device) graph_conditioner = GraphConditioner(args).to(device) graph_seq = move_graph_seq_to_device(batch['graph_seq'], device) with torch.no_grad(): graph_hidden = graph_conditioner(graph_seq) print('graph_hidden_shape=', tuple(graph_hidden.shape)) model = CtrlWorldGraph(args).to(device) model.eval() batch['graph_seq'] = move_graph_seq_to_device(batch['graph_seq'], device) batch['rgb'] = batch['rgb'].to(device) with torch.no_grad(): latents = model.encode_rgb_to_latents(batch['rgb']) print('latents_shape=', tuple(latents.shape)) with torch.no_grad(): loss, _ = model(batch) print('forward_ok=True') print('loss=', float(loss.item())) if __name__ == '__main__': main()