{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MultiModalHackVAE Demo\n\n", "This notebook demonstrates how to use the MultiModalHackVAE model from CatkinChen/nethack-vae.\n\n", "## Installation\n\n", "```bash\n", "pip install torch transformers huggingface_hub\n", "# For NetHack environment (optional):\n", "pip install nle\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "import torch\n", "import numpy as np\n", "from huggingface_hub import hf_hub_download\n", "import json\n", "\n", "# Load model config\n", "config_path = hf_hub_download(repo_id='CatkinChen/nethack-vae', filename='config.json')\n", "with open(config_path, 'r') as f:\n", " config = json.load(f)\n", "\n", "print('Model Configuration:')\n", "for key, value in config.items():\n", " print(f' {key}: {value}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Load the model (you'll need to import your model class)\n", "# from your_package import MultiModalHackVAE\n", "# model = load_model_from_huggingface('{repo_name}')\n", "\n", "# Example synthetic data\n", "batch_size = 1\n", "game_chars = torch.randint(32, 127, (batch_size, 21, 79))\n", "game_colors = torch.randint(0, 16, (batch_size, 21, 79))\n", "blstats = torch.randn(batch_size, 27)\n", "msg_tokens = torch.randint(0, 128, (batch_size, 256))\n", "hero_info = torch.randint(0, 10, (batch_size, 4))\n", "\n", "print('Synthetic data shapes:')\n", "print(f' game_chars: {game_chars.shape}')\n", "print(f' game_colors: {game_colors.shape}')\n", "print(f' blstats: {blstats.shape}')\n", "print(f' msg_tokens: {msg_tokens.shape}')\n", "print(f' hero_info: {hero_info.shape}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# Encode to latent space\n", "# with torch.no_grad():\n", "# output = model(\n", "# glyph_chars=game_chars,\n", "# glyph_colors=game_colors,\n", "# blstats=blstats,\n", "# msg_tokens=msg_tokens,\n", "# hero_info=hero_info\n", "# )\n", "# \n", "# latent_mean = output['mu']\n", "# latent_logvar = output['logvar']\n", "# lowrank_factors = output['lowrank_factors']\n", "# \n", "# print(f'Latent representation shape: {latent_mean.shape}')\n", "# print(f'Latent mean: {latent_mean[0][:5].tolist()}')\n", "\n", "print('Model inference example (uncomment when model is available)')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8+" } }, "nbformat": 4, "nbformat_minor": 4 }