gnn_wm / Ctrl-World-Graph /scripts /debug_sample_graph_step.py
EndeavourDD's picture
Add files using upload-large-folder tool
da7bf91 verified
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from graphwm.config_graph import GraphWMArgs
from graphwm.dataset.collate_graph_wm import collate_graph_wm
from graphwm.dataset.dataset_graph_wm import SampledDataGraphWorldModelDataset
from graphwm.models.ctrl_world_graph import CtrlWorldGraph, GraphConditioner
def move_graph_seq_to_device(graph_seq, device):
return [graph_batch.to(device) for graph_batch in graph_seq]
def main():
args = GraphWMArgs()
args.num_history = 0
args.num_frames = 5
args.train_batch_size = 1
args.sampled_resize_hw = (192, 320)
args.svd_model_path = '/workspace/Ctrl-World/ckpt/stable-video-diffusion-img2vid'
args.clip_model_path = '/workspace/Ctrl-World/ckpt/clip-vit-base-patch32'
dataset = SampledDataGraphWorldModelDataset(
sample_root=args.sampled_data_root,
type_vocab=args.graph_type_vocab,
session_id=args.sampled_session_id,
episode_id=args.sampled_episode_id,
num_history=args.num_history,
num_frames=args.num_frames,
resize_hw=args.sampled_resize_hw,
include_depth=False,
)
sample = dataset[0]
batch = collate_graph_wm([sample])
print('sample_rgb_shape=', tuple(batch['rgb'].shape))
print('sample_frame_ids=', batch['frame_ids'][0].tolist())
print('graph_seq_len=', len(batch['graph_seq']))
print('graph_t0_x_shape=', tuple(batch['graph_seq'][0].x.shape))
print('graph_t0_edge_attr_shape=', tuple(batch['graph_seq'][0].edge_attr.shape))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device=', device)
graph_conditioner = GraphConditioner(args).to(device)
graph_seq = move_graph_seq_to_device(batch['graph_seq'], device)
with torch.no_grad():
graph_hidden = graph_conditioner(graph_seq)
print('graph_hidden_shape=', tuple(graph_hidden.shape))
model = CtrlWorldGraph(args).to(device)
model.eval()
batch['graph_seq'] = move_graph_seq_to_device(batch['graph_seq'], device)
batch['rgb'] = batch['rgb'].to(device)
with torch.no_grad():
latents = model.encode_rgb_to_latents(batch['rgb'])
print('latents_shape=', tuple(latents.shape))
with torch.no_grad():
loss, _ = model(batch)
print('forward_ok=True')
print('loss=', float(loss.item()))
if __name__ == '__main__':
main()