Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import os | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| from torch.utils.tensorboard import SummaryWriter | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from tqdm.auto import tqdm | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from graphwm.config_graph import GraphWMArgs | |
| from graphwm.dataset.collate_graph_wm import collate_graph_wm | |
| from graphwm.dataset.dataset_graph_wm import GraphWorldModelDataset, SampledDataGraphWorldModelDataset | |
| from graphwm.models.ctrl_world_graph import CtrlWorldGraph | |
| def build_datasets(args: GraphWMArgs): | |
| if args.use_sampled_data_loader: | |
| full_dataset = SampledDataGraphWorldModelDataset( | |
| sample_root=args.sampled_data_root, | |
| type_vocab=args.graph_type_vocab, | |
| session_id=args.sampled_session_id, | |
| episode_id=args.sampled_episode_id, | |
| num_history=args.num_history, | |
| num_frames=args.num_frames, | |
| resize_hw=args.sampled_resize_hw, | |
| include_depth=args.include_depth, | |
| ) | |
| else: | |
| full_dataset = GraphWorldModelDataset(args.graph_manifest_path, args.graph_type_vocab) | |
| if not args.use_eval_split: | |
| return full_dataset, None | |
| dataset_len = len(full_dataset) | |
| val_len = max(1, int(dataset_len * args.val_ratio)) | |
| if dataset_len - val_len < 1: | |
| val_len = max(1, dataset_len - 1) | |
| train_len = dataset_len - val_len | |
| train_indices = list(range(0, train_len)) | |
| val_indices = list(range(train_len, dataset_len)) | |
| return Subset(full_dataset, train_indices), Subset(full_dataset, val_indices) | |
| def evaluate(model, loader, accelerator): | |
| model.eval() | |
| total = 0.0 | |
| count = 0 | |
| with torch.no_grad(): | |
| for batch in loader: | |
| with accelerator.autocast(): | |
| loss_gen, _ = model(batch) | |
| avg_loss = accelerator.gather(loss_gen.detach().reshape(1)).mean() | |
| total += float(avg_loss.item()) | |
| count += 1 | |
| model.train() | |
| return total / max(count, 1) | |
| def main(args: GraphWMArgs): | |
| logger = get_logger(__name__, log_level="INFO") | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| mixed_precision=args.mixed_precision, | |
| ) | |
| model = CtrlWorldGraph(args) | |
| if args.ckpt_path: | |
| state_dict = torch.load(args.ckpt_path, map_location="cpu") | |
| model.load_state_dict(state_dict, strict=False) | |
| train_dataset, val_dataset = build_datasets(args) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=args.train_batch_size, | |
| shuffle=args.shuffle, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_graph_wm, | |
| ) | |
| val_loader = None | |
| if val_dataset is not None: | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=args.eval_batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_graph_wm, | |
| ) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) | |
| if val_loader is not None: | |
| model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader) | |
| else: | |
| model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) | |
| writer = None | |
| if accelerator.is_main_process and args.use_tensorboard: | |
| os.makedirs(args.tensorboard_log_dir, exist_ok=True) | |
| writer = SummaryWriter(log_dir=args.tensorboard_log_dir) | |
| model.train() | |
| global_step = 0 | |
| running_loss = 0.0 | |
| running_count = 0 | |
| progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description("Graph WM Steps") | |
| if accelerator.is_main_process: | |
| logger.info("Train samples: %s", len(train_dataset)) | |
| if val_dataset is not None: | |
| logger.info("Val samples: %s", len(val_dataset)) | |
| while global_step < args.max_train_steps: | |
| for batch in train_loader: | |
| with accelerator.accumulate(model): | |
| with accelerator.autocast(): | |
| loss_gen, _ = model(batch) | |
| avg_loss = accelerator.gather(loss_gen.detach().reshape(1)).mean() | |
| running_loss += float(avg_loss.item()) | |
| running_count += 1 | |
| accelerator.backward(loss_gen) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if accelerator.sync_gradients: | |
| global_step += 1 | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({"loss": float(avg_loss.item())}) | |
| if global_step % args.log_every_steps == 0: | |
| train_loss = running_loss / max(running_count, 1) | |
| if accelerator.is_main_process: | |
| logger.info("step=%s train_loss=%.6f", global_step, train_loss) | |
| if writer is not None: | |
| writer.add_scalar("loss/train", train_loss, global_step) | |
| running_loss = 0.0 | |
| running_count = 0 | |
| if val_loader is not None and global_step % args.validation_steps == 0: | |
| val_loss = evaluate(model, val_loader, accelerator) | |
| if accelerator.is_main_process: | |
| logger.info("step=%s val_loss=%.6f", global_step, val_loss) | |
| if writer is not None: | |
| writer.add_scalar("loss/val", val_loss, global_step) | |
| if global_step % args.checkpointing_steps == 0 and accelerator.is_main_process: | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt") | |
| torch.save(accelerator.unwrap_model(model).state_dict(), save_path) | |
| logger.info("Saved checkpoint to %s", save_path) | |
| if global_step >= args.max_train_steps: | |
| break | |
| if writer is not None: | |
| writer.close() | |
| if __name__ == "__main__": | |
| args = GraphWMArgs() | |
| main(args) | |