Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Sync from GitHub
Browse files- .gitattributes +2 -35
- .gitignore +210 -0
- Dockerfile.notebook +26 -0
- README.md +163 -12
- ROADMAP.md +252 -0
- debug_env.py +38 -0
- demo.py +401 -0
- drift_events.py +4 -3
- generate_episodes.py +305 -0
- messages.py +147 -73
- notebooks/crisisinbox_grpo_simple.ipynb +108 -0
- pyproject.toml +30 -0
- server/Dockerfile +9 -11
- server/crisis_inbox_environment.py +13 -1
- training/crisisinbox_training.py +416 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
episodes.json filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
sample_episodes.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
#Claude
|
| 210 |
+
.claude/*
|
Dockerfile.notebook
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install Jupyter + training deps
|
| 6 |
+
RUN pip install --no-cache-dir \
|
| 7 |
+
jupyter \
|
| 8 |
+
unsloth \
|
| 9 |
+
trl \
|
| 10 |
+
transformers \
|
| 11 |
+
datasets \
|
| 12 |
+
accelerate \
|
| 13 |
+
peft \
|
| 14 |
+
huggingface_hub
|
| 15 |
+
|
| 16 |
+
# Copy everything needed for training
|
| 17 |
+
COPY episodes.json .
|
| 18 |
+
COPY generate_episodes.py .
|
| 19 |
+
COPY models.py .
|
| 20 |
+
COPY messages.py .
|
| 21 |
+
COPY drift_events.py .
|
| 22 |
+
COPY notebooks/ ./notebooks/
|
| 23 |
+
|
| 24 |
+
EXPOSE 8888
|
| 25 |
+
|
| 26 |
+
CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--no-browser", "--allow-root", "--NotebookApp.token=''"]
|
README.md
CHANGED
|
@@ -1,15 +1,166 @@
|
|
| 1 |
-
|
| 2 |
-
title: CrisisInbox
|
| 3 |
-
emoji: 🚨
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: yellow
|
| 6 |
-
sdk: docker
|
| 7 |
-
app_port: 8000
|
| 8 |
-
---
|
| 9 |
|
| 10 |
-
|
| 11 |
|
| 12 |
-
|
| 13 |
-
Built on OpenEnv 0.2.1.
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CrisisInbox
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
A reinforcement learning environment built on [OpenEnv 0.2.1](https://github.com/OpenEnvs/OpenEnv) for training language models to manage personal task overload during natural disasters.
|
| 4 |
|
| 5 |
+
**Problem Statement 3.2** (Personalized Tasks) + **Patronus AI Sub-Theme** (Schema Drift)
|
|
|
|
| 6 |
|
| 7 |
+
**HF Space:** [eptan-crisis-inbox.hf.space](https://eptan-crisis-inbox.hf.space)
|
| 8 |
+
|
| 9 |
+
## The Problem
|
| 10 |
+
|
| 11 |
+
When disaster strikes, your phone explodes. Evacuation orders, panicked family texts, insurance deadlines, your boss demanding slides, your sister begging you to pick up her kids from school, your dad's heart medication left behind. Everything is urgent. Policies change mid-crisis. There are no clean answers — only tradeoffs.
|
| 12 |
+
|
| 13 |
+
CrisisInbox trains an agent to make those tradeoffs well.
|
| 14 |
+
|
| 15 |
+
## How It Works
|
| 16 |
+
|
| 17 |
+
The agent manages a **48-hour post-disaster inbox** as a working parent in Sacramento during a hurricane. Messages arrive in real time from 19 senders across 6 channels. The agent must triage, respond, and adapt — while the rules keep changing underneath.
|
| 18 |
+
|
| 19 |
+
### Three Layers of Difficulty
|
| 20 |
+
|
| 21 |
+
**1. Cognitive Overload** — 76 messages arrive over 48 hours. Reading costs time (6 min). Responding costs more (15 min). The agent can't handle everything — it must prioritize and let some things slide.
|
| 22 |
+
|
| 23 |
+
**2. Conflicting Obligations** — Your boss says come in. HR says take emergency leave. Your sister needs you to watch her kids. Mom wants you to drive to Tahoe. The evacuation shelter is full. There's no right answer, only better tradeoffs.
|
| 24 |
+
|
| 25 |
+
**3. Schema Drift** — Mid-episode, the rules change:
|
| 26 |
+
- Insurance deadline shortened from 72h to 48h
|
| 27 |
+
- Evacuation zone expands to include your workplace
|
| 28 |
+
- Employer switches from "use PTO" to "5 days paid emergency leave"
|
| 29 |
+
- Airline extends free rebooking from 48h to 7 days
|
| 30 |
+
- FEMA adds new documentation requirements
|
| 31 |
+
|
| 32 |
+
Each episode randomly fires 3 of 5 drift events. The agent must detect changes and reprioritize.
|
| 33 |
+
|
| 34 |
+
### Sender Profiles
|
| 35 |
+
|
| 36 |
+
| Sender | Messages | Tone | Stakes |
|
| 37 |
+
|--------|----------|------|--------|
|
| 38 |
+
| Mom | 8 | Panicked, crying voicemails | Dad's heart medication, family safety |
|
| 39 |
+
| Sister | 7 | Desperate, grateful | Kids stranded at school, childcare |
|
| 40 |
+
| Emma (niece, 7) | 3 | Kid texting style | Scared in the dark, rainbow drawings |
|
| 41 |
+
| Boss (Greg) | 5 | Passive-aggressive, then softens | Career pressure vs. emergency |
|
| 42 |
+
| Neighbor Dave | 8 | Casual bro, community | Cat rescue, looting, cleanup |
|
| 43 |
+
| FEMA / NWS | 10 | Formal, information-dense | Evacuation orders, shelter locations |
|
| 44 |
+
| State Farm | 5 | Corporate | Claim deadlines, documentation |
|
| 45 |
+
| Delta Airlines | 3 | Automated | Flight rebooking |
|
| 46 |
+
| Oakwood Elementary | 3 | School admin | Closures, virtual learning |
|
| 47 |
+
| HR Department | 3 | Policy updates | Leave policies |
|
| 48 |
+
| + 9 others | 21 | Various | Pharmacy, landlord, utilities, etc. |
|
| 49 |
+
|
| 50 |
+
### MCP Tools (Agent Actions)
|
| 51 |
+
|
| 52 |
+
| Tool | Time Cost | Description |
|
| 53 |
+
|------|-----------|-------------|
|
| 54 |
+
| `get_inbox` | 0h | View all arrived messages with metadata |
|
| 55 |
+
| `read_message` | 0.1h | Read full content of a specific message |
|
| 56 |
+
| `respond_to_message` | 0.25h | Take action on a message (earns reward) |
|
| 57 |
+
| `get_status` | 0h | View time, score, deadlines, drift events |
|
| 58 |
+
| `advance_time` | 0.5-4h | Skip forward (new messages may arrive) |
|
| 59 |
+
|
| 60 |
+
### Reward Function
|
| 61 |
+
|
| 62 |
+
| Signal | Weight | Mechanic |
|
| 63 |
+
|--------|--------|----------|
|
| 64 |
+
| **Urgency base** | 1-10 | Critical=10, High=5, Medium=3, Low=1 |
|
| 65 |
+
| **Deadline timing** | x0.25 to x1.5 | Earlier response = bigger bonus; late = 75% penalty |
|
| 66 |
+
| **Drift adaptation** | x1.5 | Bonus for handling drift-flagged messages |
|
| 67 |
+
| **Stale info penalty** | x0.5 | Penalty for acting on superseded information |
|
| 68 |
+
| **Response quality** | x0.5 | Penalty for very short/empty responses |
|
| 69 |
+
|
| 70 |
+
An optimal agent responding to the critical evacuation alert at hour 0 earns 15 points. The same response at hour 10 (after the 6h deadline) earns only 2.5.
|
| 71 |
+
|
| 72 |
+
### Episode Variation
|
| 73 |
+
|
| 74 |
+
Each episode has:
|
| 75 |
+
- 3 of 5 drift events randomly selected (seed-controlled)
|
| 76 |
+
- +/-15% jitter on message arrival times
|
| 77 |
+
- +/-10% jitter on deadlines
|
| 78 |
+
- Dependency chains that gate actions (e.g., must handle sister's request before school pickup confirmation)
|
| 79 |
+
|
| 80 |
+
## Quick Start (Hosted)
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
from crisis_inbox import CrisisInboxEnv
|
| 84 |
+
|
| 85 |
+
with CrisisInboxEnv(base_url="https://eptan-crisis-inbox.hf.space") as env:
|
| 86 |
+
env.reset()
|
| 87 |
+
inbox = env.call_tool("get_inbox")
|
| 88 |
+
print(inbox)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Local Development
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
git clone https://github.com/eptan/crisis-inbox.git
|
| 95 |
+
cd crisis-inbox
|
| 96 |
+
|
| 97 |
+
python3.12 -m venv .venv
|
| 98 |
+
source .venv/bin/activate
|
| 99 |
+
pip install --upgrade pip
|
| 100 |
+
pip install -e .
|
| 101 |
+
|
| 102 |
+
# Run server
|
| 103 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
Test with the client:
|
| 107 |
+
|
| 108 |
+
```python
|
| 109 |
+
from crisis_inbox import CrisisInboxEnv
|
| 110 |
+
|
| 111 |
+
with CrisisInboxEnv(base_url="http://localhost:8000") as env:
|
| 112 |
+
env.reset()
|
| 113 |
+
|
| 114 |
+
# View the inbox
|
| 115 |
+
inbox = env.call_tool("get_inbox")
|
| 116 |
+
print(inbox)
|
| 117 |
+
|
| 118 |
+
# Read a message
|
| 119 |
+
msg = env.call_tool("read_message", message_id="msg_001")
|
| 120 |
+
print(msg)
|
| 121 |
+
|
| 122 |
+
# Respond to it
|
| 123 |
+
result = env.call_tool("respond_to_message",
|
| 124 |
+
message_id="msg_001",
|
| 125 |
+
response="Evacuating to Lincoln High School immediately with documents and medication.")
|
| 126 |
+
print(result)
|
| 127 |
+
|
| 128 |
+
# Advance time and check status
|
| 129 |
+
env.call_tool("advance_time", hours=4.0)
|
| 130 |
+
status = env.call_tool("get_status")
|
| 131 |
+
print(status)
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## Repository Structure
|
| 135 |
+
|
| 136 |
+
```
|
| 137 |
+
crisis-inbox/
|
| 138 |
+
├── models.py # Message data model (Channel, Urgency, Message)
|
| 139 |
+
├── messages.py # 76 pre-written messages across 48h timeline
|
| 140 |
+
├── drift_events.py # 5 schema drift events (3 fire per episode)
|
| 141 |
+
├── client.py # MCPToolClient subclass
|
| 142 |
+
├── __init__.py # Package exports
|
| 143 |
+
├── server/
|
| 144 |
+
│ ├── crisis_inbox_environment.py # MCPEnvironment with timeline engine
|
| 145 |
+
│ ├── app.py # FastAPI app with MCPAction workaround
|
| 146 |
+
│ └── Dockerfile # HF Spaces deployment
|
| 147 |
+
├── notebooks/
|
| 148 |
+
│ └── crisisinbox_grpo.ipynb # GRPO training notebook
|
| 149 |
+
├── episodes.json # Pre-generated training episodes
|
| 150 |
+
├── generate_episodes.py # Episode generator script
|
| 151 |
+
├── pyproject.toml # Package config
|
| 152 |
+
├── openenv.yaml # OpenEnv environment spec
|
| 153 |
+
├── requirements.txt # Docker build dependencies
|
| 154 |
+
└── ROADMAP.md # Hackathon timeline and progress
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
## Stack
|
| 158 |
+
|
| 159 |
+
- **Environment:** OpenEnv 0.2.1 (MCPEnvironment + FastMCP)
|
| 160 |
+
- **Deployment:** HF Spaces (Docker)
|
| 161 |
+
- **Training:** Unsloth GRPO via Google Colab
|
| 162 |
+
- **Model:** Qwen2.5-0.5B-Instruct
|
| 163 |
+
|
| 164 |
+
## Team
|
| 165 |
+
|
| 166 |
+
Built at the OpenEnv Hackathon @ Shack15, SF — March 7-8, 2026
|
ROADMAP.md
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CrisisInbox Hackathon Roadmap
|
| 2 |
+
|
| 3 |
+
**Event:** OpenEnv Hackathon @ Shack15, SF
|
| 4 |
+
**Dates:** Saturday March 7 – Sunday March 8, 2026
|
| 5 |
+
**Team Size:** 2
|
| 6 |
+
**Problem Statement:** 3.2 (Personalized Tasks) + Patronus AI Sub-Theme (Schema Drift)
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## PHASE 1: Foundation (11:30 AM – 2:00 PM Saturday)
|
| 11 |
+
|
| 12 |
+
### Person A — Environment Setup
|
| 13 |
+
|
| 14 |
+
- [x] Read OpenEnv 0.2.1 docs thoroughly — understand environment structure, observation/action spaces, step function
|
| 15 |
+
- [x] Get a minimal "hello world" OpenEnv environment running locally
|
| 16 |
+
- [x] Define the core data model:
|
| 17 |
+
- [x] `Message` object: sender, channel (6 types), content, urgency (critical/high/medium/low), deadline, dependencies, drift_flag, supersedes
|
| 18 |
+
- [x] Task state tracked via environment internals (`_handled`, `_visible_messages`, deadline expiry) — no separate Task object needed since messages and tasks are 1:1
|
| 19 |
+
- [x] World state tracked via environment instance vars (`_current_hour`, `_all_messages`, `_visible_messages`, `_handled`, `_score`, `_fired_drifts`, `_superseded`) — exposed to agent through `get_status` tool
|
| 20 |
+
- [x] Deploy bare-bones environment to HF Spaces — confirmed at eptan-crisis-inbox.hf.space
|
| 21 |
+
- [ ] Push initial scaffold to GitHub
|
| 22 |
+
|
| 23 |
+
### Person B — Training Pipeline Setup
|
| 24 |
+
|
| 25 |
+
- [x] Open Unsloth GRPO Colab notebook and run with a toy example end-to-end
|
| 26 |
+
- [x] Confirm training loop works: environment → agent action → reward → update
|
| 27 |
+
- [x] Define the action space — simplified to free-form: model outputs `respond_to_message(msg_id, "response")`, parsed via regex
|
| 28 |
+
- [x] Define observation space: full inbox snapshot as text prompt with urgency grouping, deadline warnings, drift flags, stale markers
|
| 29 |
+
- [x] Implement reward function (integrated in notebook `score_action()`):
|
| 30 |
+
- [x] Urgency base: critical=10, high=5, medium=3, low=1
|
| 31 |
+
- [x] Deadline timing: early bonus (up to +50%), late penalty (-75%)
|
| 32 |
+
- [x] Schema drift adaptation: +50% for handling drift-flagged messages
|
| 33 |
+
- [x] Stale info penalty: -50% for acting on superseded messages
|
| 34 |
+
- [x] Response quality: -50% for short/empty responses
|
| 35 |
+
- [x] Priority penalty: -70% for choosing low-urgency when critical messages exist
|
| 36 |
+
- [ ] Push training scaffold to GitHub
|
| 37 |
+
|
| 38 |
+
### Checkpoint — 2:00 PM
|
| 39 |
+
|
| 40 |
+
- [x] ✅ OpenEnv minimal environment runs locally
|
| 41 |
+
- [x] ✅ HF Spaces deployment pipeline confirmed
|
| 42 |
+
- [x] ✅ Unsloth GRPO training loop configured in notebook (Qwen2.5-0.5B + LoRA + GRPO)
|
| 43 |
+
- [x] ✅ Data model and reward function defined
|
| 44 |
+
- [ ] 🚨 If OpenEnv is blocking: simplify to gym-style env, wrap in OpenEnv later
|
| 45 |
+
- [ ] 🚨 If Unsloth is blocking: fall back to HF TRL directly
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## PHASE 2: Core Build (2:00 PM – 6:00 PM Saturday)
|
| 50 |
+
|
| 51 |
+
### Person A — Environment Logic
|
| 52 |
+
|
| 53 |
+
- [x] Build message generation system:
|
| 54 |
+
- [x] Create 18 sender profiles (Mom, Boss, Sister, Neighbor Dave, FEMA, NWS, State Farm, Delta Airlines, Oakwood Elementary, HR, Landlord, Coworker Sarah, PG&E, Red Cross, Bank of America, CVS Pharmacy, Sacramento County, Comcast)
|
| 55 |
+
- [x] Write 76 messages across 48-hour simulated timeline
|
| 56 |
+
- [x] Each message has: sender, channel, timestamp, urgency, deadline, content, dependencies, drift_flag, supersedes
|
| 57 |
+
- [x] Organized into waves: initial crisis (hour 0-2), escalation (2-6), post-evacuation chaos (6-12), conflicting demands (12-20), drift events (20-30), recovery (26-36), ongoing management (32-40), final stretch (44-48)
|
| 58 |
+
- [x] Implement episode timeline engine:
|
| 59 |
+
- [x] Messages arrive at scheduled sim-hours
|
| 60 |
+
- [x] Time costs per action: reading = 0.1h, responding = 0.25h, advance_time tool = 0.5-4.0h
|
| 61 |
+
- [x] Dependencies gate actions (must handle prerequisites first)
|
| 62 |
+
- [x] Episode ends at hour 48
|
| 63 |
+
- [x] Build 5 schema drift events:
|
| 64 |
+
- [x] Hour 20: Insurance deadline shortened from 72h to 48h
|
| 65 |
+
- [x] Hour 21: Evacuation zone expanded to include Zone B (workplace)
|
| 66 |
+
- [x] Hour 22.5: Employer emergency leave expanded (PTO → 5 days paid leave)
|
| 67 |
+
- [x] Hour 24.5: Airline extends free rebooking window from 48h to 7 days
|
| 68 |
+
- [x] Hour 34: FEMA adds new documentation requirements
|
| 69 |
+
- [x] Randomization: each episode triggers 3 of 5 drift events (seed-controlled)
|
| 70 |
+
- [x] Add episode variation: +/-15% jitter on arrival times, +/-10% on deadlines (seed-controlled)
|
| 71 |
+
- [x] Integrate with OpenEnv API: `reset()`, `step()` implemented; 5 MCP tools (get_inbox, read_message, respond_to_message, get_status, advance_time)
|
| 72 |
+
|
| 73 |
+
### Person B — Reward & Training Integration
|
| 74 |
+
|
| 75 |
+
- [x] Reward function implemented in two places:
|
| 76 |
+
- [x] Environment-side: `_calculate_reward()` in `crisis_inbox_environment.py` (used during live episodes)
|
| 77 |
+
- [x] Training-side: `score_action()` in notebook (parses model output, scores against inbox state)
|
| 78 |
+
- [x] Episode generator (`generate_episodes.py`) produces offline training data:
|
| 79 |
+
- [x] 50 episodes, 803 training prompts across 16 decision points per episode
|
| 80 |
+
- [x] Full message content, drift flags, superseded markers, dependency info
|
| 81 |
+
- [x] Prompts include urgency grouping, deadline warnings, stale markers
|
| 82 |
+
- [x] Connect to Unsloth GRPO: notebook loads episodes.json, builds HF Dataset, configures GRPOTrainer
|
| 83 |
+
- [ ] Run training on Colab with GPU and capture reward curves
|
| 84 |
+
- [ ] Log reward components separately to identify which signals are working
|
| 85 |
+
|
| 86 |
+
### Checkpoint — 6:00 PM (Dinner)
|
| 87 |
+
|
| 88 |
+
- [x] ✅ Environment generates realistic message streams (76 messages, 19 senders)
|
| 89 |
+
- [x] ✅ Schema drift events fire correctly (tested: superseded messages marked, drift rewards working)
|
| 90 |
+
- [x] ✅ Training notebook configured with GRPO, reward function, and evaluation
|
| 91 |
+
- [ ] ✅ First reward curves from actual GPU training run
|
| 92 |
+
- [ ] 🚨 If messages aren't generating: reduce to 30 messages, fewer senders
|
| 93 |
+
- [ ] 🚨 If training isn't connecting: hardcode environment responses, focus on reward signal
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
## PHASE 3: Integration & Iteration (6:00 PM – 10:00 PM Saturday)
|
| 98 |
+
|
| 99 |
+
### Person A — Environment Polish
|
| 100 |
+
|
| 101 |
+
- [x] Polish message content for realism and emotional impact (20+ messages rewritten)
|
| 102 |
+
- [x] Mom: panicked texting, crying voicemails, medical anxiety, guilt-trip to Tahoe
|
| 103 |
+
- [x] Boss (Greg): passive-aggressive emails with signature, softens after Meridian postpones
|
| 104 |
+
- [x] Sister: desperation about kids, voice cracking in voicemail, genuine gratitude
|
| 105 |
+
- [x] Neighbor Dave: casual bro tone, guilt about Whiskers, neighborhood solidarity
|
| 106 |
+
- [x] Emma (niece, age 7): 3 kid-perspective messages — pillow fights, scared in the dark, rainbow drawing
|
| 107 |
+
- [x] FEMA/NWS: kept formal and information-dense (already good)
|
| 108 |
+
- [x] Test all 5 drift events fire correctly and change environment state
|
| 109 |
+
- [x] Verify dependency chains work (23 messages with dependencies, gated in respond_to_message)
|
| 110 |
+
- [x] Edge case handling: stale info penalized (-50% reward), expired deadlines tracked in get_status
|
| 111 |
+
- [x] Redeploy updated environment to HF Spaces (73 messages, timeline engine, drift events)
|
| 112 |
+
- [x] Test HF Spaces deployment works end-to-end remotely (verified via client)
|
| 113 |
+
|
| 114 |
+
### Person B — Training Optimization
|
| 115 |
+
|
| 116 |
+
- [ ] Analyze initial reward curves — identify flat or noisy components
|
| 117 |
+
- [ ] Tune reward weights: if one signal dominates, rebalance
|
| 118 |
+
- [ ] If training is flat overall:
|
| 119 |
+
- [ ] Simplify action space (reduce to 4 actions)
|
| 120 |
+
- [ ] Increase reward magnitudes
|
| 121 |
+
- [ ] Reduce episode length
|
| 122 |
+
- [ ] Log specific before/after agent behaviors:
|
| 123 |
+
- [ ] Capture untrained agent: responds in order, ignores evacuation alert
|
| 124 |
+
- [ ] Capture trained agent: triages safety first, adapts to drift
|
| 125 |
+
- [ ] Save training checkpoints and reward curve data for demo
|
| 126 |
+
- [ ] Target: 200-500 episodes with visible upward trend
|
| 127 |
+
|
| 128 |
+
### Checkpoint — 10:00 PM (Doors Close)
|
| 129 |
+
|
| 130 |
+
- [x] ✅ Environment fully functional with drift events on HF Spaces
|
| 131 |
+
- [ ] ✅ Training curves show upward trend (pending GPU run)
|
| 132 |
+
- [ ] ✅ At least 2 clear before/after behavior examples captured
|
| 133 |
+
- [ ] ✅ All code pushed to GitHub
|
| 134 |
+
- [ ] 🚨 If reward curves are flat: simplify environment overnight, retrain Sunday AM
|
| 135 |
+
- [ ] 🚨 If HF Spaces is broken: prepare local demo as backup
|
| 136 |
+
- [ ] Both teammates agree on Sunday morning priority list before leaving
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## PHASE 4: Sunday Polish (9:00 AM – 12:00 PM Sunday)
|
| 141 |
+
|
| 142 |
+
### Person A — Demo & Presentation
|
| 143 |
+
|
| 144 |
+
- [x] Build demo display (`demo.py`): terminal-based visualization
|
| 145 |
+
- [x] Color-coded urgency levels (red bg = critical, red = high, yellow = medium, green = low)
|
| 146 |
+
- [x] Schema drift notifications with magenta banner
|
| 147 |
+
- [x] Agent action visualization with blue banner and reward display
|
| 148 |
+
- [x] Two strategies: smart triage vs naive (arrival order)
|
| 149 |
+
- [x] Comparison mode shows 55% improvement (157.8 vs 101.8 pts)
|
| 150 |
+
- [x] Coverage breakdown by urgency with ASCII progress bars
|
| 151 |
+
- [x] HF Spaces deployment live with polished environment
|
| 152 |
+
- [x] Write repo README:
|
| 153 |
+
- [x] Scenario hook, three layers of difficulty, sender profiles table
|
| 154 |
+
- [x] MCP tools table, reward function table, episode variation details
|
| 155 |
+
- [x] Quick start (hosted + local), repo structure, tech stack
|
| 156 |
+
- [ ] Draft the 3-minute pitch outline:
|
| 157 |
+
- [ ] 0:00-0:30 — The scenario hook ("A wildfire just hit. You have 47 unread messages.")
|
| 158 |
+
- [ ] 0:30-1:30 — Show the environment: message stream, drift events, conflicting tasks
|
| 159 |
+
- [ ] 1:30-2:15 — Untrained vs trained agent comparison
|
| 160 |
+
- [ ] 2:15-2:45 — Reward curves and training results
|
| 161 |
+
- [ ] 2:45-3:00 — Why this matters (real-world impact close)
|
| 162 |
+
|
| 163 |
+
### Person B — Training & Artifacts
|
| 164 |
+
|
| 165 |
+
- [ ] Run final training session for cleanest possible reward curves
|
| 166 |
+
- [ ] Export reward curve plots (all 5 components + composite)
|
| 167 |
+
- [ ] Export 3 specific before/after examples with clear narrative:
|
| 168 |
+
- [ ] Example 1: Untrained ignores FEMA evacuation → Trained prioritizes it first
|
| 169 |
+
- [ ] Example 2: Untrained misses insurance deadline after policy change → Trained adapts
|
| 170 |
+
- [ ] Example 3: Untrained sends form-letter reply to Mom → Trained matches emotional tone
|
| 171 |
+
- [x] Finalize Colab notebook:
|
| 172 |
+
- [x] Clean, commented code with markdown sections
|
| 173 |
+
- [x] Reward function with test cases (good action vs bad action vs junk)
|
| 174 |
+
- [x] Evaluation cell comparing trained model choices
|
| 175 |
+
- [ ] Reward curves visible when run (pending GPU run)
|
| 176 |
+
- [ ] Verify all required artifacts exist:
|
| 177 |
+
- [ ] Public GitHub repo
|
| 178 |
+
- [ ] HF Spaces deployment
|
| 179 |
+
- [ ] Colab training notebook
|
| 180 |
+
- [ ] YouTube video (not yet recorded)
|
| 181 |
+
|
| 182 |
+
### Checkpoint — 12:00 PM
|
| 183 |
+
|
| 184 |
+
- [x] ✅ Demo runs smoothly end-to-end (smart vs naive comparison, 55% improvement)
|
| 185 |
+
- [ ] ✅ Reward curves are clean and trend upward (pending GPU run)
|
| 186 |
+
- [ ] ✅ Before/after examples are compelling (pending GPU run)
|
| 187 |
+
- [x] ✅ README is complete
|
| 188 |
+
- [x] ✅ Colab notebook is clean (14 cells, markdown sections, reward function tests)
|
| 189 |
+
- [ ] 🚨 If demo is buggy: simplify to scripted walkthrough of one episode
|
| 190 |
+
- [ ] 🚨 If reward curves are still flat: show qualitative behavior improvement instead
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
## PHASE 5: Record & Submit (12:00 PM – 1:00 PM Sunday)
|
| 195 |
+
|
| 196 |
+
### Both Together
|
| 197 |
+
|
| 198 |
+
- [ ] Record one-minute YouTube video:
|
| 199 |
+
- [ ] Person A: voiceover narration
|
| 200 |
+
- [ ] Person B: screen capture / screen share
|
| 201 |
+
- [ ] Structure:
|
| 202 |
+
- [ ] 0:00-0:15 — "CrisisInbox: training AI to triage personal tasks during disasters"
|
| 203 |
+
- [ ] 0:15-0:35 — Show environment: messages arriving, drift event firing
|
| 204 |
+
- [ ] 0:35-0:50 — Before/after: untrained fails → trained succeeds
|
| 205 |
+
- [ ] 0:50-1:00 — Flash reward curves, close with impact statement
|
| 206 |
+
- [ ] Upload to YouTube (unlisted is fine)
|
| 207 |
+
- [ ] Rehearse 3-minute live pitch at least twice out loud
|
| 208 |
+
- [ ] Prepare for Q&A — practice answers to:
|
| 209 |
+
- [ ] "Why is this hard for LLMs?"
|
| 210 |
+
- [ ] "How does schema drift work specifically?"
|
| 211 |
+
- [ ] "What would you build next with more time?"
|
| 212 |
+
- [ ] "How is this different from a standard chatbot?"
|
| 213 |
+
- [ ] "How does the reward function handle edge cases?"
|
| 214 |
+
- [ ] **SUBMIT via Cerebral Valley form:**
|
| 215 |
+
- [ ] GitHub repo URL (public)
|
| 216 |
+
- [ ] HF Spaces URL
|
| 217 |
+
- [ ] YouTube video URL
|
| 218 |
+
- [ ] Colab notebook URL
|
| 219 |
+
- [ ] Select problem statement 3.2 + Patronus AI sub-theme
|
| 220 |
+
- [ ] Confirm submission received
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
## FINAL CHECKLIST — Before Judging Starts (1:15 PM Sunday)
|
| 225 |
+
|
| 226 |
+
- [ ] GitHub repo is public and all code is pushed
|
| 227 |
+
- [ ] HF Spaces deployment is live and accessible
|
| 228 |
+
- [ ] Colab notebook runs end-to-end
|
| 229 |
+
- [ ] YouTube video is uploaded and link works
|
| 230 |
+
- [ ] Submission form completed
|
| 231 |
+
- [ ] Both teammates can explain every part of the project
|
| 232 |
+
- [ ] Laptop charged for demo
|
| 233 |
+
- [ ] Demo runs without internet (backup plan if WiFi fails)
|
| 234 |
+
|
| 235 |
+
---
|
| 236 |
+
|
| 237 |
+
## EMERGENCY FALLBACKS
|
| 238 |
+
|
| 239 |
+
| Problem | Fallback |
|
| 240 |
+
|---------|----------|
|
| 241 |
+
| OpenEnv won't work | Build gym-style env, wrap in OpenEnv interface last |
|
| 242 |
+
| HF Spaces deploy fails | Run demo locally, screenshot HF attempt for judges |
|
| 243 |
+
| Training won't converge | Simplify to 3 actions, 20 messages, 2 drift events |
|
| 244 |
+
| Reward curves flat | Show qualitative behavior change, explain reward design |
|
| 245 |
+
| YouTube upload fails | Screen-record on phone, upload from mobile |
|
| 246 |
+
| One teammate burns out | Other covers — both should understand full stack |
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
## KEY RULE
|
| 251 |
+
|
| 252 |
+
**If something isn't working after 45 minutes, simplify. A working simple version always beats a broken ambitious one.**
|
debug_env.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Debug script to inspect environment response format."""
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.insert(0, r'C:\Users\smrit\CascadeProjects\crisis-inbox-test')
|
| 4 |
+
|
| 5 |
+
from crisis_inbox import CrisisInboxEnv
|
| 6 |
+
|
| 7 |
+
with CrisisInboxEnv(base_url="https://eptan-crisis-inbox.hf.space") as env:
|
| 8 |
+
env.reset()
|
| 9 |
+
|
| 10 |
+
# List available tools
|
| 11 |
+
tools = env.list_tools()
|
| 12 |
+
print(f"Available tools ({len(tools)}):")
|
| 13 |
+
for t in tools:
|
| 14 |
+
print(f" - {t.name}: {t.description[:60]}...")
|
| 15 |
+
|
| 16 |
+
# Get inbox and inspect
|
| 17 |
+
print("\n--- Calling get_inbox ---")
|
| 18 |
+
inbox = env.call_tool("get_inbox")
|
| 19 |
+
print(f"Type: {type(inbox)}")
|
| 20 |
+
print(f"Content preview: {str(inbox)[:500]}")
|
| 21 |
+
|
| 22 |
+
# Try to get a specific message
|
| 23 |
+
print("\n--- Trying read_message ---")
|
| 24 |
+
try:
|
| 25 |
+
msg = env.call_tool("read_message", message_id="msg_001")
|
| 26 |
+
print(f"Type: {type(msg)}")
|
| 27 |
+
print(f"Content: {str(msg)[:300]}")
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Error: {e}")
|
| 30 |
+
|
| 31 |
+
# Try status
|
| 32 |
+
print("\n--- Calling get_status ---")
|
| 33 |
+
try:
|
| 34 |
+
status = env.call_tool("get_status")
|
| 35 |
+
print(f"Type: {type(status)}")
|
| 36 |
+
print(f"Content: {str(status)[:300]}")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Error: {e}")
|
demo.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CrisisInbox Demo Display
|
| 4 |
+
|
| 5 |
+
Runs a full episode with color-coded urgency levels, drift event
|
| 6 |
+
notifications, and agent action visualization. Designed for the
|
| 7 |
+
3-minute hackathon demo.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python demo.py # Run against local server
|
| 11 |
+
python demo.py --remote # Run against HF Spaces
|
| 12 |
+
python demo.py --strategy smart # Use smart triage (default)
|
| 13 |
+
python demo.py --strategy naive # Use naive arrival-order strategy
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
from server.crisis_inbox_environment import CrisisInboxEnvironment
|
| 22 |
+
from openenv.core.env_server.mcp_types import CallToolAction
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# ANSI colors
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
class C:
|
| 29 |
+
RED = "\033[91m"
|
| 30 |
+
YELLOW = "\033[93m"
|
| 31 |
+
GREEN = "\033[92m"
|
| 32 |
+
BLUE = "\033[94m"
|
| 33 |
+
MAGENTA = "\033[95m"
|
| 34 |
+
CYAN = "\033[96m"
|
| 35 |
+
WHITE = "\033[97m"
|
| 36 |
+
GRAY = "\033[90m"
|
| 37 |
+
BOLD = "\033[1m"
|
| 38 |
+
DIM = "\033[2m"
|
| 39 |
+
UNDERLINE = "\033[4m"
|
| 40 |
+
RESET = "\033[0m"
|
| 41 |
+
BG_RED = "\033[41m"
|
| 42 |
+
BG_YELLOW = "\033[43m"
|
| 43 |
+
BG_BLUE = "\033[44m"
|
| 44 |
+
BG_MAGENTA = "\033[45m"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
URGENCY_COLORS = {
|
| 48 |
+
"critical": C.BG_RED + C.WHITE + C.BOLD,
|
| 49 |
+
"high": C.RED + C.BOLD,
|
| 50 |
+
"medium": C.YELLOW,
|
| 51 |
+
"low": C.GREEN,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
URGENCY_LABELS = {
|
| 55 |
+
"critical": " CRITICAL ",
|
| 56 |
+
"high": " HIGH ",
|
| 57 |
+
"medium": " MEDIUM ",
|
| 58 |
+
"low": " LOW ",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
CHANNEL_ICONS = {
|
| 62 |
+
"sms": "SMS",
|
| 63 |
+
"email": "EMAIL",
|
| 64 |
+
"phone": "PHONE",
|
| 65 |
+
"government_alert": "ALERT",
|
| 66 |
+
"app_notification": "APP",
|
| 67 |
+
"social_media": "SOCIAL",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Display helpers
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
def header(text: str):
|
| 75 |
+
width = 72
|
| 76 |
+
print(f"\n{C.BOLD}{C.CYAN}{'=' * width}{C.RESET}")
|
| 77 |
+
print(f"{C.BOLD}{C.CYAN} {text}{C.RESET}")
|
| 78 |
+
print(f"{C.BOLD}{C.CYAN}{'=' * width}{C.RESET}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def subheader(text: str):
|
| 82 |
+
print(f"\n{C.BOLD}{C.WHITE}--- {text} ---{C.RESET}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def show_time(hour: float):
|
| 86 |
+
h = int(hour)
|
| 87 |
+
m = int((hour - h) * 60)
|
| 88 |
+
day = "Day 1" if hour < 24 else "Day 2"
|
| 89 |
+
display_h = h % 24
|
| 90 |
+
ampm = "AM" if display_h < 12 else "PM"
|
| 91 |
+
display_h = display_h % 12 or 12
|
| 92 |
+
return f"{day} {display_h}:{m:02d} {ampm} (hour {hour:.1f}/48)"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def show_message_line(msg: dict, prefix: str = ""):
|
| 96 |
+
urg = msg["urgency"]
|
| 97 |
+
color = URGENCY_COLORS.get(urg, "")
|
| 98 |
+
label = URGENCY_LABELS.get(urg, urg.upper())
|
| 99 |
+
channel = CHANNEL_ICONS.get(msg["channel"], msg["channel"])
|
| 100 |
+
handled = f"{C.GRAY}[DONE]{C.RESET} " if msg.get("handled") else ""
|
| 101 |
+
superseded = f"{C.GRAY}[STALE]{C.RESET} " if msg.get("superseded") else ""
|
| 102 |
+
drift = f"{C.BG_MAGENTA}{C.WHITE} DRIFT {C.RESET} " if msg.get("drift_flag") else ""
|
| 103 |
+
deadline_str = ""
|
| 104 |
+
if msg.get("deadline_hours") is not None:
|
| 105 |
+
deadline_str = f" {C.DIM}(due h{msg['deadline_hours']:.0f}){C.RESET}"
|
| 106 |
+
|
| 107 |
+
print(
|
| 108 |
+
f" {prefix}{color}[{label}]{C.RESET} "
|
| 109 |
+
f"{C.DIM}{channel:>5}{C.RESET} "
|
| 110 |
+
f"{drift}{superseded}{handled}"
|
| 111 |
+
f"{C.BOLD}{msg['sender']}{C.RESET}: {msg['subject']}"
|
| 112 |
+
f"{deadline_str}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def show_action(action_text: str):
|
| 117 |
+
print(f"\n {C.BG_BLUE}{C.WHITE}{C.BOLD} AGENT ACTION {C.RESET} {C.CYAN}{action_text}{C.RESET}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def show_reward(reward: float, total: float):
|
| 121 |
+
color = C.GREEN if reward > 0 else C.RED
|
| 122 |
+
print(f" {color}+{reward:.1f} pts{C.RESET} {C.DIM}(total: {total:.1f}){C.RESET}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def show_drift_alert(msg: dict):
|
| 126 |
+
print(f"\n {C.BG_MAGENTA}{C.WHITE}{C.BOLD} SCHEMA DRIFT {C.RESET} "
|
| 127 |
+
f"{C.MAGENTA}{C.BOLD}{msg['sender']}: {msg['subject']}{C.RESET}")
|
| 128 |
+
print(f" {C.MAGENTA}Rules have changed! Previous information may be outdated.{C.RESET}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def show_expired(expired: list):
|
| 132 |
+
if expired:
|
| 133 |
+
print(f"\n {C.RED}{C.BOLD}EXPIRED DEADLINES:{C.RESET}")
|
| 134 |
+
for e in expired:
|
| 135 |
+
print(f" {C.RED}x {e['subject']}{C.RESET}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def show_upcoming(upcoming: list):
|
| 139 |
+
if upcoming:
|
| 140 |
+
print(f"\n {C.YELLOW}UPCOMING DEADLINES:{C.RESET}")
|
| 141 |
+
for u in upcoming:
|
| 142 |
+
print(f" {C.YELLOW}! {u['subject']} ({u['hours_remaining']:.1f}h left){C.RESET}")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def pause(seconds: float = 0.5):
|
| 146 |
+
time.sleep(seconds)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# Agent strategies
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
def smart_priority(messages: list[dict]) -> list[dict]:
|
| 153 |
+
"""Triage: safety first, then deadlines, then drift, then urgency."""
|
| 154 |
+
urgency_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
|
| 155 |
+
unhandled = [m for m in messages if not m.get("handled") and not m.get("superseded")]
|
| 156 |
+
|
| 157 |
+
def score(m):
|
| 158 |
+
urg = urgency_order.get(m["urgency"], 4)
|
| 159 |
+
deadline = m.get("deadline_hours") or 999
|
| 160 |
+
drift = 0 if m.get("drift_flag") else 1
|
| 161 |
+
return (urg, drift, deadline)
|
| 162 |
+
|
| 163 |
+
return sorted(unhandled, key=score)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def naive_order(messages: list[dict]) -> list[dict]:
|
| 167 |
+
"""Respond in arrival order (bad strategy for comparison)."""
|
| 168 |
+
return [m for m in messages if not m.get("handled") and not m.get("superseded")]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------------------------
|
| 172 |
+
# Main demo loop
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
def run_demo(strategy: str = "smart", seed: int = 42, speed: float = 0.3):
|
| 175 |
+
env = CrisisInboxEnvironment()
|
| 176 |
+
obs = env.reset(seed=seed)
|
| 177 |
+
|
| 178 |
+
def call(tool, **kwargs):
|
| 179 |
+
o = env.step(CallToolAction(type="call_tool", tool_name=tool, arguments=kwargs))
|
| 180 |
+
return json.loads(o.result.data)
|
| 181 |
+
|
| 182 |
+
prioritize = smart_priority if strategy == "smart" else naive_order
|
| 183 |
+
strategy_name = "SMART TRIAGE" if strategy == "smart" else "NAIVE (arrival order)"
|
| 184 |
+
|
| 185 |
+
header(f"CrisisInbox Demo | Strategy: {strategy_name}")
|
| 186 |
+
print(f"\n {C.DIM}Scenario: Post-hurricane evacuation in Sacramento")
|
| 187 |
+
print(f" You are a working parent. Your phone is about to explode.{C.RESET}")
|
| 188 |
+
print(f" {C.DIM}Messages: {obs.metadata['messages_total']} total | "
|
| 189 |
+
f"Drift events: {obs.metadata['drift_events_scheduled']} scheduled{C.RESET}")
|
| 190 |
+
pause(speed * 2)
|
| 191 |
+
|
| 192 |
+
last_drift_count = 0
|
| 193 |
+
actions_taken = 0
|
| 194 |
+
max_actions = 25 # Cap for demo length
|
| 195 |
+
|
| 196 |
+
# Advance through the timeline in chunks
|
| 197 |
+
time_chunks = [0, 2, 6, 12, 20, 25, 34, 44, 48]
|
| 198 |
+
|
| 199 |
+
for i in range(len(time_chunks) - 1):
|
| 200 |
+
target_hour = time_chunks[i + 1]
|
| 201 |
+
current = time_chunks[i]
|
| 202 |
+
|
| 203 |
+
# Advance time to target
|
| 204 |
+
while current < target_hour:
|
| 205 |
+
advance = min(4.0, target_hour - current)
|
| 206 |
+
result = call("advance_time", hours=advance)
|
| 207 |
+
current = result["current_hour"]
|
| 208 |
+
if result.get("new_messages", 0) > 0:
|
| 209 |
+
pass # Messages delivered silently
|
| 210 |
+
|
| 211 |
+
# Get current state
|
| 212 |
+
status = call("get_status")
|
| 213 |
+
inbox = call("get_inbox")
|
| 214 |
+
|
| 215 |
+
subheader(f"HOUR {status['current_hour']:.1f} | {show_time(status['current_hour'])}")
|
| 216 |
+
print(f" {C.DIM}Messages arrived: {status['messages_total_arrived']} | "
|
| 217 |
+
f"Handled: {status['messages_handled']} | "
|
| 218 |
+
f"Score: {status['total_score']:.1f}{C.RESET}")
|
| 219 |
+
|
| 220 |
+
# Check for new drift events
|
| 221 |
+
if status["drift_events_fired"] > last_drift_count:
|
| 222 |
+
drift_msgs = [m for m in inbox if m.get("drift_flag") and not m.get("handled")]
|
| 223 |
+
for dm in drift_msgs:
|
| 224 |
+
show_drift_alert(dm)
|
| 225 |
+
pause(speed)
|
| 226 |
+
last_drift_count = status["drift_events_fired"]
|
| 227 |
+
|
| 228 |
+
# Show deadlines
|
| 229 |
+
show_expired(status.get("expired_deadlines", []))
|
| 230 |
+
show_upcoming(status.get("upcoming_deadlines", []))
|
| 231 |
+
|
| 232 |
+
# Show current inbox
|
| 233 |
+
print(f"\n {C.BOLD}INBOX ({len(inbox)} messages):{C.RESET}")
|
| 234 |
+
# Show only recent unhandled, plus any drift
|
| 235 |
+
visible = [m for m in inbox if not m.get("handled")]
|
| 236 |
+
for msg in visible[:8]: # Show max 8 at a time
|
| 237 |
+
show_message_line(msg)
|
| 238 |
+
if len(visible) > 8:
|
| 239 |
+
print(f" {C.DIM} ... and {len(visible) - 8} more unread{C.RESET}")
|
| 240 |
+
pause(speed)
|
| 241 |
+
|
| 242 |
+
# Agent takes actions on available messages
|
| 243 |
+
prioritized = prioritize(inbox)
|
| 244 |
+
actions_this_chunk = 0
|
| 245 |
+
max_per_chunk = 4 # Don't monopolize one time period
|
| 246 |
+
|
| 247 |
+
for msg in prioritized:
|
| 248 |
+
if actions_taken >= max_actions or actions_this_chunk >= max_per_chunk:
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
msg_id = msg["id"]
|
| 252 |
+
|
| 253 |
+
# Read the message first
|
| 254 |
+
full_msg = call("read_message", message_id=msg_id)
|
| 255 |
+
if "error" in full_msg:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
# Generate a contextual response based on sender/urgency
|
| 259 |
+
response = _generate_response(full_msg)
|
| 260 |
+
|
| 261 |
+
# Try to respond
|
| 262 |
+
result = call("respond_to_message", message_id=msg_id, response=response)
|
| 263 |
+
if "error" in result:
|
| 264 |
+
if "dependencies" in result.get("error", ""):
|
| 265 |
+
show_action(f"Cannot handle {msg_id} yet - dependencies unmet")
|
| 266 |
+
pause(speed * 0.5)
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
show_action(f"Respond to {msg['sender']} - \"{msg['subject']}\"")
|
| 270 |
+
print(f" {C.DIM}\"{response[:80]}{'...' if len(response) > 80 else ''}\"{C.RESET}")
|
| 271 |
+
show_reward(result["reward"], result["total_score"])
|
| 272 |
+
pause(speed)
|
| 273 |
+
|
| 274 |
+
actions_taken += 1
|
| 275 |
+
actions_this_chunk += 1
|
| 276 |
+
|
| 277 |
+
if status.get("done"):
|
| 278 |
+
break
|
| 279 |
+
|
| 280 |
+
# Final summary
|
| 281 |
+
final_status = call("get_status")
|
| 282 |
+
final_inbox = call("get_inbox")
|
| 283 |
+
|
| 284 |
+
header("EPISODE COMPLETE")
|
| 285 |
+
handled_count = final_status["messages_handled"]
|
| 286 |
+
total_arrived = final_status["messages_total_arrived"]
|
| 287 |
+
missed = len(final_status.get("expired_deadlines", []))
|
| 288 |
+
|
| 289 |
+
print(f"\n Strategy: {C.BOLD}{strategy_name}{C.RESET}")
|
| 290 |
+
print(f" Final Score: {C.BOLD}{C.GREEN}{final_status['total_score']:.1f} pts{C.RESET}")
|
| 291 |
+
print(f" Messages Handled: {handled_count}/{total_arrived}")
|
| 292 |
+
print(f" Deadlines Missed: {C.RED}{missed}{C.RESET}")
|
| 293 |
+
print(f" Drift Events Encountered: {final_status['drift_events_fired']}")
|
| 294 |
+
|
| 295 |
+
# Show what was handled vs missed by urgency
|
| 296 |
+
handled_ids = {m["id"] for m in final_inbox if m.get("handled")}
|
| 297 |
+
by_urgency = {"critical": [0, 0], "high": [0, 0], "medium": [0, 0], "low": [0, 0]}
|
| 298 |
+
for m in final_inbox:
|
| 299 |
+
urg = m["urgency"]
|
| 300 |
+
if urg in by_urgency:
|
| 301 |
+
by_urgency[urg][1] += 1
|
| 302 |
+
if m["id"] in handled_ids:
|
| 303 |
+
by_urgency[urg][0] += 1
|
| 304 |
+
|
| 305 |
+
print(f"\n {C.BOLD}Coverage by Urgency:{C.RESET}")
|
| 306 |
+
for urg, (handled, total) in by_urgency.items():
|
| 307 |
+
color = URGENCY_COLORS.get(urg, "")
|
| 308 |
+
pct = (handled / total * 100) if total > 0 else 0
|
| 309 |
+
bar = "#" * int(pct / 5) + "." * (20 - int(pct / 5))
|
| 310 |
+
print(f" {color}{urg:>8}{C.RESET}: {handled}/{total} [{bar}] {pct:.0f}%")
|
| 311 |
+
|
| 312 |
+
print()
|
| 313 |
+
return final_status["total_score"]
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _generate_response(msg: dict) -> str:
|
| 317 |
+
"""Generate a contextual response based on sender and content."""
|
| 318 |
+
sender = msg.get("sender", "")
|
| 319 |
+
urgency = msg.get("urgency", "")
|
| 320 |
+
subject = msg.get("subject", "").lower()
|
| 321 |
+
|
| 322 |
+
if "national weather" in sender.lower() or "fema" in sender.lower():
|
| 323 |
+
return "Acknowledged. Following evacuation orders immediately. Heading to designated shelter with essential documents and medication."
|
| 324 |
+
if sender == "Mom":
|
| 325 |
+
return "I'm safe, Mom. Don't worry. I'm following the evacuation orders and heading to the shelter. I'll call you as soon as I can. Love you."
|
| 326 |
+
if sender == "Sister":
|
| 327 |
+
if "kids" in subject or "take" in subject:
|
| 328 |
+
return "I'll get them, don't worry. Heading to Oakwood now. I'll text you the second I have them safe."
|
| 329 |
+
return "Kids are safe with me. Emma is being brave. Don't worry about anything, just stay safe yourself."
|
| 330 |
+
if "emma" in sender.lower():
|
| 331 |
+
return "Hey sweetie! Mac and cheese sounds perfect. Mommy is coming soon. Let's build a blanket fort while we wait!"
|
| 332 |
+
if sender == "Boss" or sender == "HR Department":
|
| 333 |
+
if "drift" in str(msg.get("drift_flag", "")):
|
| 334 |
+
return "Thanks for the update on the emergency leave policy. I've submitted my status form on the HR portal. Will be taking the emergency leave days."
|
| 335 |
+
return "I'm in the evacuation zone and following mandatory orders. Will work remotely when I can access wifi. Updating my status on the portal now."
|
| 336 |
+
if "insurance" in sender.lower() or "state farm" in sender.lower():
|
| 337 |
+
return "Filing claim now with policy number and damage photos. Have documented all damage with timestamps before any cleanup."
|
| 338 |
+
if sender == "Neighbor Dave":
|
| 339 |
+
return "Thanks for the heads up Dave. Stay safe at the shelter. I'll keep an eye on things here. We'll get through this."
|
| 340 |
+
if "delta" in sender.lower() or "airline" in sender.lower():
|
| 341 |
+
return "Selecting Option A for rebooking to the earliest available flight. Thank you for the flexibility during this emergency."
|
| 342 |
+
if "school" in sender.lower() or "oakwood" in sender.lower():
|
| 343 |
+
return "Acknowledged. Will arrange pickup before the deadline. Thank you for the early notification."
|
| 344 |
+
if "pharmacy" in sender.lower() or "cvs" in sender.lower():
|
| 345 |
+
return "Will pick up the prescription today. If the usual location is closed, I'll transfer to the nearest open CVS."
|
| 346 |
+
if "sacramento" in sender.lower():
|
| 347 |
+
return "Acknowledged. Following all advisories and avoiding affected areas."
|
| 348 |
+
if urgency == "critical":
|
| 349 |
+
return "Taking immediate action on this critical matter. Will follow up with details shortly."
|
| 350 |
+
if urgency == "high":
|
| 351 |
+
return "Understood, this is a priority. Handling this now and will confirm when complete."
|
| 352 |
+
return "Thank you for the information. I've noted this and will address it as soon as possible given the current situation."
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ---------------------------------------------------------------------------
|
| 356 |
+
# Comparison mode
|
| 357 |
+
# ---------------------------------------------------------------------------
|
| 358 |
+
def run_comparison(seed: int = 42, speed: float = 0.2):
|
| 359 |
+
"""Run both strategies side by side for the demo."""
|
| 360 |
+
header("CRISISINBOX: STRATEGY COMPARISON")
|
| 361 |
+
print(f"\n {C.DIM}Same episode, two different approaches.{C.RESET}")
|
| 362 |
+
print(f" {C.DIM}Which agent handles a disaster better?{C.RESET}\n")
|
| 363 |
+
pause(1)
|
| 364 |
+
|
| 365 |
+
print(f"{C.RED}{C.BOLD}{'=' * 72}")
|
| 366 |
+
print(f" ROUND 1: NAIVE AGENT (responds in arrival order)")
|
| 367 |
+
print(f"{'=' * 72}{C.RESET}")
|
| 368 |
+
pause(0.5)
|
| 369 |
+
naive_score = run_demo(strategy="naive", seed=seed, speed=speed)
|
| 370 |
+
|
| 371 |
+
pause(1)
|
| 372 |
+
print(f"\n{C.GREEN}{C.BOLD}{'=' * 72}")
|
| 373 |
+
print(f" ROUND 2: TRAINED AGENT (smart triage)")
|
| 374 |
+
print(f"{'=' * 72}{C.RESET}")
|
| 375 |
+
pause(0.5)
|
| 376 |
+
smart_score = run_demo(strategy="smart", seed=seed, speed=speed)
|
| 377 |
+
|
| 378 |
+
header("FINAL COMPARISON")
|
| 379 |
+
improvement = smart_score - naive_score
|
| 380 |
+
pct = (improvement / max(naive_score, 1)) * 100
|
| 381 |
+
print(f"\n Naive Agent: {C.RED}{naive_score:.1f} pts{C.RESET}")
|
| 382 |
+
print(f" Trained Agent: {C.GREEN}{smart_score:.1f} pts{C.RESET}")
|
| 383 |
+
print(f" Improvement: {C.BOLD}{C.CYAN}+{improvement:.1f} pts ({pct:.0f}%){C.RESET}\n")
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# ---------------------------------------------------------------------------
|
| 387 |
+
# Entry point
|
| 388 |
+
# ---------------------------------------------------------------------------
|
| 389 |
+
if __name__ == "__main__":
|
| 390 |
+
parser = argparse.ArgumentParser(description="CrisisInbox Demo")
|
| 391 |
+
parser.add_argument("--strategy", choices=["smart", "naive", "compare"],
|
| 392 |
+
default="compare", help="Agent strategy (default: compare)")
|
| 393 |
+
parser.add_argument("--seed", type=int, default=42, help="Episode seed")
|
| 394 |
+
parser.add_argument("--speed", type=float, default=0.3,
|
| 395 |
+
help="Pause between actions in seconds (default: 0.3)")
|
| 396 |
+
args = parser.parse_args()
|
| 397 |
+
|
| 398 |
+
if args.strategy == "compare":
|
| 399 |
+
run_comparison(seed=args.seed, speed=args.speed)
|
| 400 |
+
else:
|
| 401 |
+
run_demo(strategy=args.strategy, seed=args.seed, speed=args.speed)
|
drift_events.py
CHANGED
|
@@ -86,9 +86,10 @@ DRIFT_EVACUATION_EXPANSION = DriftEvent(
|
|
| 86 |
channel=Channel.SMS,
|
| 87 |
subject="Office is in Zone B - disregard earlier email",
|
| 88 |
content=(
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
|
|
|
| 92 |
),
|
| 93 |
urgency=Urgency.MEDIUM,
|
| 94 |
timestamp_hours=21.5,
|
|
|
|
| 86 |
channel=Channel.SMS,
|
| 87 |
subject="Office is in Zone B - disregard earlier email",
|
| 88 |
content=(
|
| 89 |
+
"Well, our office is in Zone B. So obviously disregard the 'come in tomorrow' "
|
| 90 |
+
"thing. But I STILL need those slides. Can you work on them from wherever "
|
| 91 |
+
"you are? Sarah said she can pick up some of the slack but I need to know "
|
| 92 |
+
"what you can deliver. This client won't wait for a hurricane."
|
| 93 |
),
|
| 94 |
urgency=Urgency.MEDIUM,
|
| 95 |
timestamp_hours=21.5,
|
generate_episodes.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Episode Generator for CrisisInbox GRPO Training.
|
| 3 |
+
|
| 4 |
+
Generates training episodes locally (no server needed) by simulating the
|
| 5 |
+
environment and capturing inbox snapshots at key decision points.
|
| 6 |
+
|
| 7 |
+
Each episode produces multiple training prompts — one per decision point —
|
| 8 |
+
where the model must choose which message to handle next.
|
| 9 |
+
|
| 10 |
+
Output format (per episode):
|
| 11 |
+
{
|
| 12 |
+
"episode_id": "ep_000",
|
| 13 |
+
"seed": 42,
|
| 14 |
+
"total_messages": 73,
|
| 15 |
+
"drift_events": ["drift_insurance", "drift_evacuation", "drift_fema"],
|
| 16 |
+
"decision_points": [
|
| 17 |
+
{
|
| 18 |
+
"hour": 0.0,
|
| 19 |
+
"prompt": "...", # full text prompt for the LLM
|
| 20 |
+
"visible_messages": [...], # inbox snapshot
|
| 21 |
+
"handled_ids": [],
|
| 22 |
+
"pending_deadlines": [...],
|
| 23 |
+
"drift_events_fired": []
|
| 24 |
+
},
|
| 25 |
+
...
|
| 26 |
+
]
|
| 27 |
+
}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import json
|
| 31 |
+
import random
|
| 32 |
+
from typing import Any
|
| 33 |
+
|
| 34 |
+
from models import Channel, Message, Urgency
|
| 35 |
+
from messages import ALL_MESSAGES
|
| 36 |
+
from drift_events import ALL_DRIFT_EVENTS, DriftEvent, select_drift_events
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
SYSTEM_PROMPT = """You are managing a personal crisis inbox during a post-hurricane evacuation in Sacramento. You are a working parent with 48 hours to triage incoming messages from family, employer, government, insurance, and service providers.
|
| 40 |
+
|
| 41 |
+
Rules:
|
| 42 |
+
- Reading a message costs 0.1 hours (6 minutes)
|
| 43 |
+
- Responding to a message costs 0.25 hours (15 minutes)
|
| 44 |
+
- You cannot handle everything — prioritize wisely
|
| 45 |
+
- Safety-critical messages (evacuations, medical) should come first
|
| 46 |
+
- Watch for policy changes that supersede earlier information
|
| 47 |
+
- Some messages have dependencies that must be handled first
|
| 48 |
+
- Deadlines are real — missing them reduces your score
|
| 49 |
+
|
| 50 |
+
Available actions:
|
| 51 |
+
- respond_to_message(message_id, response) — handle a message
|
| 52 |
+
- advance_time(hours) — skip forward to see new messages
|
| 53 |
+
- get_status() — check time, score, deadlines"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def build_episode(seed: int) -> dict[str, Any]:
|
| 57 |
+
"""Build a single training episode with decision-point snapshots."""
|
| 58 |
+
rng = random.Random(seed)
|
| 59 |
+
|
| 60 |
+
# Select drift events
|
| 61 |
+
drift_events = select_drift_events(count=3, rng=rng)
|
| 62 |
+
drift_event_ids = [d.id for d in drift_events]
|
| 63 |
+
|
| 64 |
+
# Collect drift message IDs
|
| 65 |
+
selected_drift_msg_ids = set()
|
| 66 |
+
for drift in drift_events:
|
| 67 |
+
for msg in drift.messages:
|
| 68 |
+
selected_drift_msg_ids.add(msg.id)
|
| 69 |
+
|
| 70 |
+
all_drift_msg_ids = set()
|
| 71 |
+
for drift in ALL_DRIFT_EVENTS:
|
| 72 |
+
for msg in drift.messages:
|
| 73 |
+
all_drift_msg_ids.add(msg.id)
|
| 74 |
+
|
| 75 |
+
# Build message pool with jitter
|
| 76 |
+
all_messages = []
|
| 77 |
+
for msg in ALL_MESSAGES:
|
| 78 |
+
if msg.id in all_drift_msg_ids and msg.id not in selected_drift_msg_ids:
|
| 79 |
+
continue
|
| 80 |
+
m = msg.model_copy()
|
| 81 |
+
if m.timestamp_hours > 0:
|
| 82 |
+
jitter = rng.uniform(-0.15, 0.15) * m.timestamp_hours
|
| 83 |
+
m.timestamp_hours = round(max(0.1, min(47.5, m.timestamp_hours + jitter)), 2)
|
| 84 |
+
if m.deadline_hours is not None and m.deadline_hours > 0:
|
| 85 |
+
d_jitter = rng.uniform(-0.1, 0.1) * m.deadline_hours
|
| 86 |
+
m.deadline_hours = round(max(m.timestamp_hours + 0.5, min(72.0, m.deadline_hours + d_jitter)), 2)
|
| 87 |
+
all_messages.append(m)
|
| 88 |
+
|
| 89 |
+
# Also add drift event messages
|
| 90 |
+
for drift in drift_events:
|
| 91 |
+
for msg in drift.messages:
|
| 92 |
+
if not any(m.id == msg.id for m in all_messages):
|
| 93 |
+
all_messages.append(msg)
|
| 94 |
+
|
| 95 |
+
# Sort all messages by arrival time
|
| 96 |
+
all_messages.sort(key=lambda m: m.timestamp_hours)
|
| 97 |
+
|
| 98 |
+
# Track superseded messages
|
| 99 |
+
superseded = {}
|
| 100 |
+
for drift in drift_events:
|
| 101 |
+
for old_id in drift.superseded_msg_ids:
|
| 102 |
+
for dmsg in drift.messages:
|
| 103 |
+
if dmsg.supersedes == old_id:
|
| 104 |
+
superseded[old_id] = dmsg.id
|
| 105 |
+
|
| 106 |
+
# Simulate the episode at key time points to capture decision snapshots
|
| 107 |
+
decision_hours = [0.0, 2.0, 6.0, 10.0, 14.0, 18.0]
|
| 108 |
+
# Add drift trigger hours
|
| 109 |
+
for drift in drift_events:
|
| 110 |
+
decision_hours.append(drift.trigger_hour)
|
| 111 |
+
decision_hours.append(drift.trigger_hour + 1.0)
|
| 112 |
+
# Add late-game hours
|
| 113 |
+
decision_hours.extend([28.0, 34.0, 40.0, 44.0, 47.0])
|
| 114 |
+
decision_hours = sorted(set(decision_hours))
|
| 115 |
+
|
| 116 |
+
decision_points = []
|
| 117 |
+
fired_drifts = set()
|
| 118 |
+
|
| 119 |
+
for hour in decision_hours:
|
| 120 |
+
if hour > 48.0:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# Deliver messages visible at this hour
|
| 124 |
+
visible = [m for m in all_messages if m.timestamp_hours <= hour]
|
| 125 |
+
|
| 126 |
+
# Fire drift events
|
| 127 |
+
newly_fired = []
|
| 128 |
+
for drift in drift_events:
|
| 129 |
+
if drift.id not in fired_drifts and hour >= drift.trigger_hour:
|
| 130 |
+
fired_drifts.add(drift.id)
|
| 131 |
+
newly_fired.append(drift.id)
|
| 132 |
+
|
| 133 |
+
# Build inbox summary
|
| 134 |
+
visible_summaries = []
|
| 135 |
+
for msg in visible:
|
| 136 |
+
is_superseded = msg.id in superseded
|
| 137 |
+
summary = {
|
| 138 |
+
"id": msg.id,
|
| 139 |
+
"sender": msg.sender,
|
| 140 |
+
"subject": msg.subject,
|
| 141 |
+
"content": msg.content,
|
| 142 |
+
"urgency": msg.urgency.value,
|
| 143 |
+
"channel": msg.channel.value,
|
| 144 |
+
"timestamp_hours": msg.timestamp_hours,
|
| 145 |
+
"deadline_hours": msg.deadline_hours,
|
| 146 |
+
"dependencies": msg.dependencies,
|
| 147 |
+
"drift_flag": msg.drift_flag,
|
| 148 |
+
"superseded": is_superseded,
|
| 149 |
+
}
|
| 150 |
+
visible_summaries.append(summary)
|
| 151 |
+
|
| 152 |
+
# Identify pending deadlines
|
| 153 |
+
pending_deadlines = []
|
| 154 |
+
for msg in visible:
|
| 155 |
+
if msg.deadline_hours is not None:
|
| 156 |
+
remaining = msg.deadline_hours - hour
|
| 157 |
+
if remaining > 0:
|
| 158 |
+
pending_deadlines.append({
|
| 159 |
+
"id": msg.id,
|
| 160 |
+
"subject": msg.subject,
|
| 161 |
+
"urgency": msg.urgency.value,
|
| 162 |
+
"hours_remaining": round(remaining, 1),
|
| 163 |
+
})
|
| 164 |
+
elif remaining > -2: # recently expired
|
| 165 |
+
pending_deadlines.append({
|
| 166 |
+
"id": msg.id,
|
| 167 |
+
"subject": msg.subject,
|
| 168 |
+
"urgency": msg.urgency.value,
|
| 169 |
+
"hours_remaining": round(remaining, 1),
|
| 170 |
+
"expired": True,
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
# Build the text prompt
|
| 174 |
+
prompt = format_prompt(hour, visible_summaries, pending_deadlines, list(fired_drifts))
|
| 175 |
+
|
| 176 |
+
decision_points.append({
|
| 177 |
+
"hour": hour,
|
| 178 |
+
"prompt": prompt,
|
| 179 |
+
"visible_count": len(visible),
|
| 180 |
+
"visible_messages": visible_summaries,
|
| 181 |
+
"pending_deadlines": pending_deadlines,
|
| 182 |
+
"drift_events_fired": list(fired_drifts),
|
| 183 |
+
"newly_fired_drifts": newly_fired,
|
| 184 |
+
})
|
| 185 |
+
|
| 186 |
+
return {
|
| 187 |
+
"episode_id": f"ep_{seed:03d}",
|
| 188 |
+
"seed": seed,
|
| 189 |
+
"total_messages": len(all_messages),
|
| 190 |
+
"drift_events": drift_event_ids,
|
| 191 |
+
"superseded_messages": superseded,
|
| 192 |
+
"decision_points": decision_points,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def format_prompt(hour: float, messages: list, deadlines: list, fired_drifts: list,
|
| 197 |
+
max_messages: int = 20) -> str:
|
| 198 |
+
"""Format an inbox state into a text prompt for the LLM.
|
| 199 |
+
|
| 200 |
+
Only the top `max_messages` unhandled messages are shown (by urgency then
|
| 201 |
+
deadline), keeping prompts within ~1500 tokens for small-model training.
|
| 202 |
+
"""
|
| 203 |
+
lines = [SYSTEM_PROMPT, ""]
|
| 204 |
+
lines.append(f"CURRENT TIME: Hour {hour:.1f} of 48 ({48 - hour:.1f} hours remaining)")
|
| 205 |
+
lines.append(f"MESSAGES IN INBOX: {len(messages)}")
|
| 206 |
+
lines.append("")
|
| 207 |
+
|
| 208 |
+
# Show urgent deadlines first
|
| 209 |
+
urgent = [d for d in deadlines if not d.get("expired") and d["hours_remaining"] < 4]
|
| 210 |
+
expired = [d for d in deadlines if d.get("expired")]
|
| 211 |
+
if urgent:
|
| 212 |
+
lines.append("URGENT DEADLINES:")
|
| 213 |
+
for d in sorted(urgent, key=lambda x: x["hours_remaining"]):
|
| 214 |
+
lines.append(f" ! {d['subject']} — {d['hours_remaining']}h left [{d['urgency']}]")
|
| 215 |
+
lines.append("")
|
| 216 |
+
if expired:
|
| 217 |
+
lines.append("EXPIRED DEADLINES:")
|
| 218 |
+
for d in expired:
|
| 219 |
+
lines.append(f" x {d['subject']} — expired {abs(d['hours_remaining']):.1f}h ago")
|
| 220 |
+
lines.append("")
|
| 221 |
+
|
| 222 |
+
if fired_drifts:
|
| 223 |
+
lines.append(f"POLICY CHANGES DETECTED: {len(fired_drifts)}")
|
| 224 |
+
lines.append("")
|
| 225 |
+
|
| 226 |
+
# Prioritize unhandled messages by urgency then deadline
|
| 227 |
+
urgency_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
|
| 228 |
+
ranked = sorted(messages, key=lambda m: (
|
| 229 |
+
urgency_order.get(m["urgency"], 4),
|
| 230 |
+
0 if m.get("drift_flag") else 1,
|
| 231 |
+
m.get("deadline_hours") or 999,
|
| 232 |
+
))
|
| 233 |
+
shown = ranked[:max_messages]
|
| 234 |
+
omitted = len(messages) - len(shown)
|
| 235 |
+
|
| 236 |
+
# Group shown messages by urgency
|
| 237 |
+
by_urgency = {"critical": [], "high": [], "medium": [], "low": []}
|
| 238 |
+
for msg in shown:
|
| 239 |
+
by_urgency.get(msg["urgency"], by_urgency["low"]).append(msg)
|
| 240 |
+
|
| 241 |
+
for level in ["critical", "high", "medium", "low"]:
|
| 242 |
+
msgs = by_urgency[level]
|
| 243 |
+
if msgs:
|
| 244 |
+
lines.append(f"--- {level.upper()} ({len(msgs)}) ---")
|
| 245 |
+
for msg in msgs:
|
| 246 |
+
stale = " [STALE]" if msg.get("superseded") else ""
|
| 247 |
+
drift = " [POLICY CHANGE]" if msg.get("drift_flag") else ""
|
| 248 |
+
deadline = f" (due h{msg['deadline_hours']})" if msg.get("deadline_hours") else ""
|
| 249 |
+
deps = f" [requires: {', '.join(msg['dependencies'])}]" if msg.get("dependencies") else ""
|
| 250 |
+
lines.append(f" [{msg['id']}] {msg['sender']}: {msg['subject']}{deadline}{stale}{drift}{deps}")
|
| 251 |
+
# Show content preview (first 120 chars)
|
| 252 |
+
preview = msg["content"][:120].replace("\n", " ")
|
| 253 |
+
if len(msg["content"]) > 120:
|
| 254 |
+
preview += "..."
|
| 255 |
+
lines.append(f" > {preview}")
|
| 256 |
+
lines.append("")
|
| 257 |
+
|
| 258 |
+
if omitted > 0:
|
| 259 |
+
lines.append(f" ({omitted} lower-priority messages not shown)")
|
| 260 |
+
lines.append("")
|
| 261 |
+
|
| 262 |
+
lines.append("Which message should you handle next? Respond with respond_to_message(message_id, response).")
|
| 263 |
+
return "\n".join(lines)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def generate_episodes(num_episodes: int = 50, start_seed: int = 1000) -> list:
|
| 267 |
+
"""Generate multiple training episodes with different seeds."""
|
| 268 |
+
episodes = []
|
| 269 |
+
for i in range(num_episodes):
|
| 270 |
+
seed = start_seed + i
|
| 271 |
+
print(f" Episode {i + 1}/{num_episodes} (seed={seed})...", end=" ")
|
| 272 |
+
episode = build_episode(seed)
|
| 273 |
+
n_dp = len(episode["decision_points"])
|
| 274 |
+
n_msg = episode["total_messages"]
|
| 275 |
+
drifts = ", ".join(episode["drift_events"])
|
| 276 |
+
print(f"{n_msg} messages, {n_dp} decision points, drifts: [{drifts}]")
|
| 277 |
+
episodes.append(episode)
|
| 278 |
+
return episodes
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def save_episodes(episodes: list, filename: str = "episodes.json"):
|
| 282 |
+
"""Save episodes to JSON file."""
|
| 283 |
+
with open(filename, "w") as f:
|
| 284 |
+
json.dump(episodes, f, indent=2)
|
| 285 |
+
total_prompts = sum(len(ep["decision_points"]) for ep in episodes)
|
| 286 |
+
print(f"\nSaved {len(episodes)} episodes ({total_prompts} training prompts) to {filename}")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
import argparse
|
| 291 |
+
|
| 292 |
+
parser = argparse.ArgumentParser(description="Generate CrisisInbox training episodes")
|
| 293 |
+
parser.add_argument("-n", "--num-episodes", type=int, default=50, help="Number of episodes")
|
| 294 |
+
parser.add_argument("-s", "--start-seed", type=int, default=1000, help="Starting seed")
|
| 295 |
+
parser.add_argument("-o", "--output", type=str, default="episodes.json", help="Output file")
|
| 296 |
+
parser.add_argument("--sample", type=int, default=5, help="Also save N sample episodes")
|
| 297 |
+
args = parser.parse_args()
|
| 298 |
+
|
| 299 |
+
print(f"Generating {args.num_episodes} episodes (seeds {args.start_seed}-{args.start_seed + args.num_episodes - 1})...")
|
| 300 |
+
episodes = generate_episodes(args.num_episodes, args.start_seed)
|
| 301 |
+
save_episodes(episodes, args.output)
|
| 302 |
+
|
| 303 |
+
if args.sample > 0:
|
| 304 |
+
sample_file = "sample_episodes.json"
|
| 305 |
+
save_episodes(episodes[:args.sample], sample_file)
|
messages.py
CHANGED
|
@@ -46,9 +46,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 46 |
channel=Channel.SMS,
|
| 47 |
subject="Are you safe?",
|
| 48 |
content=(
|
| 49 |
-
"Honey are you ok??
|
| 50 |
-
"
|
| 51 |
-
"if you need us."
|
|
|
|
| 52 |
),
|
| 53 |
urgency=Urgency.HIGH,
|
| 54 |
timestamp_hours=0.5,
|
|
@@ -59,9 +60,11 @@ ALL_MESSAGES: list[Message] = [
|
|
| 59 |
channel=Channel.SMS,
|
| 60 |
subject="Can you take the kids?",
|
| 61 |
content=(
|
| 62 |
-
"Hey
|
| 63 |
-
"Emma and Jake from Oakwood
|
| 64 |
-
"
|
|
|
|
|
|
|
| 65 |
),
|
| 66 |
urgency=Urgency.HIGH,
|
| 67 |
timestamp_hours=0.75,
|
|
@@ -88,9 +91,14 @@ ALL_MESSAGES: list[Message] = [
|
|
| 88 |
channel=Channel.EMAIL,
|
| 89 |
subject="Tomorrow's status - need you in office",
|
| 90 |
content=(
|
| 91 |
-
"Team,
|
| 92 |
-
"
|
| 93 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
),
|
| 95 |
urgency=Urgency.MEDIUM,
|
| 96 |
timestamp_hours=1.5,
|
|
@@ -117,9 +125,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 117 |
channel=Channel.SMS,
|
| 118 |
subject="Need help boarding windows",
|
| 119 |
content=(
|
| 120 |
-
"Hey
|
| 121 |
-
"
|
| 122 |
-
"
|
|
|
|
| 123 |
),
|
| 124 |
urgency=Urgency.MEDIUM,
|
| 125 |
timestamp_hours=2.0,
|
|
@@ -177,9 +186,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 177 |
channel=Channel.PHONE,
|
| 178 |
subject="Missed call from Mom",
|
| 179 |
content=(
|
| 180 |
-
"You have a missed call from Mom.
|
| 181 |
-
"me back.
|
| 182 |
-
"
|
|
|
|
| 183 |
),
|
| 184 |
urgency=Urgency.HIGH,
|
| 185 |
timestamp_hours=4.0,
|
|
@@ -190,9 +200,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 190 |
channel=Channel.SMS,
|
| 191 |
subject="Water rising on our street",
|
| 192 |
content=(
|
| 193 |
-
"
|
| 194 |
-
"
|
| 195 |
-
"
|
|
|
|
| 196 |
),
|
| 197 |
urgency=Urgency.HIGH,
|
| 198 |
timestamp_hours=4.5,
|
|
@@ -233,8 +244,9 @@ ALL_MESSAGES: list[Message] = [
|
|
| 233 |
channel=Channel.SMS,
|
| 234 |
subject="Did you get my email?",
|
| 235 |
content=(
|
| 236 |
-
"
|
| 237 |
-
"The Meridian deck won't finish itself
|
|
|
|
| 238 |
),
|
| 239 |
urgency=Urgency.MEDIUM,
|
| 240 |
timestamp_hours=6.5,
|
|
@@ -247,8 +259,9 @@ ALL_MESSAGES: list[Message] = [
|
|
| 247 |
channel=Channel.SMS,
|
| 248 |
subject="Did you get the kids?",
|
| 249 |
content=(
|
| 250 |
-
"Hey
|
| 251 |
-
"
|
|
|
|
| 252 |
),
|
| 253 |
urgency=Urgency.CRITICAL,
|
| 254 |
timestamp_hours=7.0,
|
|
@@ -273,9 +286,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 273 |
channel=Channel.SMS,
|
| 274 |
subject="Left my cat - can you check?",
|
| 275 |
content=(
|
| 276 |
-
"I evacuated to Lincoln High but I had to leave
|
| 277 |
-
"the upstairs bathroom
|
| 278 |
-
"
|
|
|
|
| 279 |
),
|
| 280 |
urgency=Urgency.MEDIUM,
|
| 281 |
timestamp_hours=8.0,
|
|
@@ -300,9 +314,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 300 |
channel=Channel.SMS,
|
| 301 |
subject="Uncle Rick says come to Tahoe",
|
| 302 |
content=(
|
| 303 |
-
"Uncle Rick
|
| 304 |
-
"
|
| 305 |
-
"
|
|
|
|
| 306 |
),
|
| 307 |
urgency=Urgency.HIGH,
|
| 308 |
timestamp_hours=9.0,
|
|
@@ -373,10 +388,12 @@ ALL_MESSAGES: list[Message] = [
|
|
| 373 |
channel=Channel.EMAIL,
|
| 374 |
subject="RE: Tomorrow's status - updated guidance",
|
| 375 |
content=(
|
| 376 |
-
"
|
| 377 |
-
"
|
| 378 |
-
"
|
| 379 |
-
"
|
|
|
|
|
|
|
| 380 |
),
|
| 381 |
urgency=Urgency.MEDIUM,
|
| 382 |
timestamp_hours=12.0,
|
|
@@ -398,15 +415,29 @@ ALL_MESSAGES: list[Message] = [
|
|
| 398 |
timestamp_hours=13.0,
|
| 399 |
deadline_hours=48.0,
|
| 400 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
Message(
|
| 402 |
id="msg_027",
|
| 403 |
sender="Sister",
|
| 404 |
channel=Channel.SMS,
|
| 405 |
subject="Kids are asking about their stuff",
|
| 406 |
content=(
|
| 407 |
-
"Emma
|
| 408 |
-
"
|
| 409 |
-
"
|
|
|
|
| 410 |
),
|
| 411 |
urgency=Urgency.LOW,
|
| 412 |
timestamp_hours=13.5,
|
|
@@ -445,10 +476,11 @@ ALL_MESSAGES: list[Message] = [
|
|
| 445 |
channel=Channel.SMS,
|
| 446 |
subject="Dad's medication",
|
| 447 |
content=(
|
| 448 |
-
"Honey
|
| 449 |
-
"
|
| 450 |
-
"
|
| 451 |
-
"in
|
|
|
|
| 452 |
),
|
| 453 |
urgency=Urgency.HIGH,
|
| 454 |
timestamp_hours=16.0,
|
|
@@ -511,15 +543,30 @@ ALL_MESSAGES: list[Message] = [
|
|
| 511 |
urgency=Urgency.HIGH,
|
| 512 |
timestamp_hours=19.0,
|
| 513 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
Message(
|
| 515 |
id="msg_035",
|
| 516 |
sender="Sister",
|
| 517 |
channel=Channel.PHONE,
|
| 518 |
subject="Missed call from Sister",
|
| 519 |
content=(
|
| 520 |
-
"Missed call from Sister. Voicemail: '
|
| 521 |
-
"
|
| 522 |
-
"
|
|
|
|
|
|
|
| 523 |
),
|
| 524 |
urgency=Urgency.MEDIUM,
|
| 525 |
timestamp_hours=19.5,
|
|
@@ -568,9 +615,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 568 |
channel=Channel.SMS,
|
| 569 |
subject="Office is in Zone B - disregard earlier email",
|
| 570 |
content=(
|
| 571 |
-
"
|
| 572 |
-
"
|
| 573 |
-
"
|
|
|
|
| 574 |
),
|
| 575 |
urgency=Urgency.MEDIUM,
|
| 576 |
timestamp_hours=21.5,
|
|
@@ -611,10 +659,11 @@ ALL_MESSAGES: list[Message] = [
|
|
| 611 |
channel=Channel.SMS,
|
| 612 |
subject="Someone broke into houses on our street",
|
| 613 |
content=(
|
| 614 |
-
"
|
| 615 |
-
"and Oak
|
| 616 |
-
"
|
| 617 |
-
"line
|
|
|
|
| 618 |
),
|
| 619 |
urgency=Urgency.HIGH,
|
| 620 |
timestamp_hours=23.0,
|
|
@@ -702,9 +751,11 @@ ALL_MESSAGES: list[Message] = [
|
|
| 702 |
channel=Channel.SMS,
|
| 703 |
subject="Did you call Dr. Patel?",
|
| 704 |
content=(
|
| 705 |
-
"
|
| 706 |
-
"
|
| 707 |
-
"
|
|
|
|
|
|
|
| 708 |
),
|
| 709 |
urgency=Urgency.CRITICAL,
|
| 710 |
timestamp_hours=27.0,
|
|
@@ -804,9 +855,11 @@ ALL_MESSAGES: list[Message] = [
|
|
| 804 |
channel=Channel.EMAIL,
|
| 805 |
subject="Meridian presentation postponed",
|
| 806 |
content=(
|
| 807 |
-
"
|
| 808 |
-
"
|
| 809 |
-
"
|
|
|
|
|
|
|
| 810 |
),
|
| 811 |
urgency=Urgency.LOW,
|
| 812 |
timestamp_hours=33.0,
|
|
@@ -836,9 +889,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 836 |
channel=Channel.SMS,
|
| 837 |
subject="Dad's doing better",
|
| 838 |
content=(
|
| 839 |
-
"
|
| 840 |
-
"
|
| 841 |
-
"
|
|
|
|
| 842 |
),
|
| 843 |
urgency=Urgency.LOW,
|
| 844 |
timestamp_hours=34.5,
|
|
@@ -921,9 +975,9 @@ ALL_MESSAGES: list[Message] = [
|
|
| 921 |
channel=Channel.SMS,
|
| 922 |
subject="Going home to check damage",
|
| 923 |
content=(
|
| 924 |
-
"Zone A evac is lifted! I'm
|
| 925 |
-
"
|
| 926 |
-
"it should be safe.
|
| 927 |
),
|
| 928 |
urgency=Urgency.MEDIUM,
|
| 929 |
timestamp_hours=38.5,
|
|
@@ -935,9 +989,10 @@ ALL_MESSAGES: list[Message] = [
|
|
| 935 |
channel=Channel.SMS,
|
| 936 |
subject="We're at Tahoe",
|
| 937 |
content=(
|
| 938 |
-
"
|
| 939 |
-
"like
|
| 940 |
-
"
|
|
|
|
| 941 |
),
|
| 942 |
urgency=Urgency.LOW,
|
| 943 |
timestamp_hours=39.0,
|
|
@@ -995,9 +1050,11 @@ ALL_MESSAGES: list[Message] = [
|
|
| 995 |
channel=Channel.SMS,
|
| 996 |
subject="Please come to Tahoe",
|
| 997 |
content=(
|
| 998 |
-
"
|
| 999 |
-
"
|
| 1000 |
-
"
|
|
|
|
|
|
|
| 1001 |
),
|
| 1002 |
urgency=Urgency.MEDIUM,
|
| 1003 |
timestamp_hours=44.0,
|
|
@@ -1064,9 +1121,24 @@ ALL_MESSAGES: list[Message] = [
|
|
| 1064 |
channel=Channel.SMS,
|
| 1065 |
subject="House update",
|
| 1066 |
content=(
|
| 1067 |
-
"
|
| 1068 |
-
"
|
| 1069 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1070 |
),
|
| 1071 |
urgency=Urgency.LOW,
|
| 1072 |
timestamp_hours=47.0,
|
|
@@ -1077,9 +1149,11 @@ ALL_MESSAGES: list[Message] = [
|
|
| 1077 |
channel=Channel.SMS,
|
| 1078 |
subject="Kids say thank you",
|
| 1079 |
content=(
|
| 1080 |
-
"Emma drew you a picture of a rainbow
|
| 1081 |
-
"
|
| 1082 |
-
"
|
|
|
|
|
|
|
| 1083 |
),
|
| 1084 |
urgency=Urgency.LOW,
|
| 1085 |
timestamp_hours=47.5,
|
|
|
|
| 46 |
channel=Channel.SMS,
|
| 47 |
subject="Are you safe?",
|
| 48 |
content=(
|
| 49 |
+
"Honey are you ok?? Just saw channel 3 and they're showing your neighborhood. "
|
| 50 |
+
"I'm shaking. Your father keeps pacing. PLEASE call us the second you see this. "
|
| 51 |
+
"We will drive down right now if you need us. I don't care about the roads. "
|
| 52 |
+
"I love you so much please be safe."
|
| 53 |
),
|
| 54 |
urgency=Urgency.HIGH,
|
| 55 |
timestamp_hours=0.5,
|
|
|
|
| 60 |
channel=Channel.SMS,
|
| 61 |
subject="Can you take the kids?",
|
| 62 |
content=(
|
| 63 |
+
"Hey I need a huge favor. My idiot boss is making us come in even with "
|
| 64 |
+
"the storm. Can you PLEASE pick up Emma and Jake from Oakwood by 3pm? "
|
| 65 |
+
"They're doing early dismissal and I literally have no one else. "
|
| 66 |
+
"Emma gets scared during storms so just tell her it's an adventure ok? "
|
| 67 |
+
"I owe you big time."
|
| 68 |
),
|
| 69 |
urgency=Urgency.HIGH,
|
| 70 |
timestamp_hours=0.75,
|
|
|
|
| 91 |
channel=Channel.EMAIL,
|
| 92 |
subject="Tomorrow's status - need you in office",
|
| 93 |
content=(
|
| 94 |
+
"Team,\n\n"
|
| 95 |
+
"I realize the weather situation isn't ideal but the Meridian presentation "
|
| 96 |
+
"is Thursday and the client isn't going to reschedule. I need everyone in the "
|
| 97 |
+
"office tomorrow morning, full stop. If you ABSOLUTELY cannot make it, submit "
|
| 98 |
+
"PTO through the portal. I shouldn't have to remind you this is a career-defining "
|
| 99 |
+
"account.\n\n"
|
| 100 |
+
"Confirm your status by tonight.\n\n"
|
| 101 |
+
"— Greg"
|
| 102 |
),
|
| 103 |
urgency=Urgency.MEDIUM,
|
| 104 |
timestamp_hours=1.5,
|
|
|
|
| 125 |
channel=Channel.SMS,
|
| 126 |
subject="Need help boarding windows",
|
| 127 |
content=(
|
| 128 |
+
"Hey it's Dave from next door. My back went out yesterday and I can't "
|
| 129 |
+
"board up these windows by myself. I know you've got your own stuff going on "
|
| 130 |
+
"but if you've got 20 min I've got the plywood and drill ready to go. "
|
| 131 |
+
"I'll owe you a case of beer when this is all over."
|
| 132 |
),
|
| 133 |
urgency=Urgency.MEDIUM,
|
| 134 |
timestamp_hours=2.0,
|
|
|
|
| 186 |
channel=Channel.PHONE,
|
| 187 |
subject="Missed call from Mom",
|
| 188 |
content=(
|
| 189 |
+
"You have a missed call from Mom. Voicemail: '[sounds like she's been crying] "
|
| 190 |
+
"Baby please call me back. I can't stop watching the news. Your dad talked to "
|
| 191 |
+
"Uncle Rick and he says come to Tahoe, he's got plenty of room. But honey we "
|
| 192 |
+
"need to leave SOON before the roads get worse. Please please call me. I love you.'"
|
| 193 |
),
|
| 194 |
urgency=Urgency.HIGH,
|
| 195 |
timestamp_hours=4.0,
|
|
|
|
| 200 |
channel=Channel.SMS,
|
| 201 |
subject="Water rising on our street",
|
| 202 |
content=(
|
| 203 |
+
"Dude the water is coming up FAST on Elm. My basement already has 6 inches. "
|
| 204 |
+
"Get anything you care about upstairs RIGHT NOW. I'm not trying to scare you "
|
| 205 |
+
"but this is way worse than they said it would be. I think I'm heading to "
|
| 206 |
+
"the shelter. You should too."
|
| 207 |
),
|
| 208 |
urgency=Urgency.HIGH,
|
| 209 |
timestamp_hours=4.5,
|
|
|
|
| 244 |
channel=Channel.SMS,
|
| 245 |
subject="Did you get my email?",
|
| 246 |
content=(
|
| 247 |
+
"Still waiting on your status. Yes or no on tomorrow? I've got half the team "
|
| 248 |
+
"going dark on me. The Meridian deck won't finish itself and I'm not going to "
|
| 249 |
+
"be the one explaining to Jacobs why we weren't prepared."
|
| 250 |
),
|
| 251 |
urgency=Urgency.MEDIUM,
|
| 252 |
timestamp_hours=6.5,
|
|
|
|
| 259 |
channel=Channel.SMS,
|
| 260 |
subject="Did you get the kids?",
|
| 261 |
content=(
|
| 262 |
+
"Hey the school keeps calling me and I can't pick up because I'm in this "
|
| 263 |
+
"stupid meeting. Did you get them?? Emma's teacher said she was crying. "
|
| 264 |
+
"Please just text me back one word. PLEASE. I'm losing it over here."
|
| 265 |
),
|
| 266 |
urgency=Urgency.CRITICAL,
|
| 267 |
timestamp_hours=7.0,
|
|
|
|
| 286 |
channel=Channel.SMS,
|
| 287 |
subject="Left my cat - can you check?",
|
| 288 |
content=(
|
| 289 |
+
"Man I feel awful. I evacuated to Lincoln High but I had to leave Whiskers "
|
| 290 |
+
"behind. She's in the upstairs bathroom — I left food and water but she was "
|
| 291 |
+
"meowing so loud when I shut the door. If you're still on the street can you "
|
| 292 |
+
"just peek in and make sure she's ok? The spare key is under the gnome."
|
| 293 |
),
|
| 294 |
urgency=Urgency.MEDIUM,
|
| 295 |
timestamp_hours=8.0,
|
|
|
|
| 314 |
channel=Channel.SMS,
|
| 315 |
subject="Uncle Rick says come to Tahoe",
|
| 316 |
content=(
|
| 317 |
+
"OK so Uncle Rick says absolutely yes, come up. He's got 3 bedrooms and says "
|
| 318 |
+
"stay as long as you need. Your sister and the kids too. Dad and I are packing "
|
| 319 |
+
"the car now. Honey PLEASE come. I won't be able to sleep tonight if you're "
|
| 320 |
+
"still down there. The roads north are clear, we checked. Can you be ready by morning?"
|
| 321 |
),
|
| 322 |
urgency=Urgency.HIGH,
|
| 323 |
timestamp_hours=9.0,
|
|
|
|
| 388 |
channel=Channel.EMAIL,
|
| 389 |
subject="RE: Tomorrow's status - updated guidance",
|
| 390 |
content=(
|
| 391 |
+
"Team,\n\n"
|
| 392 |
+
"Spoke with leadership. Fine — if you're in an evac zone, work remote. "
|
| 393 |
+
"But the Meridian deck is still due Thursday 5 PM, no exceptions. I don't care "
|
| 394 |
+
"if you're working from a shelter, a Starbucks, or your car. Claim your slides "
|
| 395 |
+
"in the shared doc by end of day or I'll reassign them.\n\n"
|
| 396 |
+
"— Greg"
|
| 397 |
),
|
| 398 |
urgency=Urgency.MEDIUM,
|
| 399 |
timestamp_hours=12.0,
|
|
|
|
| 415 |
timestamp_hours=13.0,
|
| 416 |
deadline_hours=48.0,
|
| 417 |
),
|
| 418 |
+
Message(
|
| 419 |
+
id="msg_026b",
|
| 420 |
+
sender="Emma (niece)",
|
| 421 |
+
channel=Channel.SMS,
|
| 422 |
+
subject="from emmas phone",
|
| 423 |
+
content=(
|
| 424 |
+
"hi its emma. jake took the big pillow and wont share. also when is mommy "
|
| 425 |
+
"coming. are we having a sleepover?? can we have mac and cheese for dinner. "
|
| 426 |
+
"i miss mr buttons. love you"
|
| 427 |
+
),
|
| 428 |
+
urgency=Urgency.LOW,
|
| 429 |
+
timestamp_hours=12.5,
|
| 430 |
+
),
|
| 431 |
Message(
|
| 432 |
id="msg_027",
|
| 433 |
sender="Sister",
|
| 434 |
channel=Channel.SMS,
|
| 435 |
subject="Kids are asking about their stuff",
|
| 436 |
content=(
|
| 437 |
+
"Emma won't stop crying about Mr. Buttons — that's her stuffed bunny she left "
|
| 438 |
+
"in her cubby at school. Jake is being Jake, just wants his tablet. Did you "
|
| 439 |
+
"happen to grab their backpacks? My office FINALLY let us leave. I can come "
|
| 440 |
+
"get them around 8 if the roads are ok. Thank you for doing this. Seriously."
|
| 441 |
),
|
| 442 |
urgency=Urgency.LOW,
|
| 443 |
timestamp_hours=13.5,
|
|
|
|
| 476 |
channel=Channel.SMS,
|
| 477 |
subject="Dad's medication",
|
| 478 |
content=(
|
| 479 |
+
"Honey I don't want to add to your plate but your dad just realized he left "
|
| 480 |
+
"his heart medication on the kitchen counter. We're already past Folsom. "
|
| 481 |
+
"Can you PLEASE call Dr. Patel tomorrow morning and ask them to call in a "
|
| 482 |
+
"refill to a pharmacy in Truckee? I'm trying not to panic but you know "
|
| 483 |
+
"he can't miss more than a day. His number should be in dad's phone under P."
|
| 484 |
),
|
| 485 |
urgency=Urgency.HIGH,
|
| 486 |
timestamp_hours=16.0,
|
|
|
|
| 543 |
urgency=Urgency.HIGH,
|
| 544 |
timestamp_hours=19.0,
|
| 545 |
),
|
| 546 |
+
Message(
|
| 547 |
+
id="msg_034b",
|
| 548 |
+
sender="Emma (niece)",
|
| 549 |
+
channel=Channel.SMS,
|
| 550 |
+
subject="im scared",
|
| 551 |
+
content=(
|
| 552 |
+
"the lights went off and its really dark. jake is pretending hes not scared "
|
| 553 |
+
"but he is. when is mommy coming?? you said she was coming. "
|
| 554 |
+
"i dont like the wind noise. can you come sit with us"
|
| 555 |
+
),
|
| 556 |
+
urgency=Urgency.MEDIUM,
|
| 557 |
+
timestamp_hours=19.0,
|
| 558 |
+
),
|
| 559 |
Message(
|
| 560 |
id="msg_035",
|
| 561 |
sender="Sister",
|
| 562 |
channel=Channel.PHONE,
|
| 563 |
subject="Missed call from Sister",
|
| 564 |
content=(
|
| 565 |
+
"Missed call from Sister. Voicemail: '[car horn honking in background] Hey it's me, "
|
| 566 |
+
"I'm trying to get to you but Florin is completely underwater and my GPS keeps "
|
| 567 |
+
"rerouting me in circles. I think I need to turn back. Can you keep them tonight? "
|
| 568 |
+
"Emma's probably already asleep, just — tell her Mommy loves her. I'm so sorry. "
|
| 569 |
+
"I'll figure it out in the morning. [voice cracks] Thank you.'"
|
| 570 |
),
|
| 571 |
urgency=Urgency.MEDIUM,
|
| 572 |
timestamp_hours=19.5,
|
|
|
|
| 615 |
channel=Channel.SMS,
|
| 616 |
subject="Office is in Zone B - disregard earlier email",
|
| 617 |
content=(
|
| 618 |
+
"Well, our office is in Zone B. So obviously disregard the 'come in tomorrow' "
|
| 619 |
+
"thing. But I STILL need those slides. Can you work on them from wherever "
|
| 620 |
+
"you are? Sarah said she can pick up some of the slack but I need to know "
|
| 621 |
+
"what you can deliver. This client won't wait for a hurricane."
|
| 622 |
),
|
| 623 |
urgency=Urgency.MEDIUM,
|
| 624 |
timestamp_hours=21.5,
|
|
|
|
| 659 |
channel=Channel.SMS,
|
| 660 |
subject="Someone broke into houses on our street",
|
| 661 |
content=(
|
| 662 |
+
"Hey bad news. Talked to a cop here at the shelter and he said there have been "
|
| 663 |
+
"break-ins on Elm and Oak. People are the worst — looting during a disaster. "
|
| 664 |
+
"If you left anything valuable at your place you should probably file a report. "
|
| 665 |
+
"Non-emergency line is 555-0199. Guy two cots over from me said they got his TV "
|
| 666 |
+
"and laptop. Unbelievable."
|
| 667 |
),
|
| 668 |
urgency=Urgency.HIGH,
|
| 669 |
timestamp_hours=23.0,
|
|
|
|
| 751 |
channel=Channel.SMS,
|
| 752 |
subject="Did you call Dr. Patel?",
|
| 753 |
content=(
|
| 754 |
+
"Sweetheart did you call Dr. Patel?? Dad hasn't had his medication in over "
|
| 755 |
+
"24 hours now and I can tell his blood pressure is up. He won't admit it "
|
| 756 |
+
"but he's been dizzy. The pharmacy in Truckee says they need the doctor to "
|
| 757 |
+
"call it in — they can't just refill it. I know you have a million things "
|
| 758 |
+
"going on but this one really scares me."
|
| 759 |
),
|
| 760 |
urgency=Urgency.CRITICAL,
|
| 761 |
timestamp_hours=27.0,
|
|
|
|
| 855 |
channel=Channel.EMAIL,
|
| 856 |
subject="Meridian presentation postponed",
|
| 857 |
content=(
|
| 858 |
+
"OK update — Meridian called and pushed the presentation to next Wednesday. "
|
| 859 |
+
"Turns out half their team is dealing with the storm too. Go figure.\n\n"
|
| 860 |
+
"Take the emergency leave. Focus on your family. I know I was being intense "
|
| 861 |
+
"earlier, I'm dealing with this too. We'll regroup Monday.\n\n"
|
| 862 |
+
"— Greg"
|
| 863 |
),
|
| 864 |
urgency=Urgency.LOW,
|
| 865 |
timestamp_hours=33.0,
|
|
|
|
| 889 |
channel=Channel.SMS,
|
| 890 |
subject="Dad's doing better",
|
| 891 |
content=(
|
| 892 |
+
"Oh thank God — Dr. Patel called in the prescription to the Rite Aid in Truckee "
|
| 893 |
+
"and dad picked it up an hour ago. He's already looking better, the color is "
|
| 894 |
+
"back in his face. Thank you honey. Now PLEASE come up here. Rick made up the "
|
| 895 |
+
"guest room and everything. We miss you. Mom loves you."
|
| 896 |
),
|
| 897 |
urgency=Urgency.LOW,
|
| 898 |
timestamp_hours=34.5,
|
|
|
|
| 975 |
channel=Channel.SMS,
|
| 976 |
subject="Going home to check damage",
|
| 977 |
content=(
|
| 978 |
+
"Zone A evac is lifted!! I'm heading back to see how bad it is. Want to come "
|
| 979 |
+
"with me? Safety in numbers and all that. Cops said they caught the guys who "
|
| 980 |
+
"were breaking in so it should be safe. I can swing by in 20 if you're in."
|
| 981 |
),
|
| 982 |
urgency=Urgency.MEDIUM,
|
| 983 |
timestamp_hours=38.5,
|
|
|
|
| 989 |
channel=Channel.SMS,
|
| 990 |
subject="We're at Tahoe",
|
| 991 |
content=(
|
| 992 |
+
"We made it to Uncle Rick's. The kids are already outside throwing snowballs "
|
| 993 |
+
"like the last 24 hours never happened. Kids are incredible. Mom immediately "
|
| 994 |
+
"started making chicken soup because of course she did. There's room for you. "
|
| 995 |
+
"For real, come up when you can. You deserve a break after everything."
|
| 996 |
),
|
| 997 |
urgency=Urgency.LOW,
|
| 998 |
timestamp_hours=39.0,
|
|
|
|
| 1050 |
channel=Channel.SMS,
|
| 1051 |
subject="Please come to Tahoe",
|
| 1052 |
content=(
|
| 1053 |
+
"Honey everyone is here except you and it doesn't feel right. Dad's feeling "
|
| 1054 |
+
"so much better — he's out on the deck with Rick right now. The roads are "
|
| 1055 |
+
"totally clear, I checked three times. Even if it's just for a couple days. "
|
| 1056 |
+
"You've been taking care of everyone else this whole time. Let us take care "
|
| 1057 |
+
"of you for a bit. Please."
|
| 1058 |
),
|
| 1059 |
urgency=Urgency.MEDIUM,
|
| 1060 |
timestamp_hours=44.0,
|
|
|
|
| 1121 |
channel=Channel.SMS,
|
| 1122 |
subject="House update",
|
| 1123 |
content=(
|
| 1124 |
+
"Good news — my place isn't as bad as I thought. Water in the basement but "
|
| 1125 |
+
"first floor survived. Could've been way worse. How'd yours make out? "
|
| 1126 |
+
"Seriously, anything you need with cleanup just say the word. You checked on "
|
| 1127 |
+
"Whiskers for me and I won't forget that. Neighbors gotta stick together."
|
| 1128 |
+
),
|
| 1129 |
+
urgency=Urgency.LOW,
|
| 1130 |
+
timestamp_hours=47.0,
|
| 1131 |
+
),
|
| 1132 |
+
Message(
|
| 1133 |
+
id="msg_072b",
|
| 1134 |
+
sender="Emma (niece)",
|
| 1135 |
+
channel=Channel.SMS,
|
| 1136 |
+
subject="i drew you a pictur",
|
| 1137 |
+
content=(
|
| 1138 |
+
"hi its emma again!! i drew you a rainbow and a house and us. mommy said "
|
| 1139 |
+
"i can send you a foto. jake says thank you for the adventure. i say thank "
|
| 1140 |
+
"you for the mac and cheese. your the best. when can we do a sleepover again "
|
| 1141 |
+
"but WITHOUT the storm part. love emma age 7"
|
| 1142 |
),
|
| 1143 |
urgency=Urgency.LOW,
|
| 1144 |
timestamp_hours=47.0,
|
|
|
|
| 1149 |
channel=Channel.SMS,
|
| 1150 |
subject="Kids say thank you",
|
| 1151 |
content=(
|
| 1152 |
+
"Emma drew you a picture of a rainbow with 'thank you' written in crayon. "
|
| 1153 |
+
"She wants me to send you a photo. And Jake keeps telling everyone about "
|
| 1154 |
+
"'the adventure' and how you let him eat cereal for dinner. Kids man. "
|
| 1155 |
+
"I don't know what I would have done without you this week. I really don't. "
|
| 1156 |
+
"You're my person. Love you."
|
| 1157 |
),
|
| 1158 |
urgency=Urgency.LOW,
|
| 1159 |
timestamp_hours=47.5,
|
notebooks/crisisinbox_grpo_simple.ipynb
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": "# CrisisInbox GRPO Training\n\nTrain a small LLM to triage crisis inbox messages using Group Relative Policy Optimization.\n\n**What this does:**\n1. Loads pre-generated episode data (inbox snapshots at decision points)\n2. For each prompt, the model generates an action (which message to handle + response)\n3. A reward function scores the action based on urgency, deadline, drift adaptation\n4. GRPO updates the model to prefer higher-reward actions\n\n**GPU profiles:**\n- **T4 / free Colab**: Qwen2.5-0.5B, 2048 ctx, 4-bit — runs in ~30 min\n- **H100 / A100**: Qwen2.5-3B, 4096 ctx, 4-bit — better quality, ~20 min"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"cell_type": "code",
|
| 10 |
+
"execution_count": null,
|
| 11 |
+
"metadata": {},
|
| 12 |
+
"source": "# Install dependencies\n!pip install unsloth trl transformers datasets accelerate peft -q\n!pip install huggingface_hub -q\n\n# Download episode data\n# Option 1: From HF dataset (recommended)\n# Option 2: From GitHub repo\n# Option 3: Generate locally with `python generate_episodes.py -n 100`\n\nimport os\nif not os.path.exists(\"episodes.json\"):\n print(\"Downloading episodes.json from GitHub...\")\n !wget -q --show-progress https://raw.githubusercontent.com/eptan/crisis-inbox/main/episodes.json\n if not os.path.exists(\"episodes.json\"):\n print(\"ERROR: Download failed. Upload episodes.json manually or generate with:\")\n print(\" !git clone https://github.com/eptan/crisis-inbox.git && cd crisis-inbox && python generate_episodes.py -n 100\")\nelse:\n print(\"episodes.json already exists, skipping download\")\n\nprint(\"Setup complete\")",
|
| 13 |
+
"outputs": []
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"source": "# === GPU PROFILE ===\n# Change this one variable to switch between T4 and H100 configs.\n# Everything else adapts automatically.\n\nimport torch\n\nif torch.cuda.is_available():\n vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n gpu_name = torch.cuda.get_device_name(0)\n print(f\"GPU: {gpu_name} ({vram_gb:.0f} GB)\")\nelse:\n vram_gb = 0\n print(\"No GPU detected — config will default to smallest profile\")\n\n# Auto-select profile based on VRAM, or override manually\nif vram_gb >= 40: # H100, A100\n PROFILE = \"h100\"\n MODEL_NAME = \"unsloth/Qwen2.5-3B-Instruct\"\n MAX_SEQ_LENGTH = 4096\n MAX_PROMPT_LENGTH = 3584\n MAX_COMPLETION_LENGTH = 512\n BATCH_SIZE = 4\n GRAD_ACCUM = 2\n NUM_GENERATIONS = 8\nelif vram_gb >= 14: # T4, L4\n PROFILE = \"t4\"\n MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n MAX_SEQ_LENGTH = 2048\n MAX_PROMPT_LENGTH = 1792\n MAX_COMPLETION_LENGTH = 256\n BATCH_SIZE = 2\n GRAD_ACCUM = 4\n NUM_GENERATIONS = 4\nelse:\n PROFILE = \"cpu\"\n MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n MAX_SEQ_LENGTH = 2048\n MAX_PROMPT_LENGTH = 1792\n MAX_COMPLETION_LENGTH = 256\n BATCH_SIZE = 1\n GRAD_ACCUM = 8\n NUM_GENERATIONS = 2\n\nprint(f\"Profile: {PROFILE} | Model: {MODEL_NAME} | Context: {MAX_SEQ_LENGTH}\")",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"outputs": []
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"source": "import json\nimport re\nimport random\nfrom datasets import Dataset\n\n# Load episodes\nwith open(\"episodes.json\") as f:\n episodes = json.load(f)\n\n# Check format — old format has 'messages'/'tasks', new format has 'decision_points'\nif episodes and \"decision_points\" not in episodes[0]:\n old_keys = list(episodes[0].keys())\n raise ValueError(\n f\"episodes.json is in the old format (keys: {old_keys}).\\n\"\n f\"Regenerate with: python generate_episodes.py -n 100\\n\"\n f\"The old format used 'messages'/'tasks'/'schema_events'; \"\n f\"the notebook requires 'decision_points' from generate_episodes.py.\"\n )\n\n# Flatten to individual training prompts\nprompts = []\nfor ep in episodes:\n for dp in ep[\"decision_points\"]:\n prompts.append({\n \"prompt\": dp[\"prompt\"],\n \"hour\": dp[\"hour\"],\n \"visible_count\": dp[\"visible_count\"],\n \"episode_id\": ep[\"episode_id\"],\n \"seed\": ep[\"seed\"],\n \"drift_events\": ep[\"drift_events\"],\n \"superseded\": ep.get(\"superseded_messages\", {}),\n \"messages\": dp[\"visible_messages\"],\n })\n\nif not prompts:\n raise ValueError(\"No decision_points found in episodes; cannot train.\")\n\nprint(f\"Loaded {len(episodes)} episodes -> {len(prompts)} training prompts\")\nprint(f\"Average {len(prompts)/len(episodes):.1f} decision points per episode\")",
|
| 27 |
+
"outputs": []
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "markdown",
|
| 31 |
+
"source": "## Reward Function\n\nScores agent actions based on:\n- **Urgency base** (critical=10, high=5, medium=3, low=1)\n- **Deadline timing** (early=bonus, late=penalty)\n- **Drift adaptation** (+50% for handling policy-change messages)\n- **Stale info penalty** (-50% for acting on superseded messages)\n- **Response quality** (penalty for short/empty responses)",
|
| 32 |
+
"metadata": {}
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"source": "def score_action(completion: str, prompt_data: dict) -> float:\n \"\"\"\n Score a model completion against the inbox state.\n \n The model should output: respond_to_message(msg_id, \"response text\")\n We parse the message_id and response, then score based on the reward function.\n \"\"\"\n messages = prompt_data[\"messages\"]\n hour = prompt_data[\"hour\"]\n superseded = prompt_data.get(\"superseded\", {})\n \n # Parse the model output for message_id\n msg_id = None\n response_text = \"\"\n \n # Try to parse respond_to_message(msg_id, response)\n match = re.search(r'respond_to_message\\s*\\(\\s*[\"\\']?(msg_\\d+)[\"\\']?\\s*,\\s*[\"\\'](.+?)[\"\\']', completion, re.DOTALL)\n if match:\n msg_id = match.group(1)\n response_text = match.group(2)\n else:\n # Try simpler format: just a message ID mentioned\n id_match = re.search(r'(msg_\\d+)', completion)\n if id_match:\n msg_id = id_match.group(1)\n # No explicit response text — penalize via quality check below\n response_text = \"\"\n \n if not msg_id:\n return -1.0 # couldn't parse any action\n \n # Find the message in the inbox\n target_msg = None\n for msg in messages:\n if msg[\"id\"] == msg_id:\n target_msg = msg\n break\n \n if target_msg is None:\n return -0.5 # referenced a message not in inbox\n \n # Base reward by urgency\n urgency_rewards = {\"critical\": 10.0, \"high\": 5.0, \"medium\": 3.0, \"low\": 1.0}\n reward = urgency_rewards.get(target_msg[\"urgency\"], 1.0)\n \n # Deadline timing\n deadline = target_msg.get(\"deadline_hours\")\n if deadline is not None:\n if hour <= deadline:\n time_remaining_frac = (deadline - hour) / max(deadline, 1.0)\n reward *= 1.0 + 0.5 * time_remaining_frac\n else:\n reward *= 0.25 # late penalty\n \n # Response quality\n if len(response_text.strip()) < 10:\n reward *= 0.5\n \n # Drift adaptation bonus\n if target_msg.get(\"drift_flag\"):\n reward *= 1.5\n \n # Stale info penalty\n if target_msg[\"id\"] in superseded:\n reward *= 0.5\n \n # Penalize choosing low-urgency when unhandled critical messages exist\n unhandled_critical = any(\n m[\"urgency\"] == \"critical\" and not m.get(\"handled\") and not m.get(\"superseded\")\n for m in messages\n )\n if unhandled_critical and target_msg[\"urgency\"] in (\"low\", \"medium\"):\n reward *= 0.3\n \n return round(reward, 2)\n\n\n# Test the reward function\ntest_data = prompts[0]\nprint(\"Testing reward function on first decision point:\")\nprint(f\" Hour: {test_data['hour']}, Messages: {test_data['visible_count']}\")\n\n# Simulate good action (pick critical message)\ncritical_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"critical\"]\nif critical_msgs:\n good_action = f'respond_to_message(\"{critical_msgs[0][\"id\"]}\", \"Acknowledged. Evacuating immediately with documents and medication.\")'\n good_score = score_action(good_action, test_data)\n print(f\" Good action (critical msg): {good_score:.2f} pts\")\n\n# Simulate bad action (pick low-urgency message)\nlow_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"low\"]\nif low_msgs:\n bad_action = f'respond_to_message(\"{low_msgs[0][\"id\"]}\", \"ok\")'\n bad_score = score_action(bad_action, test_data)\n print(f\" Bad action (low msg, short response): {bad_score:.2f} pts\")\n\n# Simulate unparseable action\njunk_score = score_action(\"I think we should do something\", test_data)\nprint(f\" Unparseable action: {junk_score:.2f} pts\")",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"outputs": []
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "markdown",
|
| 43 |
+
"source": "## Load Model & Configure GRPO",
|
| 44 |
+
"metadata": {}
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"source": "from unsloth import FastLanguageModel\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n)\n\n# Add LoRA adapters — bigger r for bigger models\nlora_r = 32 if PROFILE == \"h100\" else 16\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=lora_r,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_alpha=lora_r,\n lora_dropout=0,\n bias=\"none\",\n use_gradient_checkpointing=\"unsloth\",\n)\nprint(f\"Model loaded: {MODEL_NAME} | LoRA r={lora_r} | ctx={MAX_SEQ_LENGTH}\")",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"outputs": []
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"source": "# Build the training dataset\n# Each row needs a \"prompt\" field formatted as chat messages\ntrain_data = []\nfor p in prompts:\n train_data.append({\n \"prompt\": [\n {\"role\": \"user\", \"content\": p[\"prompt\"]},\n ],\n # Store metadata for reward calculation (not used by trainer directly)\n \"_hour\": p[\"hour\"],\n \"_episode_id\": p[\"episode_id\"],\n })\n\n# Shuffle and split\nrandom.seed(42)\nrandom.shuffle(train_data)\n\ndataset = Dataset.from_list(train_data)\nprint(f\"Training dataset: {len(dataset)} prompts\")\nprint(f\"Sample prompt length: {len(train_data[0]['prompt'][0]['content'])} chars\")",
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"execution_count": null,
|
| 58 |
+
"outputs": []
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "markdown",
|
| 62 |
+
"source": "## GRPO Training Loop\n\nThe reward function scores each completion by:\n1. Parsing which message the model chose to handle\n2. Checking urgency, deadline timing, drift flags\n3. Penalizing bad choices (low-urgency when critical exists, stale info)",
|
| 63 |
+
"metadata": {}
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\n# Build a lookup from (episode_id, hour) -> prompt metadata for reward scoring\nprompt_lookup = {}\nfor p in prompts:\n key = (p[\"episode_id\"], p[\"hour\"])\n prompt_lookup[key] = p\n\n\ndef reward_fn(prompts, completions, _episode_id, _hour, **kwargs):\n \"\"\"\n GRPO reward function. Scores each completion against its inbox state.\n\n TRL passes extra dataset columns as keyword arguments, so _episode_id and\n _hour come directly from the dataset — no need to reverse-lookup from text.\n \"\"\"\n rewards = []\n for completion, ep_id, hour in zip(completions, _episode_id, _hour):\n key = (ep_id, hour)\n prompt_data = prompt_lookup.get(key)\n\n if prompt_data is None:\n rewards.append(0.0)\n continue\n\n if isinstance(completion, list):\n comp_text = completion[-1][\"content\"] if completion else \"\"\n else:\n comp_text = str(completion)\n\n score = score_action(comp_text, prompt_data)\n rewards.append(score)\n\n return rewards\n\n\nprint(f\"Prompt lookup: {len(prompt_lookup)} unique keys (expect {len(prompts)})\")\n\n# GRPO training config — all values from GPU profile\ntraining_args = GRPOConfig(\n output_dir=\"crisisinbox-grpo-output\",\n num_train_epochs=3,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=5e-6,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_prompt_length=MAX_PROMPT_LENGTH,\n num_generations=NUM_GENERATIONS,\n logging_steps=10,\n save_steps=100,\n report_to=\"none\",\n bf16=True,\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_fn,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer configured — batch={BATCH_SIZE}, gen={NUM_GENERATIONS}, prompt≤{MAX_PROMPT_LENGTH}tok\")",
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"execution_count": null,
|
| 70 |
+
"outputs": []
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"source": "# Train!\ntrainer.train()\nprint(\"Training complete\")",
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"outputs": []
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "markdown",
|
| 81 |
+
"source": "## Evaluate Trained Model\n\nSample prompts and check whether the model picks high-urgency messages and produces well-formatted actions.",
|
| 82 |
+
"metadata": {}
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"source": "# Evaluate on a few test prompts\nFastLanguageModel.for_inference(model)\n\neval_prompts = random.sample(prompts, min(10, len(prompts)))\ntotal_score = 0\n\nprint(f\"=== Trained Model Evaluation ({MODEL_NAME}) ===\\n\")\nfor p in eval_prompts:\n messages = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(\"cuda\")\n\n with torch.no_grad():\n output = model.generate(inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n\n completion = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)\n score = score_action(completion, p)\n total_score += score\n\n # Show a summary\n msg_match = re.search(r'(msg_\\d+)', completion)\n chosen_id = msg_match.group(1) if msg_match else \"none\"\n chosen_msg = next((m for m in p[\"messages\"] if m[\"id\"] == chosen_id), None)\n urgency = chosen_msg[\"urgency\"] if chosen_msg else \"?\"\n\n print(f\"Hour {p['hour']:5.1f} | Chose: {chosen_id} ({urgency:8s}) | Score: {score:+.1f}\")\n\nprint(f\"\\nAverage score: {total_score / len(eval_prompts):.2f}\")",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"outputs": []
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"source": "# Save the trained model\nmodel.save_pretrained(\"crisisinbox-grpo-trained\")\ntokenizer.save_pretrained(\"crisisinbox-grpo-trained\")\nprint(\"Model saved to crisisinbox-grpo-trained/\")",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"outputs": []
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"metadata": {
|
| 100 |
+
"kernelspec": {
|
| 101 |
+
"display_name": "Python 3",
|
| 102 |
+
"language": "python",
|
| 103 |
+
"name": "python3"
|
| 104 |
+
}
|
| 105 |
+
},
|
| 106 |
+
"nbformat": 4,
|
| 107 |
+
"nbformat_minor": 4
|
| 108 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "crisis-inbox"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "RL environment for training LLMs to manage personal task overload during natural disasters"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core]>=0.2.1",
|
| 12 |
+
"fastapi>=0.115.0",
|
| 13 |
+
"pydantic>=2.0.0",
|
| 14 |
+
"uvicorn>=0.24.0",
|
| 15 |
+
"requests>=2.31.0",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[project.optional-dependencies]
|
| 19 |
+
dev = [
|
| 20 |
+
"pytest>=8.0.0",
|
| 21 |
+
"pytest-cov>=4.0.0",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[project.scripts]
|
| 25 |
+
server = "crisis_inbox.server.app:main"
|
| 26 |
+
|
| 27 |
+
[tool.setuptools]
|
| 28 |
+
include-package-data = true
|
| 29 |
+
packages = ["crisis_inbox", "crisis_inbox.server"]
|
| 30 |
+
package-dir = { "crisis_inbox" = ".", "crisis_inbox.server" = "server" }
|
server/Dockerfile
CHANGED
|
@@ -2,21 +2,19 @@ FROM python:3.11-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
|
| 6 |
-
RUN pip install
|
| 7 |
|
| 8 |
-
|
| 9 |
-
COPY
|
| 10 |
-
COPY
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
# Install dependencies
|
| 13 |
-
RUN uv pip install --system -e .
|
| 14 |
-
|
| 15 |
-
# Expose port
|
| 16 |
EXPOSE 8000
|
| 17 |
|
| 18 |
-
# Health check
|
| 19 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 20 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 21 |
|
| 22 |
-
CMD ["uvicorn", "
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
COPY requirements.txt /app/
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
|
| 8 |
+
COPY models.py /app/
|
| 9 |
+
COPY messages.py /app/
|
| 10 |
+
COPY drift_events.py /app/
|
| 11 |
+
COPY client.py /app/
|
| 12 |
+
COPY __init__.py /app/
|
| 13 |
+
COPY server/ /app/server/
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
EXPOSE 8000
|
| 16 |
|
|
|
|
| 17 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 18 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 19 |
|
| 20 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
server/crisis_inbox_environment.py
CHANGED
|
@@ -293,7 +293,19 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 293 |
for msg in ALL_MESSAGES:
|
| 294 |
if msg.id in all_drift_ids and msg.id not in drift_msg_ids:
|
| 295 |
continue # Skip drift messages from events not selected this episode
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
self._visible_messages = []
|
| 299 |
self._handled = {}
|
|
|
|
| 293 |
for msg in ALL_MESSAGES:
|
| 294 |
if msg.id in all_drift_ids and msg.id not in drift_msg_ids:
|
| 295 |
continue # Skip drift messages from events not selected this episode
|
| 296 |
+
m = msg.model_copy()
|
| 297 |
+
# Add jitter to arrival times and deadlines for episode variation
|
| 298 |
+
# Keep hour-0 messages at 0, jitter others by +/- 15% (clamped to 0-48)
|
| 299 |
+
if m.timestamp_hours > 0:
|
| 300 |
+
jitter = self._rng.uniform(-0.15, 0.15) * m.timestamp_hours
|
| 301 |
+
m.timestamp_hours = max(0.1, min(47.5, m.timestamp_hours + jitter))
|
| 302 |
+
if m.deadline_hours is not None and m.deadline_hours > 0:
|
| 303 |
+
d_jitter = self._rng.uniform(-0.1, 0.1) * m.deadline_hours
|
| 304 |
+
m.deadline_hours = max(
|
| 305 |
+
m.timestamp_hours + 0.5,
|
| 306 |
+
min(72.0, m.deadline_hours + d_jitter),
|
| 307 |
+
)
|
| 308 |
+
self._all_messages.append(m)
|
| 309 |
|
| 310 |
self._visible_messages = []
|
| 311 |
self._handled = {}
|
training/crisisinbox_training.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CrisisInbox GRPO Training Script
|
| 3 |
+
Person B: ML Pipeline
|
| 4 |
+
|
| 5 |
+
Run this in Google Colab:
|
| 6 |
+
1. Upload this file
|
| 7 |
+
2. Upload episodes.json from repo
|
| 8 |
+
3. Run: python crisisinbox_training.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import re
|
| 13 |
+
import json
|
| 14 |
+
import numpy as np
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
from unsloth import FastLanguageModel
|
| 17 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 18 |
+
|
| 19 |
+
# Download episodes from GitHub repo
|
| 20 |
+
print("Loading episodes...")
|
| 21 |
+
import urllib.request
|
| 22 |
+
urllib.request.urlretrieve(
|
| 23 |
+
"https://raw.githubusercontent.com/eptan/crisis-inbox/main/episodes.json",
|
| 24 |
+
"episodes.json"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
with open("episodes.json", "r") as f:
|
| 28 |
+
EPISODES = json.load(f)
|
| 29 |
+
|
| 30 |
+
print(f"✓ Loaded {len(EPISODES)} episodes")
|
| 31 |
+
|
| 32 |
+
# =============================================================================
|
| 33 |
+
# PROMPT BUILDING
|
| 34 |
+
# =============================================================================
|
| 35 |
+
|
| 36 |
+
CRISIS_SYSTEM_PROMPT = """
|
| 37 |
+
You are an assistant helping a working parent during a wildfire.
|
| 38 |
+
You must triage messages, act on safety-critical items first,
|
| 39 |
+
respect deadlines, adapt to policy changes, and write appropriate
|
| 40 |
+
tones to different people.
|
| 41 |
+
|
| 42 |
+
Respond ONLY in this format:
|
| 43 |
+
<plan>
|
| 44 |
+
1. [time=min X] ACTION_DESCRIPTION
|
| 45 |
+
2. [time=min Y] ACTION_DESCRIPTION
|
| 46 |
+
...
|
| 47 |
+
</plan>
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def build_crisis_prompt(episode):
|
| 51 |
+
"""Build prompt from episode data."""
|
| 52 |
+
msgs_str = []
|
| 53 |
+
for m in episode.get("messages", []):
|
| 54 |
+
deadline_info = f" (DEADLINE: {m['deadline']}h)" if m.get("deadline") else ""
|
| 55 |
+
urgency = "🔴 URGENT" if m.get("type") == "safety" else ""
|
| 56 |
+
msgs_str.append(
|
| 57 |
+
f"[t={m['time']}h] {urgency} From {m['sender']} via {m['channel']}: {m['content']}{deadline_info}"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
drift_str = []
|
| 61 |
+
for d in episode.get("schema_events", []):
|
| 62 |
+
drift_str.append(f"[t={d['time']}h] POLICY UPDATE: {d['kind']} -> {d.get('new_value', 'changed')}")
|
| 63 |
+
|
| 64 |
+
user_content = (
|
| 65 |
+
"Here is your 48-hour message history:\n\n"
|
| 66 |
+
+ "\n".join(msgs_str)
|
| 67 |
+
+ "\n\nPolicy changes observed:\n"
|
| 68 |
+
+ ("\n".join(drift_str) if drift_str else "None yet.")
|
| 69 |
+
+ "\n\nDecide what to do in order. Remember the required <plan> format."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return [
|
| 73 |
+
{"role": "system", "content": CRISIS_SYSTEM_PROMPT},
|
| 74 |
+
{"role": "user", "content": user_content},
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
# =============================================================================
|
| 78 |
+
# PLAN PARSER
|
| 79 |
+
# =============================================================================
|
| 80 |
+
|
| 81 |
+
def parse_plan(model_output):
|
| 82 |
+
"""Parse <plan> tag output into list of action dicts."""
|
| 83 |
+
actions = []
|
| 84 |
+
|
| 85 |
+
plan_match = re.search(r'<plan>(.*?)</plan>', model_output, re.DOTALL | re.IGNORECASE)
|
| 86 |
+
if not plan_match:
|
| 87 |
+
return []
|
| 88 |
+
|
| 89 |
+
plan_content = plan_match.group(1).strip()
|
| 90 |
+
|
| 91 |
+
lines = plan_content.split('\n')
|
| 92 |
+
for line in lines:
|
| 93 |
+
line = line.strip()
|
| 94 |
+
if not line or not line[0].isdigit():
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
# Extract time: [time=min X]
|
| 98 |
+
time_match = re.search(r'time=min (\d+)', line)
|
| 99 |
+
time_min = int(time_match.group(1)) if time_match else 0
|
| 100 |
+
|
| 101 |
+
# Extract action description
|
| 102 |
+
action_desc = line.split(']', 1)[1].strip() if ']' in line else line
|
| 103 |
+
|
| 104 |
+
# Infer target
|
| 105 |
+
target = "unknown"
|
| 106 |
+
common_targets = ["spouse", "boss", "mom", "dad", "sister", "neighbor", "FEMA", "Insurance", "Airline", "School"]
|
| 107 |
+
for t in common_targets:
|
| 108 |
+
if t.lower() in action_desc.lower():
|
| 109 |
+
target = t
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
actions.append({
|
| 113 |
+
"time_min": time_min,
|
| 114 |
+
"action": action_desc,
|
| 115 |
+
"target": target,
|
| 116 |
+
"task_id": None
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
return actions
|
| 120 |
+
|
| 121 |
+
# =============================================================================
|
| 122 |
+
# 5-COMPONENT REWARD FUNCTION
|
| 123 |
+
# =============================================================================
|
| 124 |
+
|
| 125 |
+
W_SAFETY = 1.5
|
| 126 |
+
W_DEADLINE = 1.5
|
| 127 |
+
W_DRIFT = 1.0
|
| 128 |
+
W_TONE = 0.5
|
| 129 |
+
W_COVER = 1.0
|
| 130 |
+
|
| 131 |
+
def score_safety_priority(episode, parsed_actions):
|
| 132 |
+
"""+10 if safety in first 3 actions, -10 if ignored."""
|
| 133 |
+
safety_tasks = [t_id for t_id, t in episode.get("tasks", {}).items()
|
| 134 |
+
if t.get("category") == "safety"]
|
| 135 |
+
|
| 136 |
+
acted_tasks = set()
|
| 137 |
+
for a in parsed_actions:
|
| 138 |
+
for t_id, t_info in episode.get("tasks", {}).items():
|
| 139 |
+
if a.get("target", "").lower() in t_info.get("sender", "").lower():
|
| 140 |
+
a["task_id"] = t_id
|
| 141 |
+
acted_tasks.add(t_id)
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
score = 0.0
|
| 145 |
+
first_three_targets = [a.get("target", "").lower() for a in parsed_actions[:3]]
|
| 146 |
+
safety_targets = [episode["tasks"][t]["sender"].lower() for t in safety_tasks]
|
| 147 |
+
|
| 148 |
+
if any(t in first_three_targets for t in safety_targets):
|
| 149 |
+
score += 10.0
|
| 150 |
+
|
| 151 |
+
safety_acted = acted_tasks.intersection(set(safety_tasks))
|
| 152 |
+
if not safety_acted and safety_tasks:
|
| 153 |
+
score -= 10.0
|
| 154 |
+
|
| 155 |
+
return score
|
| 156 |
+
|
| 157 |
+
def score_deadlines(episode, parsed_actions):
|
| 158 |
+
"""+5 before deadline, -5 after/missed."""
|
| 159 |
+
score = 0.0
|
| 160 |
+
completion_time = {}
|
| 161 |
+
|
| 162 |
+
for a in parsed_actions:
|
| 163 |
+
for t_id, t_info in episode.get("tasks", {}).items():
|
| 164 |
+
if a.get("target", "").lower() in t_info.get("sender", "").lower():
|
| 165 |
+
if t_id not in completion_time:
|
| 166 |
+
completion_time[t_id] = a["time_min"] / 60.0
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
for t_id, deadline in episode.get("deadlines", {}).items():
|
| 170 |
+
if t_id in completion_time:
|
| 171 |
+
if completion_time[t_id] <= deadline:
|
| 172 |
+
score += 5.0
|
| 173 |
+
else:
|
| 174 |
+
score -= 5.0
|
| 175 |
+
else:
|
| 176 |
+
score -= 5.0
|
| 177 |
+
|
| 178 |
+
return score
|
| 179 |
+
|
| 180 |
+
def score_schema_drift(episode, parsed_actions):
|
| 181 |
+
"""+10 for adapting to policy changes."""
|
| 182 |
+
score = 0.0
|
| 183 |
+
|
| 184 |
+
for event in episode.get("schema_events", []):
|
| 185 |
+
event_time = event["time"]
|
| 186 |
+
event_kind = event.get("kind", "")
|
| 187 |
+
|
| 188 |
+
actions_after = [a for a in parsed_actions if a["time_min"] / 60.0 > event_time]
|
| 189 |
+
|
| 190 |
+
if not actions_after:
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
if "insurance_deadline_change" in event_kind:
|
| 194 |
+
new_deadline = event.get("new_value", 72)
|
| 195 |
+
old_deadline = event.get("old_value", 72)
|
| 196 |
+
|
| 197 |
+
insurance_actions = [a for a in actions_after
|
| 198 |
+
if "insurance" in a.get("action", "").lower()]
|
| 199 |
+
|
| 200 |
+
for a in insurance_actions:
|
| 201 |
+
action_hour = a["time_min"] / 60.0
|
| 202 |
+
if action_hour <= new_deadline:
|
| 203 |
+
score += 10.0
|
| 204 |
+
elif action_hour > old_deadline:
|
| 205 |
+
score -= 10.0
|
| 206 |
+
else:
|
| 207 |
+
score -= 5.0
|
| 208 |
+
|
| 209 |
+
return score
|
| 210 |
+
|
| 211 |
+
def score_tone(episode, model_raw_output):
|
| 212 |
+
"""Simple tone heuristics."""
|
| 213 |
+
score = 0.0
|
| 214 |
+
text = model_raw_output.lower()
|
| 215 |
+
|
| 216 |
+
if "dear" in text or "sincerely" in text or "regards" in text:
|
| 217 |
+
score += 1.0
|
| 218 |
+
|
| 219 |
+
if "love you" in text or "i'm so sorry" in text or "worried" in text:
|
| 220 |
+
score += 1.0
|
| 221 |
+
|
| 222 |
+
if "confirm" in text or "verified" in text or "documentation" in text:
|
| 223 |
+
score += 1.0
|
| 224 |
+
|
| 225 |
+
return score
|
| 226 |
+
|
| 227 |
+
def score_coverage(episode, parsed_actions):
|
| 228 |
+
"""+2 per task touched, -1 per ignored."""
|
| 229 |
+
described_tasks = set(episode.get("tasks", {}).keys())
|
| 230 |
+
|
| 231 |
+
acted_tasks = set()
|
| 232 |
+
for a in parsed_actions:
|
| 233 |
+
for t_id, t_info in episode.get("tasks", {}).items():
|
| 234 |
+
if a.get("target", "").lower() in t_info.get("sender", "").lower():
|
| 235 |
+
acted_tasks.add(t_id)
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
score = 0.0
|
| 239 |
+
for t_id in described_tasks:
|
| 240 |
+
if t_id in acted_tasks:
|
| 241 |
+
score += 2.0
|
| 242 |
+
else:
|
| 243 |
+
score -= 1.0
|
| 244 |
+
|
| 245 |
+
return score
|
| 246 |
+
|
| 247 |
+
def total_reward(episode, model_raw_output):
|
| 248 |
+
"""Calculate total reward from 5 components."""
|
| 249 |
+
parsed_actions = parse_plan(model_raw_output)
|
| 250 |
+
|
| 251 |
+
if not parsed_actions:
|
| 252 |
+
return -20.0
|
| 253 |
+
|
| 254 |
+
r_safety = score_safety_priority(episode, parsed_actions)
|
| 255 |
+
r_deadline = score_deadlines(episode, parsed_actions)
|
| 256 |
+
r_drift = score_schema_drift(episode, parsed_actions)
|
| 257 |
+
r_tone = score_tone(episode, model_raw_output)
|
| 258 |
+
r_cover = score_coverage(episode, parsed_actions)
|
| 259 |
+
|
| 260 |
+
total = (
|
| 261 |
+
W_SAFETY * r_safety +
|
| 262 |
+
W_DEADLINE * r_deadline +
|
| 263 |
+
W_DRIFT * r_drift +
|
| 264 |
+
W_TONE * r_tone +
|
| 265 |
+
W_COVER * r_cover
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
return total
|
| 269 |
+
|
| 270 |
+
# =============================================================================
|
| 271 |
+
# GRPO TRAINING SETUP
|
| 272 |
+
# =============================================================================
|
| 273 |
+
|
| 274 |
+
def build_dataset(episodes):
|
| 275 |
+
rows = []
|
| 276 |
+
for idx, ep in enumerate(episodes):
|
| 277 |
+
prompt = build_crisis_prompt(ep)
|
| 278 |
+
rows.append({
|
| 279 |
+
"id": idx,
|
| 280 |
+
"prompt": prompt,
|
| 281 |
+
"episode": ep,
|
| 282 |
+
})
|
| 283 |
+
return Dataset.from_list(rows)
|
| 284 |
+
|
| 285 |
+
print("Building dataset...")
|
| 286 |
+
crisis_dataset = build_dataset(EPISODES)
|
| 287 |
+
print(f"✓ Built dataset with {len(crisis_dataset)} episodes")
|
| 288 |
+
|
| 289 |
+
# Load model
|
| 290 |
+
print("Loading model...")
|
| 291 |
+
MODEL_NAME = "unsloth/Qwen2.5-0.5B-Instruct"
|
| 292 |
+
|
| 293 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 294 |
+
model_name=MODEL_NAME,
|
| 295 |
+
max_seq_length=2048,
|
| 296 |
+
dtype=None,
|
| 297 |
+
load_in_4bit=True,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
model = FastLanguageModel.get_peft_model(
|
| 301 |
+
model,
|
| 302 |
+
r=16,
|
| 303 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 304 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 305 |
+
lora_alpha=16,
|
| 306 |
+
lora_dropout=0,
|
| 307 |
+
bias="none",
|
| 308 |
+
use_gradient_checkpointing="unsloth",
|
| 309 |
+
random_state=3407,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
print(f"✓ Loaded model: {MODEL_NAME}")
|
| 313 |
+
|
| 314 |
+
# Reward function
|
| 315 |
+
EPISODES_LIST = [row["episode"] for row in crisis_dataset]
|
| 316 |
+
|
| 317 |
+
def crisis_reward_fn(prompts, completions, **kwargs):
|
| 318 |
+
rewards = []
|
| 319 |
+
|
| 320 |
+
for prompt, completion in zip(prompts, completions):
|
| 321 |
+
# Convert completion to string - completion is list of token IDs
|
| 322 |
+
if isinstance(completion, list):
|
| 323 |
+
# Access tokenizer from global scope
|
| 324 |
+
from unsloth import FastLanguageModel
|
| 325 |
+
model_obj = FastLanguageModel.get_model()
|
| 326 |
+
tokenizer_obj = model_obj.tokenizer if hasattr(model_obj, 'tokenizer') else None
|
| 327 |
+
if tokenizer_obj is not None:
|
| 328 |
+
completion_str = tokenizer_obj.decode(completion, skip_special_tokens=True)
|
| 329 |
+
else:
|
| 330 |
+
completion_str = str(completion)
|
| 331 |
+
else:
|
| 332 |
+
completion_str = str(completion)
|
| 333 |
+
|
| 334 |
+
episode = None
|
| 335 |
+
for ep in EPISODES_LIST:
|
| 336 |
+
test_prompt = build_crisis_prompt(ep)
|
| 337 |
+
if test_prompt[1]["content"] == prompt[1]["content"]:
|
| 338 |
+
episode = ep
|
| 339 |
+
break
|
| 340 |
+
|
| 341 |
+
if episode is None:
|
| 342 |
+
idx = kwargs.get("episode_idx", 0)
|
| 343 |
+
episode = EPISODES_LIST[idx % len(EPISODES_LIST)]
|
| 344 |
+
|
| 345 |
+
reward = total_reward(episode, completion_str)
|
| 346 |
+
rewards.append(float(reward))
|
| 347 |
+
|
| 348 |
+
return rewards
|
| 349 |
+
|
| 350 |
+
# Training config
|
| 351 |
+
training_args = GRPOConfig(
|
| 352 |
+
use_vllm=False, # Disabled for Colab compatibility
|
| 353 |
+
num_train_epochs=3,
|
| 354 |
+
max_steps=200, # Start with 200 for test
|
| 355 |
+
per_device_train_batch_size=2,
|
| 356 |
+
gradient_accumulation_steps=4,
|
| 357 |
+
num_generations=4,
|
| 358 |
+
max_completion_length=512,
|
| 359 |
+
temperature=0.7,
|
| 360 |
+
learning_rate=1e-5,
|
| 361 |
+
logging_steps=10,
|
| 362 |
+
save_steps=50,
|
| 363 |
+
output_dir="./crisisinbox_grpo",
|
| 364 |
+
report_to="none",
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
trainer = GRPOTrainer(
|
| 368 |
+
model=model,
|
| 369 |
+
processing_class=tokenizer,
|
| 370 |
+
reward_funcs=crisis_reward_fn,
|
| 371 |
+
train_dataset=crisis_dataset,
|
| 372 |
+
args=training_args,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
print("✓ GRPO Trainer ready")
|
| 376 |
+
|
| 377 |
+
# =============================================================================
|
| 378 |
+
# TRAINING RUN
|
| 379 |
+
# =============================================================================
|
| 380 |
+
|
| 381 |
+
print("🚀 Starting CrisisInbox GRPO Training...")
|
| 382 |
+
print("This will take 1-4 hours depending on max_steps\n")
|
| 383 |
+
|
| 384 |
+
trainer.train()
|
| 385 |
+
|
| 386 |
+
print("\n✅ Training completed!")
|
| 387 |
+
|
| 388 |
+
# =============================================================================
|
| 389 |
+
# SAVE RESULTS
|
| 390 |
+
# =============================================================================
|
| 391 |
+
|
| 392 |
+
model.save_pretrained("./crisisinbox_model_final")
|
| 393 |
+
tokenizer.save_pretrained("./crisisinbox_model_final")
|
| 394 |
+
print("✓ Model saved")
|
| 395 |
+
|
| 396 |
+
# Generate demo examples
|
| 397 |
+
DEMO_EXAMPLES = []
|
| 398 |
+
for ep in EPISODES[:3]:
|
| 399 |
+
prompt = build_crisis_prompt(ep)
|
| 400 |
+
text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
|
| 401 |
+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 402 |
+
outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7)
|
| 403 |
+
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 404 |
+
|
| 405 |
+
reward = total_reward(ep, response)
|
| 406 |
+
DEMO_EXAMPLES.append({
|
| 407 |
+
"episode_id": ep["episode_id"],
|
| 408 |
+
"response": response,
|
| 409 |
+
"reward": reward,
|
| 410 |
+
"parsed_actions": parse_plan(response)
|
| 411 |
+
})
|
| 412 |
+
|
| 413 |
+
with open("demo_examples_trained.json", "w") as f:
|
| 414 |
+
json.dump(DEMO_EXAMPLES, f, indent=2)
|
| 415 |
+
|
| 416 |
+
print(f"✓ Saved {len(DEMO_EXAMPLES)} demo examples")
|