import os import sys from pathlib import Path import torch from torch.utils.data import DataLoader, Subset from torch.utils.tensorboard import SummaryWriter from accelerate import Accelerator from accelerate.logging import get_logger from tqdm.auto import tqdm ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from graphwm.config_graph import GraphWMArgs from graphwm.dataset.collate_graph_wm import collate_graph_wm from graphwm.dataset.dataset_graph_wm import GraphWorldModelDataset, SampledDataGraphWorldModelDataset from graphwm.models.ctrl_world_graph import CtrlWorldGraph def build_datasets(args: GraphWMArgs): if args.use_sampled_data_loader: full_dataset = SampledDataGraphWorldModelDataset( sample_root=args.sampled_data_root, type_vocab=args.graph_type_vocab, session_id=args.sampled_session_id, episode_id=args.sampled_episode_id, num_history=args.num_history, num_frames=args.num_frames, resize_hw=args.sampled_resize_hw, include_depth=args.include_depth, ) else: full_dataset = GraphWorldModelDataset(args.graph_manifest_path, args.graph_type_vocab) if not args.use_eval_split: return full_dataset, None dataset_len = len(full_dataset) val_len = max(1, int(dataset_len * args.val_ratio)) if dataset_len - val_len < 1: val_len = max(1, dataset_len - 1) train_len = dataset_len - val_len train_indices = list(range(0, train_len)) val_indices = list(range(train_len, dataset_len)) return Subset(full_dataset, train_indices), Subset(full_dataset, val_indices) def evaluate(model, loader, accelerator): model.eval() total = 0.0 count = 0 with torch.no_grad(): for batch in loader: with accelerator.autocast(): loss_gen, _ = model(batch) avg_loss = accelerator.gather(loss_gen.detach().reshape(1)).mean() total += float(avg_loss.item()) count += 1 model.train() return total / max(count, 1) def main(args: GraphWMArgs): logger = get_logger(__name__, log_level="INFO") accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, ) model = CtrlWorldGraph(args) if args.ckpt_path: state_dict = torch.load(args.ckpt_path, map_location="cpu") model.load_state_dict(state_dict, strict=False) train_dataset, val_dataset = build_datasets(args) train_loader = DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=args.shuffle, num_workers=args.num_workers, collate_fn=collate_graph_wm, ) val_loader = None if val_dataset is not None: val_loader = DataLoader( val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_graph_wm, ) optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) if val_loader is not None: model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader) else: model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) writer = None if accelerator.is_main_process and args.use_tensorboard: os.makedirs(args.tensorboard_log_dir, exist_ok=True) writer = SummaryWriter(log_dir=args.tensorboard_log_dir) model.train() global_step = 0 running_loss = 0.0 running_count = 0 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Graph WM Steps") if accelerator.is_main_process: logger.info("Train samples: %s", len(train_dataset)) if val_dataset is not None: logger.info("Val samples: %s", len(val_dataset)) while global_step < args.max_train_steps: for batch in train_loader: with accelerator.accumulate(model): with accelerator.autocast(): loss_gen, _ = model(batch) avg_loss = accelerator.gather(loss_gen.detach().reshape(1)).mean() running_loss += float(avg_loss.item()) running_count += 1 accelerator.backward(loss_gen) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad() if accelerator.sync_gradients: global_step += 1 progress_bar.update(1) progress_bar.set_postfix({"loss": float(avg_loss.item())}) if global_step % args.log_every_steps == 0: train_loss = running_loss / max(running_count, 1) if accelerator.is_main_process: logger.info("step=%s train_loss=%.6f", global_step, train_loss) if writer is not None: writer.add_scalar("loss/train", train_loss, global_step) running_loss = 0.0 running_count = 0 if val_loader is not None and global_step % args.validation_steps == 0: val_loss = evaluate(model, val_loader, accelerator) if accelerator.is_main_process: logger.info("step=%s val_loss=%.6f", global_step, val_loss) if writer is not None: writer.add_scalar("loss/val", val_loss, global_step) if global_step % args.checkpointing_steps == 0 and accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt") torch.save(accelerator.unwrap_model(model).state_dict(), save_path) logger.info("Saved checkpoint to %s", save_path) if global_step >= args.max_train_steps: break if writer is not None: writer.close() if __name__ == "__main__": args = GraphWMArgs() main(args)