CatkinChen commited on
Commit
df4349f
·
verified ·
1 Parent(s): 7509ffd

Add demo notebook

Browse files
Files changed (1) hide show
  1. demo_notebook.ipynb +101 -0
demo_notebook.ipynb ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# MultiModalHackVAE Demo\n\n",
8
+ "This notebook demonstrates how to use the MultiModalHackVAE model from CatkinChen/nethack-vae.\n\n",
9
+ "## Installation\n\n",
10
+ "```bash\n",
11
+ "pip install torch transformers huggingface_hub\n",
12
+ "# For NetHack environment (optional):\n",
13
+ "pip install nle\n",
14
+ "```"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "source": [
22
+ "import torch\n",
23
+ "import numpy as np\n",
24
+ "from huggingface_hub import hf_hub_download\n",
25
+ "import json\n",
26
+ "\n",
27
+ "# Load model config\n",
28
+ "config_path = hf_hub_download(repo_id='CatkinChen/nethack-vae', filename='config.json')\n",
29
+ "with open(config_path, 'r') as f:\n",
30
+ " config = json.load(f)\n",
31
+ "\n",
32
+ "print('Model Configuration:')\n",
33
+ "for key, value in config.items():\n",
34
+ " print(f' {key}: {value}')"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "source": [
42
+ "# Load the model (you'll need to import your model class)\n",
43
+ "# from your_package import MultiModalHackVAE\n",
44
+ "# model = load_model_from_huggingface('{repo_name}')\n",
45
+ "\n",
46
+ "# Example synthetic data\n",
47
+ "batch_size = 1\n",
48
+ "game_chars = torch.randint(32, 127, (batch_size, 21, 79))\n",
49
+ "game_colors = torch.randint(0, 16, (batch_size, 21, 79))\n",
50
+ "blstats = torch.randn(batch_size, 27)\n",
51
+ "msg_tokens = torch.randint(0, 128, (batch_size, 256))\n",
52
+ "hero_info = torch.randint(0, 10, (batch_size, 4))\n",
53
+ "\n",
54
+ "print('Synthetic data shapes:')\n",
55
+ "print(f' game_chars: {game_chars.shape}')\n",
56
+ "print(f' game_colors: {game_colors.shape}')\n",
57
+ "print(f' blstats: {blstats.shape}')\n",
58
+ "print(f' msg_tokens: {msg_tokens.shape}')\n",
59
+ "print(f' hero_info: {hero_info.shape}')"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "source": [
67
+ "# Encode to latent space\n",
68
+ "# with torch.no_grad():\n",
69
+ "# output = model(\n",
70
+ "# glyph_chars=game_chars,\n",
71
+ "# glyph_colors=game_colors,\n",
72
+ "# blstats=blstats,\n",
73
+ "# msg_tokens=msg_tokens,\n",
74
+ "# hero_info=hero_info\n",
75
+ "# )\n",
76
+ "# \n",
77
+ "# latent_mean = output['latent_mean']\n",
78
+ "# latent_logvar = output['latent_logvar']\n",
79
+ "# reconstructions = output['reconstructions']\n",
80
+ "# \n",
81
+ "# print(f'Latent representation shape: {latent_mean.shape}')\n",
82
+ "# print(f'Latent mean: {latent_mean[0][:5].tolist()}')\n",
83
+ "\n",
84
+ "print('Model inference example (uncomment when model is available)')"
85
+ ]
86
+ }
87
+ ],
88
+ "metadata": {
89
+ "kernelspec": {
90
+ "display_name": "Python 3",
91
+ "language": "python",
92
+ "name": "python3"
93
+ },
94
+ "language_info": {
95
+ "name": "python",
96
+ "version": "3.8+"
97
+ }
98
+ },
99
+ "nbformat": 4,
100
+ "nbformat_minor": 4
101
+ }