Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Upload folder using huggingface_hub
Browse files- messages.py +275 -0
- models.py +16 -0
- notebooks/crisisinbox_grpo_connected copy.ipynbd +435 -0
- notebooks/crisisinbox_grpo_connected.ipynb +3 -1
- server/crisis_inbox_environment.py +77 -1
- server/rewards.py +4 -0
messages.py
CHANGED
|
@@ -1158,4 +1158,279 @@ ALL_MESSAGES: list[Message] = [
|
|
| 1158 |
urgency=Urgency.LOW,
|
| 1159 |
timestamp_hours=47.5,
|
| 1160 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
]
|
|
|
|
| 1158 |
urgency=Urgency.LOW,
|
| 1159 |
timestamp_hours=47.5,
|
| 1160 |
),
|
| 1161 |
+
|
| 1162 |
+
# ========== CONFLICTING DEADLINES ==========
|
| 1163 |
+
# These pairs have overlapping deadlines — the agent can only do one.
|
| 1164 |
+
|
| 1165 |
+
# Conflict pair 1: School pickup vs Insurance call (both at hour ~8)
|
| 1166 |
+
Message(
|
| 1167 |
+
id="msg_074",
|
| 1168 |
+
sender="Oakwood Elementary",
|
| 1169 |
+
channel=Channel.PHONE,
|
| 1170 |
+
subject="URGENT: Early dismissal pickup required by 2pm",
|
| 1171 |
+
content=(
|
| 1172 |
+
"This is Oakwood Elementary calling about Emma and Jake. Due to the storm, "
|
| 1173 |
+
"we are doing an emergency early dismissal at 2pm today. Your sister listed you "
|
| 1174 |
+
"as emergency pickup. If no authorized adult arrives by 2pm, we will need to "
|
| 1175 |
+
"contact Child Protective Services per district policy. Please confirm."
|
| 1176 |
+
),
|
| 1177 |
+
urgency=Urgency.CRITICAL,
|
| 1178 |
+
timestamp_hours=6.0,
|
| 1179 |
+
deadline_hours=8.0,
|
| 1180 |
+
conflicts_with="msg_075",
|
| 1181 |
+
),
|
| 1182 |
+
Message(
|
| 1183 |
+
id="msg_075",
|
| 1184 |
+
sender="State Farm Insurance",
|
| 1185 |
+
channel=Channel.PHONE,
|
| 1186 |
+
subject="Scheduled damage assessment call - don't miss",
|
| 1187 |
+
content=(
|
| 1188 |
+
"This is your scheduled callback from State Farm. An adjuster is available to "
|
| 1189 |
+
"do a phone assessment of your property damage between 1:30pm and 2:15pm today ONLY. "
|
| 1190 |
+
"If you miss this window, the next available slot is in 12 days. Missing the initial "
|
| 1191 |
+
"assessment may delay your claim payout by 4-6 weeks. Please be available."
|
| 1192 |
+
),
|
| 1193 |
+
urgency=Urgency.HIGH,
|
| 1194 |
+
timestamp_hours=6.0,
|
| 1195 |
+
deadline_hours=8.5,
|
| 1196 |
+
dependencies=["msg_004"],
|
| 1197 |
+
conflicts_with="msg_074",
|
| 1198 |
+
),
|
| 1199 |
+
|
| 1200 |
+
# Conflict pair 2: Boss presentation vs FEMA registration (both at hour ~14)
|
| 1201 |
+
Message(
|
| 1202 |
+
id="msg_076",
|
| 1203 |
+
sender="Boss",
|
| 1204 |
+
channel=Channel.SMS,
|
| 1205 |
+
subject="Client pushed meeting to today - need you on Zoom at 2pm",
|
| 1206 |
+
content=(
|
| 1207 |
+
"Bad news, Meridian moved the meeting to today. I need you on the Zoom call at "
|
| 1208 |
+
"2pm sharp to present your section. It's 30 minutes max. I already told them "
|
| 1209 |
+
"you'd be there. This is the account we've been working on for 6 months. "
|
| 1210 |
+
"Don't let me down."
|
| 1211 |
+
),
|
| 1212 |
+
urgency=Urgency.HIGH,
|
| 1213 |
+
timestamp_hours=12.0,
|
| 1214 |
+
deadline_hours=14.5,
|
| 1215 |
+
conflicts_with="msg_077",
|
| 1216 |
+
),
|
| 1217 |
+
Message(
|
| 1218 |
+
id="msg_077",
|
| 1219 |
+
sender="FEMA",
|
| 1220 |
+
channel=Channel.GOVERNMENT_ALERT,
|
| 1221 |
+
subject="In-person registration window: 1pm-3pm TODAY ONLY",
|
| 1222 |
+
content=(
|
| 1223 |
+
"FEMA Disaster Recovery Center at Sacramento Convention Center is open for "
|
| 1224 |
+
"in-person registration TODAY ONLY from 1pm to 3pm. In-person registrations "
|
| 1225 |
+
"receive priority processing (2-3 weeks vs 6-8 weeks online). Bring ID, proof "
|
| 1226 |
+
"of residence, and damage documentation. This is the only in-person session "
|
| 1227 |
+
"scheduled for your zip code."
|
| 1228 |
+
),
|
| 1229 |
+
urgency=Urgency.HIGH,
|
| 1230 |
+
timestamp_hours=11.5,
|
| 1231 |
+
deadline_hours=15.0,
|
| 1232 |
+
conflicts_with="msg_076",
|
| 1233 |
+
),
|
| 1234 |
+
|
| 1235 |
+
# ========== ESCALATION CHAINS ==========
|
| 1236 |
+
# These messages escalate (spawn angry follow-ups) if not handled in time.
|
| 1237 |
+
|
| 1238 |
+
Message(
|
| 1239 |
+
id="msg_078",
|
| 1240 |
+
sender="Neighbor Dave",
|
| 1241 |
+
channel=Channel.SMS,
|
| 1242 |
+
subject="Can you help me board up windows?",
|
| 1243 |
+
content=(
|
| 1244 |
+
"Hey man, the plywood I got is too big for me to handle alone. My wife's at her "
|
| 1245 |
+
"mom's with the kids. Can you come over for like 20 minutes to help me board up "
|
| 1246 |
+
"the front windows? I'll return the favor anytime. I'm at 422 Oak St."
|
| 1247 |
+
),
|
| 1248 |
+
urgency=Urgency.MEDIUM,
|
| 1249 |
+
timestamp_hours=3.0,
|
| 1250 |
+
deadline_hours=6.0,
|
| 1251 |
+
escalation_trigger="msg_078e",
|
| 1252 |
+
escalation_delay_hours=1.0,
|
| 1253 |
+
),
|
| 1254 |
+
# Escalation: Dave's follow-up (injected by environment if msg_078 unhandled by hour 7)
|
| 1255 |
+
Message(
|
| 1256 |
+
id="msg_078e",
|
| 1257 |
+
sender="Neighbor Dave",
|
| 1258 |
+
channel=Channel.SMS,
|
| 1259 |
+
subject="Window broke. Thanks for nothing",
|
| 1260 |
+
content=(
|
| 1261 |
+
"Well the front window just blew in. Glass everywhere. Would've taken you 20 "
|
| 1262 |
+
"minutes Dave. Twenty minutes. Now I've got water pouring into my living room "
|
| 1263 |
+
"and I'm trying to tape a tarp up by myself. I hope whatever you were doing "
|
| 1264 |
+
"was worth it. Don't bother coming now."
|
| 1265 |
+
),
|
| 1266 |
+
urgency=Urgency.LOW,
|
| 1267 |
+
timestamp_hours=7.0,
|
| 1268 |
+
),
|
| 1269 |
+
|
| 1270 |
+
Message(
|
| 1271 |
+
id="msg_079",
|
| 1272 |
+
sender="Boss",
|
| 1273 |
+
channel=Channel.EMAIL,
|
| 1274 |
+
subject="Slides due by 5pm - FINAL warning",
|
| 1275 |
+
content=(
|
| 1276 |
+
"I haven't received your section of the Meridian slides. I need them by 5pm "
|
| 1277 |
+
"today or I'm giving your section to Sarah and we'll discuss this when things "
|
| 1278 |
+
"settle down. I understand the situation but the client doesn't care about hurricanes. "
|
| 1279 |
+
"5pm. Final."
|
| 1280 |
+
),
|
| 1281 |
+
urgency=Urgency.HIGH,
|
| 1282 |
+
timestamp_hours=15.0,
|
| 1283 |
+
deadline_hours=17.0,
|
| 1284 |
+
escalation_trigger="msg_079e",
|
| 1285 |
+
escalation_delay_hours=2.0,
|
| 1286 |
+
),
|
| 1287 |
+
# Escalation: Boss fires you from the project (injected if msg_079 unhandled by hour 19)
|
| 1288 |
+
Message(
|
| 1289 |
+
id="msg_079e",
|
| 1290 |
+
sender="Boss",
|
| 1291 |
+
channel=Channel.EMAIL,
|
| 1292 |
+
subject="Re: Slides - Gave your section to Sarah",
|
| 1293 |
+
content=(
|
| 1294 |
+
"I waited. Nothing. Sarah's handling your section now. I covered for you with "
|
| 1295 |
+
"the client but I'm not going to lie — this isn't a good look. We'll need to "
|
| 1296 |
+
"have a conversation when you're back. I get it's a disaster but everyone else "
|
| 1297 |
+
"managed to check in."
|
| 1298 |
+
),
|
| 1299 |
+
urgency=Urgency.MEDIUM,
|
| 1300 |
+
timestamp_hours=19.0,
|
| 1301 |
+
),
|
| 1302 |
+
|
| 1303 |
+
Message(
|
| 1304 |
+
id="msg_080",
|
| 1305 |
+
sender="Mom",
|
| 1306 |
+
channel=Channel.SMS,
|
| 1307 |
+
subject="WHY ARENT YOU ANSWERING",
|
| 1308 |
+
content=(
|
| 1309 |
+
"I've called you SEVEN times. Your father is in the car ready to drive down. "
|
| 1310 |
+
"Please just send ONE TEXT so I know you're alive. I am losing my mind. "
|
| 1311 |
+
"If I don't hear from you in the next hour I'm calling 911."
|
| 1312 |
+
),
|
| 1313 |
+
urgency=Urgency.CRITICAL,
|
| 1314 |
+
timestamp_hours=4.0,
|
| 1315 |
+
deadline_hours=5.0,
|
| 1316 |
+
escalation_trigger="msg_080e",
|
| 1317 |
+
escalation_delay_hours=1.5,
|
| 1318 |
+
),
|
| 1319 |
+
# Escalation: Mom actually calls 911 (injected if msg_080 unhandled by hour 6.5)
|
| 1320 |
+
Message(
|
| 1321 |
+
id="msg_080e",
|
| 1322 |
+
sender="Mom",
|
| 1323 |
+
channel=Channel.SMS,
|
| 1324 |
+
subject="Called 911. Dad is driving down",
|
| 1325 |
+
content=(
|
| 1326 |
+
"That's it. I called 911 and filed a welfare check. Your father is on the highway. "
|
| 1327 |
+
"I don't care if you're busy. I don't care if you think I'm overreacting. "
|
| 1328 |
+
"You don't go SILENT during a hurricane. If you see this call me IMMEDIATELY. "
|
| 1329 |
+
"I haven't slept."
|
| 1330 |
+
),
|
| 1331 |
+
urgency=Urgency.HIGH,
|
| 1332 |
+
timestamp_hours=6.5,
|
| 1333 |
+
),
|
| 1334 |
+
|
| 1335 |
+
# ========== MULTI-TURN CONVERSATIONS ==========
|
| 1336 |
+
# Responding to these messages triggers a follow-up requiring another action.
|
| 1337 |
+
|
| 1338 |
+
Message(
|
| 1339 |
+
id="msg_081",
|
| 1340 |
+
sender="State Farm Insurance",
|
| 1341 |
+
channel=Channel.EMAIL,
|
| 1342 |
+
subject="Claim received - additional photos needed",
|
| 1343 |
+
content=(
|
| 1344 |
+
"Thank you for filing your initial claim (#SF-2026-84721). However, our adjuster "
|
| 1345 |
+
"needs additional documentation before we can proceed: (1) Close-up photos of roof "
|
| 1346 |
+
"damage, (2) Water line marks on interior walls, (3) Serial numbers of damaged "
|
| 1347 |
+
"electronics. Please reply with these within 12 hours to keep your claim in the "
|
| 1348 |
+
"expedited queue."
|
| 1349 |
+
),
|
| 1350 |
+
urgency=Urgency.HIGH,
|
| 1351 |
+
timestamp_hours=16.0,
|
| 1352 |
+
deadline_hours=28.0,
|
| 1353 |
+
dependencies=["msg_004"],
|
| 1354 |
+
reply_trigger="msg_081r",
|
| 1355 |
+
),
|
| 1356 |
+
# Reply: Adjuster confirms and asks one more thing
|
| 1357 |
+
Message(
|
| 1358 |
+
id="msg_081r",
|
| 1359 |
+
sender="State Farm Insurance",
|
| 1360 |
+
channel=Channel.EMAIL,
|
| 1361 |
+
subject="Re: Claim #SF-2026-84721 - One more step",
|
| 1362 |
+
content=(
|
| 1363 |
+
"Got your photos, thank you. Your claim is being processed. One final step: "
|
| 1364 |
+
"we need you to sign the digital authorization form I've attached. This authorizes "
|
| 1365 |
+
"our contractor to begin repairs. Without your signature, repairs cannot start "
|
| 1366 |
+
"even if the claim is approved. Please sign within 6 hours."
|
| 1367 |
+
),
|
| 1368 |
+
urgency=Urgency.HIGH,
|
| 1369 |
+
timestamp_hours=0.0, # Timestamp set dynamically when injected
|
| 1370 |
+
deadline_hours=0.0, # Deadline set dynamically (current_hour + 6)
|
| 1371 |
+
),
|
| 1372 |
+
|
| 1373 |
+
Message(
|
| 1374 |
+
id="msg_082",
|
| 1375 |
+
sender="Sister",
|
| 1376 |
+
channel=Channel.SMS,
|
| 1377 |
+
subject="Can you keep the kids overnight?",
|
| 1378 |
+
content=(
|
| 1379 |
+
"Hey so my boss is now saying we have to work through the night because of the storm "
|
| 1380 |
+
"damage at the warehouse. Can you keep Emma and Jake overnight? I know it's a lot "
|
| 1381 |
+
"to ask right now but I literally have no other option. Mom and Dad's power is out. "
|
| 1382 |
+
"They have their backpacks with PJs and stuff."
|
| 1383 |
+
),
|
| 1384 |
+
urgency=Urgency.HIGH,
|
| 1385 |
+
timestamp_hours=18.0,
|
| 1386 |
+
deadline_hours=20.0,
|
| 1387 |
+
reply_trigger="msg_082r",
|
| 1388 |
+
),
|
| 1389 |
+
# Reply: Sister responds with logistics
|
| 1390 |
+
Message(
|
| 1391 |
+
id="msg_082r",
|
| 1392 |
+
sender="Sister",
|
| 1393 |
+
channel=Channel.SMS,
|
| 1394 |
+
subject="Re: Kids overnight - Emma's medication",
|
| 1395 |
+
content=(
|
| 1396 |
+
"OMG thank you, you're a lifesaver. One thing — Emma needs her allergy medication "
|
| 1397 |
+
"at 8pm. It's the pink liquid in her backpack front pocket. 5mL. She knows but she'll "
|
| 1398 |
+
"try to skip it because it tastes bad. Don't let her. Also Jake needs a nightlight "
|
| 1399 |
+
"or he won't sleep. Sorry I'm the worst. I owe you forever."
|
| 1400 |
+
),
|
| 1401 |
+
urgency=Urgency.MEDIUM,
|
| 1402 |
+
timestamp_hours=0.0, # Set dynamically
|
| 1403 |
+
deadline_hours=0.0, # Set dynamically (current_hour + 2)
|
| 1404 |
+
),
|
| 1405 |
+
|
| 1406 |
+
Message(
|
| 1407 |
+
id="msg_083",
|
| 1408 |
+
sender="Neighbor Dave",
|
| 1409 |
+
channel=Channel.SMS,
|
| 1410 |
+
subject="Found your dog!!",
|
| 1411 |
+
content=(
|
| 1412 |
+
"Dude your dog is in my backyard! Max must have gotten out through the fence that blew "
|
| 1413 |
+
"down. He's soaking wet but seems ok. I put him in my garage with a towel. Come get "
|
| 1414 |
+
"him when you can but he seems pretty stressed — keeps whining. Let me know."
|
| 1415 |
+
),
|
| 1416 |
+
urgency=Urgency.MEDIUM,
|
| 1417 |
+
timestamp_hours=9.0,
|
| 1418 |
+
reply_trigger="msg_083r",
|
| 1419 |
+
),
|
| 1420 |
+
# Reply: Dave found something concerning about the dog
|
| 1421 |
+
Message(
|
| 1422 |
+
id="msg_083r",
|
| 1423 |
+
sender="Neighbor Dave",
|
| 1424 |
+
channel=Channel.SMS,
|
| 1425 |
+
subject="Re: Your dog - he's limping",
|
| 1426 |
+
content=(
|
| 1427 |
+
"Hey so Max is limping on his back left leg. I didn't notice at first because he was "
|
| 1428 |
+
"just laying down but when I gave him water he got up and he's definitely favoring it. "
|
| 1429 |
+
"Might want to get him to a vet. I think the emergency vet on J Street is still open "
|
| 1430 |
+
"despite the storm. Want me to drive you two over there?"
|
| 1431 |
+
),
|
| 1432 |
+
urgency=Urgency.HIGH,
|
| 1433 |
+
timestamp_hours=0.0, # Set dynamically
|
| 1434 |
+
deadline_hours=0.0, # Set dynamically (current_hour + 4)
|
| 1435 |
+
),
|
| 1436 |
]
|
models.py
CHANGED
|
@@ -66,3 +66,19 @@ class Message(BaseModel):
|
|
| 66 |
default=None,
|
| 67 |
description="ID of a previous message this one replaces (due to drift)",
|
| 68 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
default=None,
|
| 67 |
description="ID of a previous message this one replaces (due to drift)",
|
| 68 |
)
|
| 69 |
+
conflicts_with: Optional[str] = Field(
|
| 70 |
+
default=None,
|
| 71 |
+
description="ID of another message with an overlapping deadline — only one can be handled",
|
| 72 |
+
)
|
| 73 |
+
escalation_trigger: Optional[str] = Field(
|
| 74 |
+
default=None,
|
| 75 |
+
description="ID of a follow-up message that appears if THIS message is not handled in time",
|
| 76 |
+
)
|
| 77 |
+
escalation_delay_hours: Optional[float] = Field(
|
| 78 |
+
default=None,
|
| 79 |
+
description="Hours after this message's deadline before the escalation fires",
|
| 80 |
+
)
|
| 81 |
+
reply_trigger: Optional[str] = Field(
|
| 82 |
+
default=None,
|
| 83 |
+
description="ID of a follow-up message injected when THIS message is handled (multi-turn)",
|
| 84 |
+
)
|
notebooks/crisisinbox_grpo_connected copy.ipynbd
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "ym4tunggrm",
|
| 6 |
+
"source": "# CrisisInbox GRPO Training (Connected to HF Space)\n\nTrain a small LLM (Qwen2.5-0.5B) to triage crisis inbox messages using Group Relative Policy Optimization.\n\n**This notebook connects to the live CrisisInbox environment** deployed on HuggingFace Spaces at `https://eptan-crisis-inbox.hf.space` to collect training episodes in real-time, then trains the model using GRPO.\n\n**Stack:** HF TRL + PEFT (LoRA on full bf16 model \u2014 no quantization needed for 0.5B)\n\n**What this does:**\n1. Connects to the deployed CrisisInbox environment via WebSocket\n2. Collects episodes by interacting with the environment (reset, list tools, call tools)\n3. Builds training prompts from live environment observations\n4. Trains the model with GRPO using a reward function\n5. Evaluates the trained model against the live environment\n\nOpen in Google Colab or Northflank with a GPU runtime.",
|
| 7 |
+
"metadata": {}
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"cell_type": "code",
|
| 11 |
+
"id": "p0j1w7pr7ib",
|
| 12 |
+
"source": [
|
| 13 |
+
"# Install dependencies (pure HF TRL + PEFT, no quantization needed for 0.5B model)\n",
|
| 14 |
+
"!pip install trl transformers datasets accelerate peft -q\n",
|
| 15 |
+
"!pip install \"openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git\" -q\n",
|
| 16 |
+
"!pip install huggingface_hub matplotlib -q\n",
|
| 17 |
+
"print(\"Setup complete\")\n"
|
| 18 |
+
],
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"execution_count": null,
|
| 21 |
+
"outputs": []
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"id": "sg4tfghxfgb",
|
| 26 |
+
"source": "# Patch transformers logging crash\nimport logging\nimport warnings\n\ndef _patch_transformers_logging():\n try:\n import transformers.utils.logging as trans_log\n _orig = trans_log.logger.warning\n def _safe_warning(msg, *args, **kwargs):\n if args and isinstance(args[0], type) and issubclass(args[0], Warning):\n args = ()\n return _orig(msg, *args, **kwargs)\n trans_log.logger.warning = _safe_warning\n except Exception:\n pass\n warnings.filterwarnings(\"ignore\", message=\".*attention mask API.*\", category=FutureWarning)\n\n_patch_transformers_logging()",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"execution_count": null,
|
| 29 |
+
"outputs": []
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"id": "t0a3hdk1jqk",
|
| 34 |
+
"source": "import torch\n\nif torch.cuda.is_available():\n props = torch.cuda.get_device_properties(0)\n total_bytes = getattr(props, \"total_memory\", None) or getattr(props, \"total_mem\", 0)\n vram_gb = total_bytes / 1e9 if total_bytes else 0\n if vram_gb == 0 and hasattr(torch.cuda, \"mem_get_info\"):\n _, total_bytes = torch.cuda.mem_get_info(0)\n vram_gb = total_bytes / 1e9\n print(f\"GPU: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB)\")\nelse:\n print(\"No GPU available.\")",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"outputs": []
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "markdown",
|
| 41 |
+
"id": "9b4l2fp3jm5",
|
| 42 |
+
"source": "## Connect to CrisisInbox Environment\n\nConnect to the live environment on HuggingFace Spaces via WebSocket and collect episodes by running through the 48-hour simulation multiple times with different seeds.",
|
| 43 |
+
"metadata": {}
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"id": "pmbt9gcp9hb",
|
| 48 |
+
"source": "import json\nimport time as _time\nfrom openenv.core.mcp_client import MCPToolClient\n\nBASE_URL = \"https://eptan-crisis-inbox.hf.space\"\n\n# Wake up the HF Space (may be sleeping) and verify connectivity\nprint(\"Connecting to HF Space (may take a moment if cold-starting)...\")\nfor attempt in range(3):\n try:\n with MCPToolClient(base_url=BASE_URL, connect_timeout_s=60.0).sync() as env:\n env.reset(seed=0)\n tools = env.list_tools()\n print(f\"Connected! Available tools: {[t.name for t in tools]}\")\n for t in tools:\n print(f\" - {t.name}: {t.description[:80]}...\")\n\n status = json.loads(env.call_tool(\"get_status\"))\n print(f\"\\nEnvironment ready \u2014 {status['messages_total_arrived']} messages at hour {status['current_hour']}\")\n break\n except Exception as e:\n if attempt < 2:\n print(f\" Attempt {attempt + 1} failed ({e}), retrying in 10s...\")\n _time.sleep(10)\n else:\n raise RuntimeError(f\"Could not connect to {BASE_URL} after 3 attempts: {e}\")",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"outputs": []
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"id": "rfzviywy9od",
|
| 56 |
+
"source": [
|
| 57 |
+
"import random\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"def collect_episode(base_url, seed, time_steps=None):\n",
|
| 61 |
+
" \"\"\"Collect one episode from the live environment using OpenEnv tools.\"\"\"\n",
|
| 62 |
+
" if time_steps is None:\n",
|
| 63 |
+
" time_steps = [0, 2, 6, 12, 18, 24, 30, 36, 42, 47]\n",
|
| 64 |
+
"\n",
|
| 65 |
+
" superseded_msgs = {}\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" with MCPToolClient(\n",
|
| 68 |
+
" base_url=base_url, connect_timeout_s=60.0, message_timeout_s=120.0,\n",
|
| 69 |
+
" ).sync() as env:\n",
|
| 70 |
+
" env.reset(seed=seed)\n",
|
| 71 |
+
" decision_points = []\n",
|
| 72 |
+
" current_hour = 0.0\n",
|
| 73 |
+
"\n",
|
| 74 |
+
" for target_hour in time_steps:\n",
|
| 75 |
+
" while current_hour < target_hour - 0.1:\n",
|
| 76 |
+
" jump = min(4.0, target_hour - current_hour)\n",
|
| 77 |
+
" env.call_tool(\"advance_time\", hours=jump)\n",
|
| 78 |
+
" status = json.loads(env.call_tool(\"get_status\"))\n",
|
| 79 |
+
" current_hour = status[\"current_hour\"]\n",
|
| 80 |
+
"\n",
|
| 81 |
+
" inbox = json.loads(env.call_tool(\"get_inbox\"))\n",
|
| 82 |
+
" prompt = env.call_tool(\"get_prompt\") # Server builds the prompt\n",
|
| 83 |
+
"\n",
|
| 84 |
+
" for m in inbox:\n",
|
| 85 |
+
" if m.get(\"superseded\"):\n",
|
| 86 |
+
" superseded_msgs[m[\"id\"]] = \"\"\n",
|
| 87 |
+
"\n",
|
| 88 |
+
" unhandled = [m for m in inbox if not m.get(\"handled\", False)]\n",
|
| 89 |
+
" if not unhandled:\n",
|
| 90 |
+
" continue\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" decision_points.append({\n",
|
| 93 |
+
" \"hour\": target_hour,\n",
|
| 94 |
+
" \"visible_count\": len(inbox),\n",
|
| 95 |
+
" \"prompt\": prompt,\n",
|
| 96 |
+
" \"messages\": inbox,\n",
|
| 97 |
+
" \"superseded\": dict(superseded_msgs),\n",
|
| 98 |
+
" })\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" return {\n",
|
| 101 |
+
" \"episode_id\": f\"ep_{seed}\",\n",
|
| 102 |
+
" \"seed\": seed,\n",
|
| 103 |
+
" \"drift_events\": [],\n",
|
| 104 |
+
" \"superseded_messages\": superseded_msgs,\n",
|
| 105 |
+
" \"decision_points\": decision_points,\n",
|
| 106 |
+
" }\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Test: collect one episode\n",
|
| 110 |
+
"print(\"Collecting test episode (seed=42)...\")\n",
|
| 111 |
+
"test_ep = collect_episode(BASE_URL, seed=42)\n",
|
| 112 |
+
"print(f\"Episode {test_ep['episode_id']}: {len(test_ep['decision_points'])} decision points\")\n",
|
| 113 |
+
"for dp in test_ep[\"decision_points\"]:\n",
|
| 114 |
+
" print(f\" Hour {dp['hour']:5.1f}: {dp['visible_count']} messages visible\")\n"
|
| 115 |
+
],
|
| 116 |
+
"metadata": {},
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"outputs": []
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"id": "nmmfb62s9ph",
|
| 123 |
+
"source": "# Collect multiple episodes from the live environment\nNUM_EPISODES = 10\nSEEDS = list(range(NUM_EPISODES))\n\nepisodes = []\nfor seed in SEEDS:\n print(f\"Collecting episode {seed + 1}/{NUM_EPISODES} (seed={seed})...\", end=\" \")\n for attempt in range(3):\n try:\n ep = collect_episode(BASE_URL, seed=seed)\n episodes.append(ep)\n print(f\"{len(ep['decision_points'])} decision points\")\n break\n except Exception as e:\n if attempt < 2:\n print(f\"retry {attempt + 1}...\", end=\" \")\n _time.sleep(5)\n else:\n print(f\"FAILED ({e}), skipping\")\n\n# Flatten to training prompts\nprompts = []\nfor ep in episodes:\n for dp in ep[\"decision_points\"]:\n prompts.append({\n \"prompt\": dp[\"prompt\"],\n \"hour\": dp[\"hour\"],\n \"visible_count\": dp[\"visible_count\"],\n \"episode_id\": ep[\"episode_id\"],\n \"seed\": ep[\"seed\"],\n \"drift_events\": ep[\"drift_events\"],\n \"superseded\": ep.get(\"superseded_messages\", {}),\n \"messages\": dp[\"visible_messages\"],\n })\n\nprint(f\"\\nCollected {len(episodes)} episodes -> {len(prompts)} training prompts\")\nprint(f\"Average {len(prompts)/len(episodes):.1f} decision points per episode\")",
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"execution_count": null,
|
| 126 |
+
"outputs": []
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"id": "pg833f350e",
|
| 131 |
+
"source": "## Reward Function\n\nScores agent actions based on:\n- **Urgency base** (critical=10, high=5, medium=3, low=1)\n- **Deadline timing** (early=bonus, late=penalty)\n- **Drift adaptation** (+50% for handling policy-change messages)\n- **Stale info penalty** (-50% for acting on superseded messages)\n- **Response quality** (penalty for short/empty responses)",
|
| 132 |
+
"metadata": {}
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "code",
|
| 136 |
+
"id": "2xd2afp4g99",
|
| 137 |
+
"source": [
|
| 138 |
+
"import re\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"def score_action(completion, prompt_data):\n",
|
| 142 |
+
" \"\"\"Score a model completion. Mirrors server's calculate_reward().\"\"\"\n",
|
| 143 |
+
" messages = prompt_data[\"messages\"]\n",
|
| 144 |
+
" hour = prompt_data[\"hour\"]\n",
|
| 145 |
+
" superseded = prompt_data.get(\"superseded\", {})\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" match = re.search(\n",
|
| 148 |
+
" r'respond_to_message\\s*\\(\\s*[\"\\']?(msg_\\d+)[\"\\']?\\s*,\\s*[\"\\'](.+?)[\"\\']',\n",
|
| 149 |
+
" completion, re.DOTALL,\n",
|
| 150 |
+
" )\n",
|
| 151 |
+
" if match:\n",
|
| 152 |
+
" msg_id, response_text = match.group(1), match.group(2)\n",
|
| 153 |
+
" else:\n",
|
| 154 |
+
" id_match = re.search(r'(msg_\\d+)', completion)\n",
|
| 155 |
+
" if id_match:\n",
|
| 156 |
+
" msg_id, response_text = id_match.group(1), completion[:200]\n",
|
| 157 |
+
" else:\n",
|
| 158 |
+
" return -1.0\n",
|
| 159 |
+
"\n",
|
| 160 |
+
" target = next((m for m in messages if m[\"id\"] == msg_id), None)\n",
|
| 161 |
+
" if target is None:\n",
|
| 162 |
+
" return -0.5\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" urgency_rewards = {\"critical\": 10.0, \"high\": 5.0, \"medium\": 3.0, \"low\": 1.0}\n",
|
| 165 |
+
" reward = urgency_rewards.get(target[\"urgency\"], 1.0)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
" deadline = target.get(\"deadline_hours\")\n",
|
| 168 |
+
" if deadline is not None:\n",
|
| 169 |
+
" if hour <= deadline:\n",
|
| 170 |
+
" reward *= 1.0 + 0.5 * ((deadline - hour) / max(deadline, 1.0))\n",
|
| 171 |
+
" else:\n",
|
| 172 |
+
" reward *= 0.25\n",
|
| 173 |
+
"\n",
|
| 174 |
+
" if len(response_text.strip()) < 10:\n",
|
| 175 |
+
" reward *= 0.5\n",
|
| 176 |
+
"\n",
|
| 177 |
+
" if target.get(\"drift_flag\"):\n",
|
| 178 |
+
" reward *= 1.5\n",
|
| 179 |
+
"\n",
|
| 180 |
+
" if target[\"id\"] in superseded:\n",
|
| 181 |
+
" reward *= 0.5\n",
|
| 182 |
+
"\n",
|
| 183 |
+
" unhandled = [m for m in messages if not m.get(\"handled\") and m[\"id\"] != msg_id]\n",
|
| 184 |
+
" if any(m[\"urgency\"] == \"critical\" for m in unhandled) and target[\"urgency\"] in (\"low\", \"medium\"):\n",
|
| 185 |
+
" reward *= 0.3\n",
|
| 186 |
+
"\n",
|
| 187 |
+
" return round(reward, 2)\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"# Test\n",
|
| 191 |
+
"test_data = prompts[0]\n",
|
| 192 |
+
"print(\"Testing reward function:\")\n",
|
| 193 |
+
"print(f\" Hour: {test_data['hour']}, Messages: {test_data['visible_count']}\")\n",
|
| 194 |
+
"critical_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"critical\"]\n",
|
| 195 |
+
"if critical_msgs:\n",
|
| 196 |
+
" good = f'respond_to_message(\"{critical_msgs[0][\"id\"]}\", \"Evacuating now with documents.\")'\n",
|
| 197 |
+
" print(f\" Good action (critical): {score_action(good, test_data):.2f}\")\n",
|
| 198 |
+
"low_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"low\"]\n",
|
| 199 |
+
"if low_msgs:\n",
|
| 200 |
+
" bad = f'respond_to_message(\"{low_msgs[0][\"id\"]}\", \"ok\")'\n",
|
| 201 |
+
" print(f\" Bad action (low, short): {score_action(bad, test_data):.2f}\")\n",
|
| 202 |
+
"print(f\" Unparseable: {score_action('do something', test_data):.2f}\")\n"
|
| 203 |
+
],
|
| 204 |
+
"metadata": {},
|
| 205 |
+
"execution_count": null,
|
| 206 |
+
"outputs": []
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"cell_type": "markdown",
|
| 210 |
+
"id": "0fr2fzreorqr",
|
| 211 |
+
"source": "## Load Model & Baseline Evaluation\n\nLoad the model, run a **pre-training baseline** against the live environment, then configure GRPO training.",
|
| 212 |
+
"metadata": {}
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"id": "zey499u5w1a",
|
| 217 |
+
"source": "from transformers import AutoModelForCausalLM, AutoTokenizer\nfrom peft import LoraConfig\nimport torch\n\n# Auto-detect precision\n_use_bf16 = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False\n_compute_dtype = torch.bfloat16 if _use_bf16 else torch.float16\n\n# Load in full bf16/fp16 \u2014 no 4-bit quantization.\n# Qwen2.5-0.5B is ~1GB in bf16, fits easily on any GPU.\n# This avoids all bitsandbytes dtype mismatch issues with lm_head.\nmodel = AutoModelForCausalLM.from_pretrained(\n \"Qwen/Qwen2.5-0.5B-Instruct\",\n device_map=\"auto\",\n torch_dtype=_compute_dtype,\n)\ntokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B-Instruct\")\n\n# Fix: TRL GRPOTrainer expects warnings_issued but newer transformers removed it.\nif not hasattr(model, \"warnings_issued\"):\n model.warnings_issued = {}\n\n# GRPO requires left padding so completions align across the batch\ntokenizer.padding_side = \"left\"\nif tokenizer.pad_token_id is None:\n tokenizer.pad_token = tokenizer.eos_token\n tokenizer.pad_token_id = tokenizer.eos_token_id\n\n# LoRA config \u2014 passed to GRPOTrainer, not applied here\nlora_config = LoraConfig(\n r=16,\n lora_alpha=16,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_dropout=0.0,\n bias=\"none\",\n task_type=\"CAUSAL_LM\",\n)\n\nprint(f\"Model loaded in {_compute_dtype} (no quantization)\")\nprint(f\"Precision: {'bf16' if _use_bf16 else 'fp16'}\")\nprint(f\"Model size: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M params\")",
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"execution_count": null,
|
| 220 |
+
"outputs": []
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "markdown",
|
| 224 |
+
"id": "8iigxyoxiks",
|
| 225 |
+
"source": "### Pre-Training Baseline\n\nEvaluate the **untrained** model against the live environment before any GRPO training. This gives us a baseline to measure improvement.",
|
| 226 |
+
"metadata": {}
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"id": "6r2zhgg94fk",
|
| 231 |
+
"source": [
|
| 232 |
+
"# --- Pre-training baseline evaluation against live environment ---\n",
|
| 233 |
+
"from openenv.core.env_server.mcp_types import CallToolAction\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"def generate_action(model, tokenizer, prompt_text):\n",
|
| 237 |
+
" \"\"\"Generate an action from the model.\"\"\"\n",
|
| 238 |
+
" msgs = [{\"role\": \"user\", \"content\": prompt_text}]\n",
|
| 239 |
+
" input_ids = tokenizer.apply_chat_template(msgs, return_tensors=\"pt\", add_generation_prompt=True)\n",
|
| 240 |
+
" if not isinstance(input_ids, torch.Tensor):\n",
|
| 241 |
+
" input_ids = input_ids[\"input_ids\"]\n",
|
| 242 |
+
" input_ids = input_ids.to(\"cuda\")\n",
|
| 243 |
+
" prompt_len = input_ids.shape[1]\n",
|
| 244 |
+
" with torch.no_grad():\n",
|
| 245 |
+
" output = model.generate(input_ids=input_ids, max_new_tokens=200, temperature=0.7,\n",
|
| 246 |
+
" pad_token_id=tokenizer.pad_token_id, do_sample=True)\n",
|
| 247 |
+
" return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"def _extract_tool_result(obs):\n",
|
| 251 |
+
" \"\"\"Extract JSON from a CallToolObservation (handles FastMCP wrapping).\"\"\"\n",
|
| 252 |
+
" raw = getattr(obs, \"result\", None)\n",
|
| 253 |
+
" if hasattr(raw, \"data\"):\n",
|
| 254 |
+
" raw = raw.data\n",
|
| 255 |
+
" if isinstance(raw, dict) and \"data\" in raw:\n",
|
| 256 |
+
" raw = raw[\"data\"]\n",
|
| 257 |
+
" if isinstance(raw, str):\n",
|
| 258 |
+
" try:\n",
|
| 259 |
+
" return json.loads(raw)\n",
|
| 260 |
+
" except (json.JSONDecodeError, TypeError):\n",
|
| 261 |
+
" return {}\n",
|
| 262 |
+
" return raw if isinstance(raw, dict) else {}\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"def evaluate_on_live_env(model, tokenizer, base_url, seed, max_steps=20):\n",
|
| 266 |
+
" \"\"\"Evaluate model against the live environment using OpenEnv step() flow.\"\"\"\n",
|
| 267 |
+
" with MCPToolClient(base_url=base_url, connect_timeout_s=60.0, message_timeout_s=120.0).sync() as env:\n",
|
| 268 |
+
" env.reset(seed=seed)\n",
|
| 269 |
+
" total_reward = 0.0\n",
|
| 270 |
+
" actions_taken = []\n",
|
| 271 |
+
"\n",
|
| 272 |
+
" for step_i in range(max_steps):\n",
|
| 273 |
+
" status = json.loads(env.call_tool(\"get_status\"))\n",
|
| 274 |
+
" if status.get(\"done\"):\n",
|
| 275 |
+
" break\n",
|
| 276 |
+
"\n",
|
| 277 |
+
" inbox = json.loads(env.call_tool(\"get_inbox\"))\n",
|
| 278 |
+
" current_hour = status[\"current_hour\"]\n",
|
| 279 |
+
" unhandled = [m for m in inbox if not m.get(\"handled\", False)]\n",
|
| 280 |
+
" if not unhandled:\n",
|
| 281 |
+
" env.call_tool(\"advance_time\", hours=2.0)\n",
|
| 282 |
+
" continue\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" # Use server's get_prompt tool\n",
|
| 285 |
+
" prompt = env.call_tool(\"get_prompt\")\n",
|
| 286 |
+
" completion = generate_action(model, tokenizer, prompt)\n",
|
| 287 |
+
"\n",
|
| 288 |
+
" match = re.search(r'respond_to_message\\s*\\(\\s*[\"\\']?(msg_\\d+)[\"\\']?\\s*,\\s*[\"\\'](.+?)[\"\\']', completion, re.DOTALL)\n",
|
| 289 |
+
" if not match:\n",
|
| 290 |
+
" id_match = re.search(r'(msg_\\d+)', completion)\n",
|
| 291 |
+
" if id_match:\n",
|
| 292 |
+
" msg_id, response_text = id_match.group(1), completion[:200]\n",
|
| 293 |
+
" else:\n",
|
| 294 |
+
" env.call_tool(\"advance_time\", hours=1.0)\n",
|
| 295 |
+
" continue\n",
|
| 296 |
+
" else:\n",
|
| 297 |
+
" msg_id, response_text = match.group(1), match.group(2)\n",
|
| 298 |
+
"\n",
|
| 299 |
+
" action = CallToolAction(\n",
|
| 300 |
+
" tool_name=\"respond_to_message\",\n",
|
| 301 |
+
" arguments={\"message_id\": msg_id, \"response\": response_text},\n",
|
| 302 |
+
" )\n",
|
| 303 |
+
" step_result = env.step(action)\n",
|
| 304 |
+
" obs = step_result.observation\n",
|
| 305 |
+
" reward = obs.reward if obs.reward is not None else 0.0\n",
|
| 306 |
+
" done = obs.done\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" result_data = _extract_tool_result(obs)\n",
|
| 309 |
+
" if \"error\" in result_data:\n",
|
| 310 |
+
" env.call_tool(\"advance_time\", hours=1.0)\n",
|
| 311 |
+
" continue\n",
|
| 312 |
+
"\n",
|
| 313 |
+
" total_reward += reward\n",
|
| 314 |
+
" target_msg = next((m for m in inbox if m[\"id\"] == msg_id), None)\n",
|
| 315 |
+
" urgency = target_msg[\"urgency\"] if target_msg else \"?\"\n",
|
| 316 |
+
" actions_taken.append({\"step\": step_i, \"hour\": current_hour, \"msg_id\": msg_id, \"urgency\": urgency, \"reward\": reward})\n",
|
| 317 |
+
" print(f\" Step {step_i:2d} | Hour {current_hour:5.1f} | {msg_id} ({urgency:8s}) | Reward: {reward:+.1f} | Total: {total_reward:.1f}\")\n",
|
| 318 |
+
"\n",
|
| 319 |
+
" if done:\n",
|
| 320 |
+
" break\n",
|
| 321 |
+
"\n",
|
| 322 |
+
" final_status = json.loads(env.call_tool(\"get_status\"))\n",
|
| 323 |
+
"\n",
|
| 324 |
+
" return {\"seed\": seed, \"total_reward\": total_reward, \"actions\": actions_taken, \"final_status\": final_status}\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"# Run baseline on 3 seeds\n",
|
| 328 |
+
"print(\"=== PRE-TRAINING BASELINE (untrained model) ===\\n\")\n",
|
| 329 |
+
"baseline_results = []\n",
|
| 330 |
+
"for seed in [99, 42, 7]:\n",
|
| 331 |
+
" print(f\"--- Seed {seed} ---\")\n",
|
| 332 |
+
" res = evaluate_on_live_env(model, tokenizer, BASE_URL, seed=seed)\n",
|
| 333 |
+
" baseline_results.append(res)\n",
|
| 334 |
+
" print(f\" Total: {res['total_reward']:.1f} | Actions: {len(res['actions'])}\\n\")\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"baseline_avg = sum(r[\"total_reward\"] for r in baseline_results) / len(baseline_results)\n",
|
| 337 |
+
"print(f\"Baseline average reward: {baseline_avg:.1f}\")\n"
|
| 338 |
+
],
|
| 339 |
+
"metadata": {},
|
| 340 |
+
"execution_count": null,
|
| 341 |
+
"outputs": []
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"cell_type": "code",
|
| 345 |
+
"id": "97iryd40qzt",
|
| 346 |
+
"source": "from datasets import Dataset\n\nMAX_PROMPT_LENGTH = 1024\n\ntrain_data = []\nfor p in prompts:\n msgs = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n tok = tokenizer.apply_chat_template(msgs, truncation=True, max_length=1024, return_tensors=\"pt\", add_generation_prompt=True)\n try:\n ids = tok[\"input_ids\"]\n except (TypeError, KeyError):\n ids = tok\n n_tokens = ids.shape[1] if ids.dim() > 1 else ids.shape[0]\n if n_tokens > MAX_PROMPT_LENGTH:\n continue\n train_data.append({\n \"prompt\": msgs,\n \"_prompt_key\": p[\"prompt\"][:200],\n })\n\nrandom.seed(42)\nrandom.shuffle(train_data)\n\ndataset = Dataset.from_list(train_data)\nprint(f\"Training dataset: {len(dataset)} prompts (after dropping prompts > {MAX_PROMPT_LENGTH} tokens)\")\nprint(f\"Sample prompt length: {len(train_data[0]['prompt'][0]['content'])} chars\")",
|
| 347 |
+
"metadata": {},
|
| 348 |
+
"execution_count": null,
|
| 349 |
+
"outputs": []
|
| 350 |
+
},
|
| 351 |
+
{
|
| 352 |
+
"cell_type": "markdown",
|
| 353 |
+
"id": "2vcnc1dr4d1",
|
| 354 |
+
"source": "## GRPO Training Loop\n\nThe reward function scores each completion by:\n1. Parsing which message the model chose to handle\n2. Checking urgency, deadline timing, drift flags\n3. Penalizing bad choices (low-urgency when critical exists, stale info)",
|
| 355 |
+
"metadata": {}
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"cell_type": "code",
|
| 359 |
+
"id": "9liiw6eifdo",
|
| 360 |
+
"source": "import gc\ngc.collect()\ntorch.cuda.empty_cache()",
|
| 361 |
+
"metadata": {},
|
| 362 |
+
"execution_count": null,
|
| 363 |
+
"outputs": []
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"cell_type": "code",
|
| 367 |
+
"id": "arbp96a9wi",
|
| 368 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\n# Build lookup from prompt text -> prompt metadata for reward scoring\n# Use first 200 chars as key (reliable \u2014 TRL may not pass custom dataset columns)\nprompt_lookup = {}\nfor p in prompts:\n key = p[\"prompt\"][:200]\n prompt_lookup[key] = p\n\n\ndef reward_fn(prompts, completions, **kwargs):\n \"\"\"GRPO reward function. Scores each completion against its inbox state.\"\"\"\n rewards = []\n for prompt_msgs, completion in zip(prompts, completions):\n # Extract prompt text to look up metadata\n if isinstance(prompt_msgs, list):\n prompt_text = prompt_msgs[-1][\"content\"] if prompt_msgs else \"\"\n else:\n prompt_text = str(prompt_msgs)\n\n key = prompt_text[:200]\n prompt_data = prompt_lookup.get(key)\n\n if prompt_data is None:\n rewards.append(0.0)\n continue\n\n if isinstance(completion, list):\n if completion and isinstance(completion[0], (int, float)):\n comp_text = tokenizer.decode(completion, skip_special_tokens=True)\n else:\n comp_text = completion[-1].get(\"content\", \"\") if completion else \"\"\n else:\n comp_text = str(completion)\n\n score = score_action(comp_text, prompt_data)\n rewards.append(score)\n\n return rewards\n\n\ntraining_args = GRPOConfig(\n output_dir=\"crisisinbox-grpo-output\",\n num_train_epochs=3,\n per_device_train_batch_size=2,\n gradient_accumulation_steps=2,\n learning_rate=1e-5,\n max_completion_length=256,\n max_prompt_length=1024,\n num_generations=4,\n logging_steps=1,\n save_steps=100,\n bf16=_use_bf16,\n fp16=not _use_bf16,\n sync_ref_model=True,\n)\n\n# Let GRPOTrainer handle PEFT wrapping (avoids dtype mismatches from manual setup)\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_fn,\n args=training_args,\n train_dataset=dataset,\n peft_config=lora_config,\n)\n\n# After trainer init, update model ref to the PEFT-wrapped version\nmodel = trainer.model\n\nprint(f\"Trainer configured \u2014 {len(prompt_lookup)} unique prompt keys\")\nprint(f\"Precision: {'bf16' if _use_bf16 else 'fp16'}\")\nprint(f\"Training for {training_args.num_train_epochs} epochs\")\nprint(\"Ready to train\")",
|
| 369 |
+
"metadata": {},
|
| 370 |
+
"execution_count": null,
|
| 371 |
+
"outputs": []
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"cell_type": "code",
|
| 375 |
+
"id": "ni6dh0hkegm",
|
| 376 |
+
"source": "# Train!\ntrainer.train()\nprint(\"Training complete\")",
|
| 377 |
+
"metadata": {},
|
| 378 |
+
"execution_count": null,
|
| 379 |
+
"outputs": []
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"cell_type": "markdown",
|
| 383 |
+
"id": "gv86tb0swwl",
|
| 384 |
+
"source": "## Evaluate: Offline + Training Curve\n\nEvaluate the trained model on collected prompts and plot the training reward curve.",
|
| 385 |
+
"metadata": {}
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "code",
|
| 389 |
+
"id": "it6zcy49jp9",
|
| 390 |
+
"source": [
|
| 391 |
+
"import matplotlib.pyplot as plt\nimport pandas as pd\n\nmodel.eval()\n\n# --- Post-training evaluation on same seeds as baseline ---\nprint(\"=== POST-TRAINING EVALUATION ===\\n\")\ntrained_results = []\nfor seed in [99, 42, 7]:\n print(f\"--- Seed {seed} ---\")\n res = evaluate_on_live_env(model, tokenizer, BASE_URL, seed=seed)\n trained_results.append(res)\n print(f\" Total: {res['total_reward']:.1f} | Actions: {len(res['actions'])}\\n\")\n\ntrained_avg = sum(r[\"total_reward\"] for r in trained_results) / len(trained_results)\nprint(f\"Trained average reward: {trained_avg:.1f}\")\nprint(f\"Baseline average reward: {baseline_avg:.1f}\")\nimprovement = ((trained_avg - baseline_avg) / max(baseline_avg, 0.1)) * 100\nprint(f\"Improvement: {improvement:+.1f}%\")\n\n# --- Plot 1: Before/After Comparison Bar Chart ---\nfig, axes = plt.subplots(1, 3, figsize=(16, 5))\n\n# Bar chart: per-seed comparison\nseeds = [99, 42, 7]\nbaseline_scores = [r[\"total_reward\"] for r in baseline_results]\ntrained_scores = [r[\"total_reward\"] for r in trained_results]\n\nx = range(len(seeds))\nwidth = 0.35\nbars1 = axes[0].bar([i - width/2 for i in x], baseline_scores, width, label=\"Before Training\", color=\"#d62728\", alpha=0.8)\nbars2 = axes[0].bar([i + width/2 for i in x], trained_scores, width, label=\"After Training\", color=\"#2ca02c\", alpha=0.8)\naxes[0].set_xlabel(\"Episode Seed\")\naxes[0].set_ylabel(\"Total Reward\")\naxes[0].set_title(\"Before vs After GRPO Training\")\naxes[0].set_xticks(list(x))\naxes[0].set_xticklabels([f\"Seed {s}\" for s in seeds])\naxes[0].legend()\naxes[0].grid(axis=\"y\", linestyle=\"--\", alpha=0.6)\n# Add value labels on bars\nfor bar in bars1:\n axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.5,\n f'{bar.get_height():.1f}', ha='center', va='bottom', fontsize=9, color=\"#d62728\")\nfor bar in bars2:\n axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.5,\n f'{bar.get_height():.1f}', ha='center', va='bottom', fontsize=9, color=\"#2ca02c\")\n\n# Bar chart: average comparison\naxes[1].bar([\"Untrained\\n(Baseline)\", \"GRPO\\nTrained\"], [baseline_avg, trained_avg],\n color=[\"#d62728\", \"#2ca02c\"], alpha=0.8, width=0.5)\naxes[1].set_ylabel(\"Average Reward\")\naxes[1].set_title(f\"Average Reward ({improvement:+.1f}% improvement)\")\naxes[1].grid(axis=\"y\", linestyle=\"--\", alpha=0.6)\naxes[1].text(0, baseline_avg + 0.5, f\"{baseline_avg:.1f}\", ha=\"center\", va=\"bottom\", fontweight=\"bold\")\naxes[1].text(1, trained_avg + 0.5, f\"{trained_avg:.1f}\", ha=\"center\", va=\"bottom\", fontweight=\"bold\")\n\n# Plot 2: Training reward curve\nhistory = pd.DataFrame(trainer.state.log_history)\nif \"rewards/reward_fn/mean\" in history.columns:\n reward_steps = history.dropna(subset=[\"rewards/reward_fn/mean\"])\n axes[2].plot(reward_steps[\"step\"], reward_steps[\"rewards/reward_fn/mean\"],\n label=\"Mean Reward\", color=\"#2ca02c\", linewidth=2)\n axes[2].fill_between(reward_steps[\"step\"],\n reward_steps[\"rewards/reward_fn/mean\"] - reward_steps[\"rewards/reward_fn/std\"],\n reward_steps[\"rewards/reward_fn/mean\"] + reward_steps[\"rewards/reward_fn/std\"],\n alpha=0.2, color=\"#2ca02c\")\n # Add baseline reference line\n axes[2].axhline(y=baseline_avg, color=\"#d62728\", linestyle=\"--\", linewidth=1.5, label=f\"Baseline ({baseline_avg:.1f})\")\n axes[2].set_xlabel(\"Training Steps\")\n axes[2].set_ylabel(\"Reward\")\n axes[2].set_title(\"GRPO Training Curve\")\n axes[2].legend()\n axes[2].grid(True, linestyle=\"--\", alpha=0.6)\nelse:\n axes[2].text(0.5, 0.5, \"No reward history\\n(run trainer.train() first)\",\n ha=\"center\", va=\"center\", transform=axes[2].transAxes, fontsize=12)\n axes[2].set_title(\"GRPO Training Curve\")\n\nplt.tight_layout()\nplt.savefig(\"crisisinbox_grpo_results.png\", dpi=150, bbox_inches=\"tight\")\nplt.show()\nprint(\"Results saved to crisisinbox_grpo_results.png\")"
|
| 392 |
+
],
|
| 393 |
+
"metadata": {},
|
| 394 |
+
"execution_count": null,
|
| 395 |
+
"outputs": []
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "markdown",
|
| 399 |
+
"id": "wob0szg0fv",
|
| 400 |
+
"source": "## Evaluate Against Live Environment\n\nRun the trained model in a closed loop against the actual CrisisInbox environment to get real server-side rewards.",
|
| 401 |
+
"metadata": {}
|
| 402 |
+
},
|
| 403 |
+
{
|
| 404 |
+
"cell_type": "code",
|
| 405 |
+
"id": "a5ih8db6zz",
|
| 406 |
+
"source": [
|
| 407 |
+
"# Additional live eval on a fresh seed (not used in baseline comparison)\nprint(\"=== Extra Live Evaluation (seed=123) ===\\n\")\nextra_result = evaluate_on_live_env(model, tokenizer, BASE_URL, seed=123, max_steps=25)\nprint(f\"\\nTotal reward: {extra_result['total_reward']:.1f}\")\nprint(f\"Actions taken: {len(extra_result['actions'])}\")\nprint(f\"Messages handled: {extra_result['final_status']['messages_handled']}\")"
|
| 408 |
+
],
|
| 409 |
+
"metadata": {},
|
| 410 |
+
"execution_count": null,
|
| 411 |
+
"outputs": []
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
"cell_type": "code",
|
| 415 |
+
"id": "g76r4wc8jx",
|
| 416 |
+
"source": "# Save the trained model\nmodel.save_pretrained(\"crisisinbox-grpo-trained\")\ntokenizer.save_pretrained(\"crisisinbox-grpo-trained\")\nprint(\"Model saved to crisisinbox-grpo-trained/\")",
|
| 417 |
+
"metadata": {},
|
| 418 |
+
"execution_count": null,
|
| 419 |
+
"outputs": []
|
| 420 |
+
}
|
| 421 |
+
],
|
| 422 |
+
"metadata": {
|
| 423 |
+
"kernelspec": {
|
| 424 |
+
"display_name": "Python 3",
|
| 425 |
+
"language": "python",
|
| 426 |
+
"name": "python3"
|
| 427 |
+
},
|
| 428 |
+
"language_info": {
|
| 429 |
+
"name": "python",
|
| 430 |
+
"version": "3.11.0"
|
| 431 |
+
}
|
| 432 |
+
},
|
| 433 |
+
"nbformat": 4,
|
| 434 |
+
"nbformat_minor": 5
|
| 435 |
+
}
|
notebooks/crisisinbox_grpo_connected.ipynb
CHANGED
|
@@ -120,7 +120,9 @@
|
|
| 120 |
{
|
| 121 |
"cell_type": "code",
|
| 122 |
"id": "nmmfb62s9ph",
|
| 123 |
-
"source":
|
|
|
|
|
|
|
| 124 |
"metadata": {},
|
| 125 |
"execution_count": null,
|
| 126 |
"outputs": []
|
|
|
|
| 120 |
{
|
| 121 |
"cell_type": "code",
|
| 122 |
"id": "nmmfb62s9ph",
|
| 123 |
+
"source": [
|
| 124 |
+
"# Collect multiple episodes from the live environment\nNUM_EPISODES = 10\nSEEDS = list(range(NUM_EPISODES))\n\nepisodes = []\nfor seed in SEEDS:\n print(f\"Collecting episode {seed + 1}/{NUM_EPISODES} (seed={seed})...\", end=\" \")\n for attempt in range(3):\n try:\n ep = collect_episode(BASE_URL, seed=seed)\n episodes.append(ep)\n print(f\"{len(ep['decision_points'])} decision points\")\n break\n except Exception as e:\n if attempt < 2:\n print(f\"retry {attempt + 1}...\", end=\" \")\n _time.sleep(5)\n else:\n print(f\"FAILED ({e}), skipping\")\n\n# Flatten to training prompts\nprompts = []\nfor ep in episodes:\n for dp in ep[\"decision_points\"]:\n prompts.append({\n \"prompt\": dp[\"prompt\"],\n \"hour\": dp[\"hour\"],\n \"visible_count\": dp[\"visible_count\"],\n \"episode_id\": ep[\"episode_id\"],\n \"seed\": ep[\"seed\"],\n \"drift_events\": ep[\"drift_events\"],\n \"superseded\": ep.get(\"superseded_messages\", {}),\n \"messages\": dp[\"messages\"],\n })\n\nprint(f\"\\nCollected {len(episodes)} episodes -> {len(prompts)} training prompts\")\nprint(f\"Average {len(prompts)/len(episodes):.1f} decision points per episode\")"
|
| 125 |
+
],
|
| 126 |
"metadata": {},
|
| 127 |
"execution_count": null,
|
| 128 |
"outputs": []
|
server/crisis_inbox_environment.py
CHANGED
|
@@ -73,6 +73,9 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 73 |
self._drift_events: list[DriftEvent] = []
|
| 74 |
self._fired_drifts: set[str] = set()
|
| 75 |
self._superseded: dict[str, str] = {} # old_msg_id -> new_msg_id
|
|
|
|
|
|
|
|
|
|
| 76 |
self._rng = random.Random()
|
| 77 |
|
| 78 |
@mcp.tool
|
|
@@ -99,6 +102,7 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 99 |
"read": msg.id in self._read_msgs,
|
| 100 |
"drift_flag": msg.drift_flag,
|
| 101 |
"superseded": is_superseded,
|
|
|
|
| 102 |
})
|
| 103 |
return json.dumps(summaries, indent=2)
|
| 104 |
|
|
@@ -170,6 +174,29 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 170 |
self._handled[message_id] = response
|
| 171 |
self._score += reward
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
# Advance time
|
| 174 |
self._advance_clock(0.25)
|
| 175 |
|
|
@@ -243,11 +270,12 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 243 |
status = "HANDLED" if msg.id in self._handled else "UNHANDLED"
|
| 244 |
drift = " [POLICY CHANGE]" if msg.drift_flag else ""
|
| 245 |
superseded = " [SUPERSEDED]" if msg.id in self._superseded else ""
|
|
|
|
| 246 |
deadline_str = f", deadline: hour {msg.deadline_hours}" if msg.deadline_hours else ""
|
| 247 |
lines.append(
|
| 248 |
f"[{status}] {msg.id} | {msg.urgency.value.upper()} | "
|
| 249 |
f"From: {msg.sender} via {msg.channel.value} | "
|
| 250 |
-
f"\"{msg.subject}\"{deadline_str}{drift}{superseded}"
|
| 251 |
)
|
| 252 |
lines.extend([
|
| 253 |
"=" * 60,
|
|
@@ -290,6 +318,7 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 290 |
self._current_hour = min(48.0, self._current_hour + hours)
|
| 291 |
self._deliver_messages()
|
| 292 |
self._fire_drift_events()
|
|
|
|
| 293 |
|
| 294 |
def _deliver_messages(self):
|
| 295 |
"""Make messages visible if their timestamp has been reached."""
|
|
@@ -298,6 +327,22 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 298 |
if not any(m.id == msg.id for m in self._visible_messages):
|
| 299 |
self._visible_messages.append(msg)
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
def _fire_drift_events(self):
|
| 302 |
"""Fire any drift events whose trigger time has been reached."""
|
| 303 |
for drift in self._drift_events:
|
|
@@ -354,6 +399,37 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 354 |
)
|
| 355 |
self._all_messages.append(m)
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
self._visible_messages = []
|
| 358 |
self._handled = {}
|
| 359 |
self._read_msgs = set()
|
|
|
|
| 73 |
self._drift_events: list[DriftEvent] = []
|
| 74 |
self._fired_drifts: set[str] = set()
|
| 75 |
self._superseded: dict[str, str] = {} # old_msg_id -> new_msg_id
|
| 76 |
+
self._escalation_map: dict[str, Message] = {} # parent_id -> escalation msg
|
| 77 |
+
self._reply_map: dict[str, Message] = {} # parent_id -> reply msg
|
| 78 |
+
self._conflict_pairs: dict[str, str] = {} # msg_id -> conflicting msg_id
|
| 79 |
self._rng = random.Random()
|
| 80 |
|
| 81 |
@mcp.tool
|
|
|
|
| 102 |
"read": msg.id in self._read_msgs,
|
| 103 |
"drift_flag": msg.drift_flag,
|
| 104 |
"superseded": is_superseded,
|
| 105 |
+
"conflicts_with": msg.conflicts_with,
|
| 106 |
})
|
| 107 |
return json.dumps(summaries, indent=2)
|
| 108 |
|
|
|
|
| 174 |
self._handled[message_id] = response
|
| 175 |
self._score += reward
|
| 176 |
|
| 177 |
+
# Conflict resolution: if this message conflicts with another,
|
| 178 |
+
# the conflicting message can no longer be handled (time conflict)
|
| 179 |
+
if msg.conflicts_with and msg.conflicts_with not in self._handled:
|
| 180 |
+
self._handled[msg.conflicts_with] = "[AUTO-EXPIRED: time conflict]"
|
| 181 |
+
|
| 182 |
+
# Multi-turn: if handling this message triggers a reply, inject it
|
| 183 |
+
if msg.reply_trigger and msg.reply_trigger in self._reply_map:
|
| 184 |
+
reply_msg = self._reply_map[msg.reply_trigger]
|
| 185 |
+
reply_msg.timestamp_hours = self._current_hour + 0.5
|
| 186 |
+
if reply_msg.deadline_hours is not None and reply_msg.deadline_hours == 0.0:
|
| 187 |
+
# Dynamic deadline based on message content hints
|
| 188 |
+
reply_msg.deadline_hours = self._current_hour + 6.0
|
| 189 |
+
if not any(m.id == reply_msg.id for m in self._all_messages):
|
| 190 |
+
self._all_messages.append(reply_msg)
|
| 191 |
+
|
| 192 |
+
# Escalation: if this message had an escalation, cancel it
|
| 193 |
+
# (handled in time, no need to escalate)
|
| 194 |
+
if message_id in self._escalation_map:
|
| 195 |
+
esc = self._escalation_map[message_id]
|
| 196 |
+
# Remove from all_messages so it never appears
|
| 197 |
+
self._all_messages = [m for m in self._all_messages if m.id != esc.id]
|
| 198 |
+
self._visible_messages = [m for m in self._visible_messages if m.id != esc.id]
|
| 199 |
+
|
| 200 |
# Advance time
|
| 201 |
self._advance_clock(0.25)
|
| 202 |
|
|
|
|
| 270 |
status = "HANDLED" if msg.id in self._handled else "UNHANDLED"
|
| 271 |
drift = " [POLICY CHANGE]" if msg.drift_flag else ""
|
| 272 |
superseded = " [SUPERSEDED]" if msg.id in self._superseded else ""
|
| 273 |
+
conflict = f" [CONFLICTS WITH {msg.conflicts_with}]" if msg.conflicts_with else ""
|
| 274 |
deadline_str = f", deadline: hour {msg.deadline_hours}" if msg.deadline_hours else ""
|
| 275 |
lines.append(
|
| 276 |
f"[{status}] {msg.id} | {msg.urgency.value.upper()} | "
|
| 277 |
f"From: {msg.sender} via {msg.channel.value} | "
|
| 278 |
+
f"\"{msg.subject}\"{deadline_str}{drift}{superseded}{conflict}"
|
| 279 |
)
|
| 280 |
lines.extend([
|
| 281 |
"=" * 60,
|
|
|
|
| 318 |
self._current_hour = min(48.0, self._current_hour + hours)
|
| 319 |
self._deliver_messages()
|
| 320 |
self._fire_drift_events()
|
| 321 |
+
self._fire_escalations()
|
| 322 |
|
| 323 |
def _deliver_messages(self):
|
| 324 |
"""Make messages visible if their timestamp has been reached."""
|
|
|
|
| 327 |
if not any(m.id == msg.id for m in self._visible_messages):
|
| 328 |
self._visible_messages.append(msg)
|
| 329 |
|
| 330 |
+
def _fire_escalations(self):
|
| 331 |
+
"""Inject escalation messages for unhandled messages past their deadline + delay."""
|
| 332 |
+
for parent_id, esc_msg in list(self._escalation_map.items()):
|
| 333 |
+
if parent_id in self._handled:
|
| 334 |
+
continue # Handled in time, no escalation
|
| 335 |
+
# Find the parent message to check deadline
|
| 336 |
+
parent = next((m for m in self._all_messages if m.id == parent_id), None)
|
| 337 |
+
if parent is None or parent.deadline_hours is None:
|
| 338 |
+
continue
|
| 339 |
+
trigger_hour = parent.deadline_hours + (parent.escalation_delay_hours or 0.0)
|
| 340 |
+
if self._current_hour >= trigger_hour:
|
| 341 |
+
# Inject escalation message if not already present
|
| 342 |
+
if not any(m.id == esc_msg.id for m in self._all_messages):
|
| 343 |
+
esc_msg.timestamp_hours = trigger_hour
|
| 344 |
+
self._all_messages.append(esc_msg)
|
| 345 |
+
|
| 346 |
def _fire_drift_events(self):
|
| 347 |
"""Fire any drift events whose trigger time has been reached."""
|
| 348 |
for drift in self._drift_events:
|
|
|
|
| 399 |
)
|
| 400 |
self._all_messages.append(m)
|
| 401 |
|
| 402 |
+
# Build escalation, reply, and conflict maps from loaded messages.
|
| 403 |
+
# Escalation and reply messages start outside the pool — they're injected
|
| 404 |
+
# dynamically when triggered.
|
| 405 |
+
self._escalation_map = {}
|
| 406 |
+
self._reply_map = {}
|
| 407 |
+
self._conflict_pairs = {}
|
| 408 |
+
|
| 409 |
+
# Collect IDs of escalation/reply targets so we can remove them from the pool
|
| 410 |
+
deferred_ids: set[str] = set()
|
| 411 |
+
for m in self._all_messages:
|
| 412 |
+
if m.escalation_trigger:
|
| 413 |
+
deferred_ids.add(m.escalation_trigger)
|
| 414 |
+
if m.reply_trigger:
|
| 415 |
+
deferred_ids.add(m.reply_trigger)
|
| 416 |
+
if m.conflicts_with:
|
| 417 |
+
self._conflict_pairs[m.id] = m.conflicts_with
|
| 418 |
+
|
| 419 |
+
# Pull deferred messages out of the pool into their maps
|
| 420 |
+
kept: list[Message] = []
|
| 421 |
+
for m in self._all_messages:
|
| 422 |
+
if m.id in deferred_ids:
|
| 423 |
+
# Find which parent references this
|
| 424 |
+
for parent in self._all_messages:
|
| 425 |
+
if parent.escalation_trigger == m.id:
|
| 426 |
+
self._escalation_map[parent.id] = m
|
| 427 |
+
if parent.reply_trigger == m.id:
|
| 428 |
+
self._reply_map[m.id] = m
|
| 429 |
+
else:
|
| 430 |
+
kept.append(m)
|
| 431 |
+
self._all_messages = kept
|
| 432 |
+
|
| 433 |
self._visible_messages = []
|
| 434 |
self._handled = {}
|
| 435 |
self._read_msgs = set()
|
server/rewards.py
CHANGED
|
@@ -96,6 +96,10 @@ def calculate_reward(
|
|
| 96 |
if msg.id in superseded:
|
| 97 |
reward *= 0.5
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# Priority penalty: choosing low/medium when unhandled critical messages exist
|
| 100 |
if visible_messages and handled is not None:
|
| 101 |
has_unhandled_critical = any(
|
|
|
|
| 96 |
if msg.id in superseded:
|
| 97 |
reward *= 0.5
|
| 98 |
|
| 99 |
+
# Conflict-resolution bonus: handling a message that forces a trade-off
|
| 100 |
+
if msg.conflicts_with:
|
| 101 |
+
reward *= 1.25
|
| 102 |
+
|
| 103 |
# Priority penalty: choosing low/medium when unhandled critical messages exist
|
| 104 |
if visible_messages and handled is not None:
|
| 105 |
has_unhandled_critical = any(
|