File size: 3,321 Bytes
df4349f cee504f df4349f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MultiModalHackVAE Demo\n\n",
"This notebook demonstrates how to use the MultiModalHackVAE model from CatkinChen/nethack-vae.\n\n",
"## Installation\n\n",
"```bash\n",
"pip install torch transformers huggingface_hub\n",
"# For NetHack environment (optional):\n",
"pip install nle\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [
"import torch\n",
"import numpy as np\n",
"from huggingface_hub import hf_hub_download\n",
"import json\n",
"\n",
"# Load model config\n",
"config_path = hf_hub_download(repo_id='CatkinChen/nethack-vae', filename='config.json')\n",
"with open(config_path, 'r') as f:\n",
" config = json.load(f)\n",
"\n",
"print('Model Configuration:')\n",
"for key, value in config.items():\n",
" print(f' {key}: {value}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [
"# Load the model (you'll need to import your model class)\n",
"# from your_package import MultiModalHackVAE\n",
"# model = load_model_from_huggingface('{repo_name}')\n",
"\n",
"# Example synthetic data\n",
"batch_size = 1\n",
"game_chars = torch.randint(32, 127, (batch_size, 21, 79))\n",
"game_colors = torch.randint(0, 16, (batch_size, 21, 79))\n",
"blstats = torch.randn(batch_size, 27)\n",
"msg_tokens = torch.randint(0, 128, (batch_size, 256))\n",
"hero_info = torch.randint(0, 10, (batch_size, 4))\n",
"\n",
"print('Synthetic data shapes:')\n",
"print(f' game_chars: {game_chars.shape}')\n",
"print(f' game_colors: {game_colors.shape}')\n",
"print(f' blstats: {blstats.shape}')\n",
"print(f' msg_tokens: {msg_tokens.shape}')\n",
"print(f' hero_info: {hero_info.shape}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [
"# Encode to latent space\n",
"# with torch.no_grad():\n",
"# output = model(\n",
"# glyph_chars=game_chars,\n",
"# glyph_colors=game_colors,\n",
"# blstats=blstats,\n",
"# msg_tokens=msg_tokens,\n",
"# hero_info=hero_info\n",
"# )\n",
"# \n",
"# latent_mean = output['mu']\n",
"# latent_logvar = output['logvar']\n",
"# lowrank_factors = output['lowrank_factors']\n",
"# \n",
"# print(f'Latent representation shape: {latent_mean.shape}')\n",
"# print(f'Latent mean: {latent_mean[0][:5].tolist()}')\n",
"\n",
"print('Model inference example (uncomment when model is available)')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8+"
}
},
"nbformat": 4,
"nbformat_minor": 4
} |