jayantaggarwal-sketch commited on
Commit
af8810b
·
1 Parent(s): 9318eea

Sync latest code and non-binary artifacts

Browse files
artifacts/training_metrics.json ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "loss": 0.6357966065406799,
4
+ "grad_norm": 0.5020767450332642,
5
+ "learning_rate": 0.0,
6
+ "num_tokens": 2584.0,
7
+ "completions/mean_length": 200.0,
8
+ "completions/min_length": 20.0,
9
+ "completions/max_length": 380.0,
10
+ "completions/clipped_ratio": 0.0,
11
+ "completions/mean_terminated_length": 200.0,
12
+ "completions/min_terminated_length": 20.0,
13
+ "completions/max_terminated_length": 380.0,
14
+ "rewards/reward_function/mean": 0.574999988079071,
15
+ "rewards/reward_function/std": 0.10606600344181061,
16
+ "reward": 0.574999988079071,
17
+ "reward_std": 0.10606600344181061,
18
+ "frac_reward_zero_std": 0.0,
19
+ "entropy": 0.26839178800582886,
20
+ "clip_ratio/low_mean": 0.0,
21
+ "clip_ratio/low_min": 0.0,
22
+ "clip_ratio/high_mean": 0.0,
23
+ "clip_ratio/high_max": 0.0,
24
+ "clip_ratio/region_mean": 0.0,
25
+ "step_time": 61.83920078600022,
26
+ "epoch": 0.06666666666666667,
27
+ "step": 1
28
+ },
29
+ {
30
+ "loss": -0.24778124690055847,
31
+ "grad_norm": 1.877001166343689,
32
+ "learning_rate": 1.6666666666666667e-06,
33
+ "num_tokens": 4297.0,
34
+ "completions/mean_length": 35.5,
35
+ "completions/min_length": 23.0,
36
+ "completions/max_length": 48.0,
37
+ "completions/clipped_ratio": 0.0,
38
+ "completions/mean_terminated_length": 35.5,
39
+ "completions/min_terminated_length": 23.0,
40
+ "completions/max_terminated_length": 48.0,
41
+ "rewards/reward_function/mean": 0.40209999680519104,
42
+ "rewards/reward_function/std": 0.020647529512643814,
43
+ "reward": 0.40209999680519104,
44
+ "reward_std": 0.020647529512643814,
45
+ "frac_reward_zero_std": 0.0,
46
+ "entropy": 0.29453833028674126,
47
+ "clip_ratio/low_mean": 0.0,
48
+ "clip_ratio/low_min": 0.0,
49
+ "clip_ratio/high_mean": 0.0,
50
+ "clip_ratio/high_max": 0.0,
51
+ "clip_ratio/region_mean": 0.0,
52
+ "step_time": 14.531932316999928,
53
+ "epoch": 0.13333333333333333,
54
+ "step": 2
55
+ },
56
+ {
57
+ "loss": -0.07843422889709473,
58
+ "grad_norm": 1.543445110321045,
59
+ "learning_rate": 3.3333333333333333e-06,
60
+ "num_tokens": 5871.0,
61
+ "completions/mean_length": 18.0,
62
+ "completions/min_length": 16.0,
63
+ "completions/max_length": 20.0,
64
+ "completions/clipped_ratio": 0.0,
65
+ "completions/mean_terminated_length": 18.0,
66
+ "completions/min_terminated_length": 16.0,
67
+ "completions/max_terminated_length": 20.0,
68
+ "rewards/reward_function/mean": 0.4583500027656555,
69
+ "rewards/reward_function/std": 0.058901991695165634,
70
+ "reward": 0.4583500027656555,
71
+ "reward_std": 0.058901991695165634,
72
+ "frac_reward_zero_std": 0.0,
73
+ "entropy": 0.1914939135313034,
74
+ "clip_ratio/low_mean": 0.0,
75
+ "clip_ratio/low_min": 0.0,
76
+ "clip_ratio/high_mean": 0.0,
77
+ "clip_ratio/high_max": 0.0,
78
+ "clip_ratio/region_mean": 0.0,
79
+ "step_time": 10.524165547000166,
80
+ "epoch": 0.2,
81
+ "step": 3
82
+ },
83
+ {
84
+ "loss": 0.07059085369110107,
85
+ "grad_norm": 1.5465283393859863,
86
+ "learning_rate": 5e-06,
87
+ "num_tokens": 7485.0,
88
+ "completions/mean_length": 20.0,
89
+ "completions/min_length": 18.0,
90
+ "completions/max_length": 22.0,
91
+ "completions/clipped_ratio": 0.0,
92
+ "completions/mean_terminated_length": 20.0,
93
+ "completions/min_terminated_length": 18.0,
94
+ "completions/max_terminated_length": 22.0,
95
+ "rewards/reward_function/mean": 0.4583500027656555,
96
+ "rewards/reward_function/std": 0.058901991695165634,
97
+ "reward": 0.4583500027656555,
98
+ "reward_std": 0.058901991695165634,
99
+ "frac_reward_zero_std": 0.0,
100
+ "entropy": 0.1294238492846489,
101
+ "clip_ratio/low_mean": 0.0,
102
+ "clip_ratio/low_min": 0.0,
103
+ "clip_ratio/high_mean": 0.0,
104
+ "clip_ratio/high_max": 0.0,
105
+ "clip_ratio/region_mean": 0.0,
106
+ "step_time": 10.990043340000057,
107
+ "epoch": 0.26666666666666666,
108
+ "step": 4
109
+ },
110
+ {
111
+ "loss": -0.008725225925445557,
112
+ "grad_norm": 1.369038462638855,
113
+ "learning_rate": 4.814814814814815e-06,
114
+ "num_tokens": 8692.0,
115
+ "completions/mean_length": 40.5,
116
+ "completions/min_length": 40.0,
117
+ "completions/max_length": 41.0,
118
+ "completions/clipped_ratio": 0.0,
119
+ "completions/mean_terminated_length": 40.5,
120
+ "completions/min_terminated_length": 40.0,
121
+ "completions/max_terminated_length": 41.0,
122
+ "rewards/reward_function/mean": 0.5187499523162842,
123
+ "rewards/reward_function/std": 0.18561552464962006,
124
+ "reward": 0.5187499523162842,
125
+ "reward_std": 0.18561550974845886,
126
+ "frac_reward_zero_std": 0.0,
127
+ "entropy": 0.19958198070526123,
128
+ "clip_ratio/low_mean": 0.0,
129
+ "clip_ratio/low_min": 0.0,
130
+ "clip_ratio/high_mean": 0.0,
131
+ "clip_ratio/high_max": 0.0,
132
+ "clip_ratio/region_mean": 0.0,
133
+ "step_time": 11.131567607000306,
134
+ "epoch": 0.3333333333333333,
135
+ "step": 5
136
+ },
137
+ {
138
+ "loss": -0.07062190771102905,
139
+ "grad_norm": 0.8509910702705383,
140
+ "learning_rate": 4.62962962962963e-06,
141
+ "num_tokens": 9862.0,
142
+ "completions/mean_length": 40.0,
143
+ "completions/min_length": 36.0,
144
+ "completions/max_length": 44.0,
145
+ "completions/clipped_ratio": 0.0,
146
+ "completions/mean_terminated_length": 40.0,
147
+ "completions/min_terminated_length": 36.0,
148
+ "completions/max_terminated_length": 44.0,
149
+ "rewards/reward_function/mean": 0.4437499940395355,
150
+ "rewards/reward_function/std": 0.07954952120780945,
151
+ "reward": 0.4437499940395355,
152
+ "reward_std": 0.07954952120780945,
153
+ "frac_reward_zero_std": 0.0,
154
+ "entropy": 0.1297583170235157,
155
+ "clip_ratio/low_mean": 0.0,
156
+ "clip_ratio/low_min": 0.0,
157
+ "clip_ratio/high_mean": 0.0,
158
+ "clip_ratio/high_max": 0.0,
159
+ "clip_ratio/region_mean": 0.0,
160
+ "step_time": 11.613226033000046,
161
+ "epoch": 0.4,
162
+ "step": 6
163
+ },
164
+ {
165
+ "loss": -5.960464477539062e-07,
166
+ "grad_norm": 0.1141546443104744,
167
+ "learning_rate": 4.444444444444444e-06,
168
+ "num_tokens": 11000.0,
169
+ "completions/mean_length": 19.0,
170
+ "completions/min_length": 19.0,
171
+ "completions/max_length": 19.0,
172
+ "completions/clipped_ratio": 0.0,
173
+ "completions/mean_terminated_length": 19.0,
174
+ "completions/min_terminated_length": 19.0,
175
+ "completions/max_terminated_length": 19.0,
176
+ "rewards/reward_function/mean": 0.5349999666213989,
177
+ "rewards/reward_function/std": 0.04949747025966644,
178
+ "reward": 0.5349999666213989,
179
+ "reward_std": 0.04949747025966644,
180
+ "frac_reward_zero_std": 0.0,
181
+ "entropy": 0.07411494851112366,
182
+ "clip_ratio/low_mean": 0.0,
183
+ "clip_ratio/low_min": 0.0,
184
+ "clip_ratio/high_mean": 0.0,
185
+ "clip_ratio/high_max": 0.0,
186
+ "clip_ratio/region_mean": 0.0,
187
+ "step_time": 8.378681117999804,
188
+ "epoch": 0.4666666666666667,
189
+ "step": 7
190
+ },
191
+ {
192
+ "loss": -0.5384407639503479,
193
+ "grad_norm": 0.575139045715332,
194
+ "learning_rate": 4.2592592592592596e-06,
195
+ "num_tokens": 12808.0,
196
+ "completions/mean_length": 119.0,
197
+ "completions/min_length": 28.0,
198
+ "completions/max_length": 210.0,
199
+ "completions/clipped_ratio": 0.0,
200
+ "completions/mean_terminated_length": 119.0,
201
+ "completions/min_terminated_length": 28.0,
202
+ "completions/max_terminated_length": 210.0,
203
+ "rewards/reward_function/mean": 0.6333500146865845,
204
+ "rewards/reward_function/std": 0.023546643555164337,
205
+ "reward": 0.6333500146865845,
206
+ "reward_std": 0.023546643555164337,
207
+ "frac_reward_zero_std": 0.0,
208
+ "entropy": 0.2676837705075741,
209
+ "clip_ratio/low_mean": 0.0,
210
+ "clip_ratio/low_min": 0.0,
211
+ "clip_ratio/high_mean": 0.0,
212
+ "clip_ratio/high_max": 0.0,
213
+ "clip_ratio/region_mean": 0.0,
214
+ "step_time": 37.1004951179998,
215
+ "epoch": 0.5333333333333333,
216
+ "step": 8
217
+ },
218
+ {
219
+ "loss": 0.0,
220
+ "grad_norm": 0.14882154762744904,
221
+ "learning_rate": 4.074074074074074e-06,
222
+ "num_tokens": 14288.0,
223
+ "completions/mean_length": 22.0,
224
+ "completions/min_length": 22.0,
225
+ "completions/max_length": 22.0,
226
+ "completions/clipped_ratio": 0.0,
227
+ "completions/mean_terminated_length": 22.0,
228
+ "completions/min_terminated_length": 22.0,
229
+ "completions/max_terminated_length": 22.0,
230
+ "rewards/reward_function/mean": 0.4583500027656555,
231
+ "rewards/reward_function/std": 0.058901991695165634,
232
+ "reward": 0.4583500027656555,
233
+ "reward_std": 0.058901991695165634,
234
+ "frac_reward_zero_std": 0.0,
235
+ "entropy": 0.0809866338968277,
236
+ "clip_ratio/low_mean": 0.0,
237
+ "clip_ratio/low_min": 0.0,
238
+ "clip_ratio/high_mean": 0.0,
239
+ "clip_ratio/high_max": 0.0,
240
+ "clip_ratio/region_mean": 0.0,
241
+ "step_time": 10.457592102000262,
242
+ "epoch": 0.6,
243
+ "step": 9
244
+ },
245
+ {
246
+ "loss": -0.07851982116699219,
247
+ "grad_norm": 1.6114908456802368,
248
+ "learning_rate": 3.88888888888889e-06,
249
+ "num_tokens": 16384.0,
250
+ "completions/mean_length": 27.0,
251
+ "completions/min_length": 24.0,
252
+ "completions/max_length": 30.0,
253
+ "completions/clipped_ratio": 0.0,
254
+ "completions/mean_terminated_length": 27.0,
255
+ "completions/min_terminated_length": 24.0,
256
+ "completions/max_terminated_length": 30.0,
257
+ "rewards/reward_function/mean": 0.5333499908447266,
258
+ "rewards/reward_function/std": 0.16496798396110535,
259
+ "reward": 0.5333499908447266,
260
+ "reward_std": 0.16496798396110535,
261
+ "frac_reward_zero_std": 0.0,
262
+ "entropy": 0.4823211133480072,
263
+ "clip_ratio/low_mean": 0.0,
264
+ "clip_ratio/low_min": 0.0,
265
+ "clip_ratio/high_mean": 0.0,
266
+ "clip_ratio/high_max": 0.0,
267
+ "clip_ratio/region_mean": 0.0,
268
+ "step_time": 14.966705318999857,
269
+ "epoch": 0.6666666666666666,
270
+ "step": 10
271
+ },
272
+ {
273
+ "loss": 0.0,
274
+ "grad_norm": 0.0,
275
+ "learning_rate": 3.7037037037037037e-06,
276
+ "num_tokens": 17617.0,
277
+ "completions/mean_length": 35.5,
278
+ "completions/min_length": 16.0,
279
+ "completions/max_length": 55.0,
280
+ "completions/clipped_ratio": 0.0,
281
+ "completions/mean_terminated_length": 35.5,
282
+ "completions/min_terminated_length": 16.0,
283
+ "completions/max_terminated_length": 55.0,
284
+ "rewards/reward_function/mean": 0.41670000553131104,
285
+ "rewards/reward_function/std": 0.0,
286
+ "reward": 0.41670000553131104,
287
+ "reward_std": 0.0,
288
+ "frac_reward_zero_std": 1.0,
289
+ "entropy": 0.9841015487909317,
290
+ "clip_ratio/low_mean": 0.0,
291
+ "clip_ratio/low_min": 0.0,
292
+ "clip_ratio/high_mean": 0.0,
293
+ "clip_ratio/high_max": 0.0,
294
+ "clip_ratio/region_mean": 0.0,
295
+ "step_time": 13.133867920999819,
296
+ "epoch": 0.7333333333333333,
297
+ "step": 11
298
+ },
299
+ {
300
+ "loss": -0.05432739853858948,
301
+ "grad_norm": 1.1030373573303223,
302
+ "learning_rate": 3.5185185185185187e-06,
303
+ "num_tokens": 19429.0,
304
+ "completions/mean_length": 39.0,
305
+ "completions/min_length": 36.0,
306
+ "completions/max_length": 42.0,
307
+ "completions/clipped_ratio": 0.0,
308
+ "completions/mean_terminated_length": 39.0,
309
+ "completions/min_terminated_length": 36.0,
310
+ "completions/max_terminated_length": 42.0,
311
+ "rewards/reward_function/mean": 0.5583499670028687,
312
+ "rewards/reward_function/std": 0.08251935988664627,
313
+ "reward": 0.5583499670028687,
314
+ "reward_std": 0.08251935988664627,
315
+ "frac_reward_zero_std": 0.0,
316
+ "entropy": 0.12238830700516701,
317
+ "clip_ratio/low_mean": 0.0,
318
+ "clip_ratio/low_min": 0.0,
319
+ "clip_ratio/high_mean": 0.0,
320
+ "clip_ratio/high_max": 0.0,
321
+ "clip_ratio/region_mean": 0.0,
322
+ "step_time": 15.016316874000267,
323
+ "epoch": 0.8,
324
+ "step": 12
325
+ },
326
+ {
327
+ "loss": 0.149135559797287,
328
+ "grad_norm": 1.254807472229004,
329
+ "learning_rate": 3.3333333333333333e-06,
330
+ "num_tokens": 20626.0,
331
+ "completions/mean_length": 35.5,
332
+ "completions/min_length": 28.0,
333
+ "completions/max_length": 43.0,
334
+ "completions/clipped_ratio": 0.0,
335
+ "completions/mean_terminated_length": 35.5,
336
+ "completions/min_terminated_length": 28.0,
337
+ "completions/max_terminated_length": 43.0,
338
+ "rewards/reward_function/mean": 0.4583500027656555,
339
+ "rewards/reward_function/std": 0.058901991695165634,
340
+ "reward": 0.4583500027656555,
341
+ "reward_std": 0.058901991695165634,
342
+ "frac_reward_zero_std": 0.0,
343
+ "entropy": 0.2583826147019863,
344
+ "clip_ratio/low_mean": 0.0,
345
+ "clip_ratio/low_min": 0.0,
346
+ "clip_ratio/high_mean": 0.0,
347
+ "clip_ratio/high_max": 0.0,
348
+ "clip_ratio/region_mean": 0.0,
349
+ "step_time": 11.519657740999946,
350
+ "epoch": 0.8666666666666667,
351
+ "step": 13
352
+ },
353
+ {
354
+ "loss": 0.27955153584480286,
355
+ "grad_norm": 1.167170763015747,
356
+ "learning_rate": 3.1481481481481483e-06,
357
+ "num_tokens": 22067.0,
358
+ "completions/mean_length": 36.5,
359
+ "completions/min_length": 22.0,
360
+ "completions/max_length": 51.0,
361
+ "completions/clipped_ratio": 0.0,
362
+ "completions/mean_terminated_length": 36.5,
363
+ "completions/min_terminated_length": 22.0,
364
+ "completions/max_terminated_length": 51.0,
365
+ "rewards/reward_function/mean": 0.40209999680519104,
366
+ "rewards/reward_function/std": 0.020647529512643814,
367
+ "reward": 0.40209999680519104,
368
+ "reward_std": 0.020647529512643814,
369
+ "frac_reward_zero_std": 0.0,
370
+ "entropy": 0.24101658910512924,
371
+ "clip_ratio/low_mean": 0.0,
372
+ "clip_ratio/low_min": 0.0,
373
+ "clip_ratio/high_mean": 0.0,
374
+ "clip_ratio/high_max": 0.0,
375
+ "clip_ratio/region_mean": 0.0,
376
+ "step_time": 14.053339626000025,
377
+ "epoch": 0.9333333333333333,
378
+ "step": 14
379
+ },
380
+ {
381
+ "loss": 0.06053304672241211,
382
+ "grad_norm": 2.425391435623169,
383
+ "learning_rate": 2.962962962962963e-06,
384
+ "num_tokens": 23691.0,
385
+ "completions/mean_length": 35.0,
386
+ "completions/min_length": 32.0,
387
+ "completions/max_length": 38.0,
388
+ "completions/clipped_ratio": 0.0,
389
+ "completions/mean_terminated_length": 35.0,
390
+ "completions/min_terminated_length": 32.0,
391
+ "completions/max_terminated_length": 38.0,
392
+ "rewards/reward_function/mean": 0.4437499940395355,
393
+ "rewards/reward_function/std": 0.07954952120780945,
394
+ "reward": 0.4437499940395355,
395
+ "reward_std": 0.07954952120780945,
396
+ "frac_reward_zero_std": 0.0,
397
+ "entropy": 0.25765860080718994,
398
+ "clip_ratio/low_mean": 0.0,
399
+ "clip_ratio/low_min": 0.0,
400
+ "clip_ratio/high_mean": 0.0,
401
+ "clip_ratio/high_max": 0.0,
402
+ "clip_ratio/region_mean": 0.0,
403
+ "step_time": 13.23964212199985,
404
+ "epoch": 1.0,
405
+ "step": 15
406
+ },
407
+ {
408
+ "loss": 0.0,
409
+ "grad_norm": 0.0,
410
+ "learning_rate": 2.7777777777777783e-06,
411
+ "num_tokens": 24930.0,
412
+ "completions/mean_length": 56.5,
413
+ "completions/min_length": 55.0,
414
+ "completions/max_length": 58.0,
415
+ "completions/clipped_ratio": 0.0,
416
+ "completions/mean_terminated_length": 56.5,
417
+ "completions/min_terminated_length": 55.0,
418
+ "completions/max_terminated_length": 58.0,
419
+ "rewards/reward_function/mean": 0.5,
420
+ "rewards/reward_function/std": 0.0,
421
+ "reward": 0.5,
422
+ "reward_std": 0.0,
423
+ "frac_reward_zero_std": 1.0,
424
+ "entropy": 0.4330967664718628,
425
+ "clip_ratio/low_mean": 0.0,
426
+ "clip_ratio/low_min": 0.0,
427
+ "clip_ratio/high_mean": 0.0,
428
+ "clip_ratio/high_max": 0.0,
429
+ "clip_ratio/region_mean": 0.0,
430
+ "step_time": 13.377671128000202,
431
+ "epoch": 1.0666666666666667,
432
+ "step": 16
433
+ },
434
+ {
435
+ "loss": 0.0672960877418518,
436
+ "grad_norm": 1.6053359508514404,
437
+ "learning_rate": 2.5925925925925925e-06,
438
+ "num_tokens": 26072.0,
439
+ "completions/mean_length": 21.0,
440
+ "completions/min_length": 19.0,
441
+ "completions/max_length": 23.0,
442
+ "completions/clipped_ratio": 0.0,
443
+ "completions/mean_terminated_length": 21.0,
444
+ "completions/min_terminated_length": 19.0,
445
+ "completions/max_terminated_length": 23.0,
446
+ "rewards/reward_function/mean": 0.516700029373169,
447
+ "rewards/reward_function/std": 0.1414213478565216,
448
+ "reward": 0.516700029373169,
449
+ "reward_std": 0.1414213478565216,
450
+ "frac_reward_zero_std": 0.0,
451
+ "entropy": 0.06669686548411846,
452
+ "clip_ratio/low_mean": 0.0,
453
+ "clip_ratio/low_min": 0.0,
454
+ "clip_ratio/high_mean": 0.0,
455
+ "clip_ratio/high_max": 0.0,
456
+ "clip_ratio/region_mean": 0.0,
457
+ "step_time": 8.82481953599995,
458
+ "epoch": 1.1333333333333333,
459
+ "step": 17
460
+ },
461
+ {
462
+ "loss": -0.13629436492919922,
463
+ "grad_norm": 1.381037950515747,
464
+ "learning_rate": 2.4074074074074075e-06,
465
+ "num_tokens": 27771.0,
466
+ "completions/mean_length": 28.5,
467
+ "completions/min_length": 23.0,
468
+ "completions/max_length": 34.0,
469
+ "completions/clipped_ratio": 0.0,
470
+ "completions/mean_terminated_length": 28.5,
471
+ "completions/min_terminated_length": 23.0,
472
+ "completions/max_terminated_length": 34.0,
473
+ "rewards/reward_function/mean": 0.5583499670028687,
474
+ "rewards/reward_function/std": 0.08251935988664627,
475
+ "reward": 0.5583499670028687,
476
+ "reward_std": 0.08251935988664627,
477
+ "frac_reward_zero_std": 0.0,
478
+ "entropy": 0.12161804735660553,
479
+ "clip_ratio/low_mean": 0.0,
480
+ "clip_ratio/low_min": 0.0,
481
+ "clip_ratio/high_mean": 0.0,
482
+ "clip_ratio/high_max": 0.0,
483
+ "clip_ratio/region_mean": 0.0,
484
+ "step_time": 13.361655292000023,
485
+ "epoch": 1.2,
486
+ "step": 18
487
+ },
488
+ {
489
+ "loss": -0.1279437243938446,
490
+ "grad_norm": 2.226921558380127,
491
+ "learning_rate": 2.222222222222222e-06,
492
+ "num_tokens": 29251.0,
493
+ "completions/mean_length": 22.0,
494
+ "completions/min_length": 18.0,
495
+ "completions/max_length": 26.0,
496
+ "completions/clipped_ratio": 0.0,
497
+ "completions/mean_terminated_length": 22.0,
498
+ "completions/min_terminated_length": 18.0,
499
+ "completions/max_terminated_length": 26.0,
500
+ "rewards/reward_function/mean": 0.6021000146865845,
501
+ "rewards/reward_function/std": 0.020647529512643814,
502
+ "reward": 0.6021000146865845,
503
+ "reward_std": 0.020647529512643814,
504
+ "frac_reward_zero_std": 0.0,
505
+ "entropy": 0.08564786985516548,
506
+ "clip_ratio/low_mean": 0.0,
507
+ "clip_ratio/low_min": 0.0,
508
+ "clip_ratio/high_mean": 0.0,
509
+ "clip_ratio/high_max": 0.0,
510
+ "clip_ratio/region_mean": 0.0,
511
+ "step_time": 11.087925465999888,
512
+ "epoch": 1.2666666666666666,
513
+ "step": 19
514
+ },
515
+ {
516
+ "loss": 0.06728780269622803,
517
+ "grad_norm": 1.0975555181503296,
518
+ "learning_rate": 2.037037037037037e-06,
519
+ "num_tokens": 30425.0,
520
+ "completions/mean_length": 42.0,
521
+ "completions/min_length": 38.0,
522
+ "completions/max_length": 46.0,
523
+ "completions/clipped_ratio": 0.0,
524
+ "completions/mean_terminated_length": 42.0,
525
+ "completions/min_terminated_length": 38.0,
526
+ "completions/max_terminated_length": 46.0,
527
+ "rewards/reward_function/mean": 0.5020999908447266,
528
+ "rewards/reward_function/std": 0.1207738146185875,
529
+ "reward": 0.5020999908447266,
530
+ "reward_std": 0.1207738146185875,
531
+ "frac_reward_zero_std": 0.0,
532
+ "entropy": 0.16879020631313324,
533
+ "clip_ratio/low_mean": 0.0,
534
+ "clip_ratio/low_min": 0.0,
535
+ "clip_ratio/high_mean": 0.0,
536
+ "clip_ratio/high_max": 0.0,
537
+ "clip_ratio/region_mean": 0.0,
538
+ "step_time": 11.843729876999987,
539
+ "epoch": 1.3333333333333333,
540
+ "step": 20
541
+ },
542
+ {
543
+ "loss": 0.054305046796798706,
544
+ "grad_norm": 1.0439469814300537,
545
+ "learning_rate": 1.8518518518518519e-06,
546
+ "num_tokens": 32047.0,
547
+ "completions/mean_length": 26.0,
548
+ "completions/min_length": 24.0,
549
+ "completions/max_length": 28.0,
550
+ "completions/clipped_ratio": 0.0,
551
+ "completions/mean_terminated_length": 26.0,
552
+ "completions/min_terminated_length": 24.0,
553
+ "completions/max_terminated_length": 28.0,
554
+ "rewards/reward_function/mean": 0.543749988079071,
555
+ "rewards/reward_function/std": 0.06187182664871216,
556
+ "reward": 0.543749988079071,
557
+ "reward_std": 0.06187182664871216,
558
+ "frac_reward_zero_std": 0.0,
559
+ "entropy": 0.09812109172344208,
560
+ "clip_ratio/low_mean": 0.0,
561
+ "clip_ratio/low_min": 0.0,
562
+ "clip_ratio/high_mean": 0.0,
563
+ "clip_ratio/high_max": 0.0,
564
+ "clip_ratio/region_mean": 0.0,
565
+ "step_time": 12.01839622600005,
566
+ "epoch": 1.4,
567
+ "step": 21
568
+ },
569
+ {
570
+ "loss": -0.43347233533859253,
571
+ "grad_norm": 0.6157990097999573,
572
+ "learning_rate": 1.6666666666666667e-06,
573
+ "num_tokens": 33354.0,
574
+ "completions/mean_length": 72.5,
575
+ "completions/min_length": 28.0,
576
+ "completions/max_length": 117.0,
577
+ "completions/clipped_ratio": 0.0,
578
+ "completions/mean_terminated_length": 72.5,
579
+ "completions/min_terminated_length": 28.0,
580
+ "completions/max_terminated_length": 117.0,
581
+ "rewards/reward_function/mean": 0.4437499940395355,
582
+ "rewards/reward_function/std": 0.07954952120780945,
583
+ "reward": 0.4437499940395355,
584
+ "reward_std": 0.07954952120780945,
585
+ "frac_reward_zero_std": 0.0,
586
+ "entropy": 0.16432299464941025,
587
+ "clip_ratio/low_mean": 0.0,
588
+ "clip_ratio/low_min": 0.0,
589
+ "clip_ratio/high_mean": 0.0,
590
+ "clip_ratio/high_max": 0.0,
591
+ "clip_ratio/region_mean": 0.0,
592
+ "step_time": 21.621784166999987,
593
+ "epoch": 1.4666666666666668,
594
+ "step": 22
595
+ },
596
+ {
597
+ "loss": 0.21844279766082764,
598
+ "grad_norm": 1.18323814868927,
599
+ "learning_rate": 1.4814814814814815e-06,
600
+ "num_tokens": 35143.0,
601
+ "completions/mean_length": 27.5,
602
+ "completions/min_length": 19.0,
603
+ "completions/max_length": 36.0,
604
+ "completions/clipped_ratio": 0.0,
605
+ "completions/mean_terminated_length": 27.5,
606
+ "completions/min_terminated_length": 19.0,
607
+ "completions/max_terminated_length": 36.0,
608
+ "rewards/reward_function/mean": 0.6312500238418579,
609
+ "rewards/reward_function/std": 0.18561552464962006,
610
+ "reward": 0.6312500238418579,
611
+ "reward_std": 0.18561550974845886,
612
+ "frac_reward_zero_std": 0.0,
613
+ "entropy": 0.10744666680693626,
614
+ "clip_ratio/low_mean": 0.0,
615
+ "clip_ratio/low_min": 0.0,
616
+ "clip_ratio/high_mean": 0.0,
617
+ "clip_ratio/high_max": 0.0,
618
+ "clip_ratio/region_mean": 0.0,
619
+ "step_time": 14.225728247999996,
620
+ "epoch": 1.5333333333333332,
621
+ "step": 23
622
+ },
623
+ {
624
+ "loss": -0.6148233413696289,
625
+ "grad_norm": 0.8197247982025146,
626
+ "learning_rate": 1.2962962962962962e-06,
627
+ "num_tokens": 36929.0,
628
+ "completions/mean_length": 124.0,
629
+ "completions/min_length": 16.0,
630
+ "completions/max_length": 232.0,
631
+ "completions/clipped_ratio": 0.0,
632
+ "completions/mean_terminated_length": 124.0,
633
+ "completions/min_terminated_length": 16.0,
634
+ "completions/max_terminated_length": 232.0,
635
+ "rewards/reward_function/mean": 0.4583500027656555,
636
+ "rewards/reward_function/std": 0.058901991695165634,
637
+ "reward": 0.4583500027656555,
638
+ "reward_std": 0.058901991695165634,
639
+ "frac_reward_zero_std": 0.0,
640
+ "entropy": 0.6331216096878052,
641
+ "clip_ratio/low_mean": 0.0,
642
+ "clip_ratio/low_min": 0.0,
643
+ "clip_ratio/high_mean": 0.0,
644
+ "clip_ratio/high_max": 0.0,
645
+ "clip_ratio/region_mean": 0.0,
646
+ "step_time": 39.883540922000066,
647
+ "epoch": 1.6,
648
+ "step": 24
649
+ },
650
+ {
651
+ "loss": -0.054349154233932495,
652
+ "grad_norm": 1.030266284942627,
653
+ "learning_rate": 1.111111111111111e-06,
654
+ "num_tokens": 39023.0,
655
+ "completions/mean_length": 26.0,
656
+ "completions/min_length": 24.0,
657
+ "completions/max_length": 28.0,
658
+ "completions/clipped_ratio": 0.0,
659
+ "completions/mean_terminated_length": 26.0,
660
+ "completions/min_terminated_length": 24.0,
661
+ "completions/max_terminated_length": 28.0,
662
+ "rewards/reward_function/mean": 0.6749999523162842,
663
+ "rewards/reward_function/std": 0.1237436980009079,
664
+ "reward": 0.6749999523162842,
665
+ "reward_std": 0.1237436980009079,
666
+ "frac_reward_zero_std": 0.0,
667
+ "entropy": 0.07906700298190117,
668
+ "clip_ratio/low_mean": 0.0,
669
+ "clip_ratio/low_min": 0.0,
670
+ "clip_ratio/high_mean": 0.0,
671
+ "clip_ratio/high_max": 0.0,
672
+ "clip_ratio/region_mean": 0.0,
673
+ "step_time": 14.697501995000039,
674
+ "epoch": 1.6666666666666665,
675
+ "step": 25
676
+ },
677
+ {
678
+ "loss": 0.0,
679
+ "grad_norm": 0.0,
680
+ "learning_rate": 9.259259259259259e-07,
681
+ "num_tokens": 40211.0,
682
+ "completions/mean_length": 31.0,
683
+ "completions/min_length": 22.0,
684
+ "completions/max_length": 40.0,
685
+ "completions/clipped_ratio": 0.0,
686
+ "completions/mean_terminated_length": 31.0,
687
+ "completions/min_terminated_length": 22.0,
688
+ "completions/max_terminated_length": 40.0,
689
+ "rewards/reward_function/mean": 0.6499999761581421,
690
+ "rewards/reward_function/std": 0.0,
691
+ "reward": 0.6499999761581421,
692
+ "reward_std": 0.0,
693
+ "frac_reward_zero_std": 1.0,
694
+ "entropy": 0.14391540735960007,
695
+ "clip_ratio/low_mean": 0.0,
696
+ "clip_ratio/low_min": 0.0,
697
+ "clip_ratio/high_mean": 0.0,
698
+ "clip_ratio/high_max": 0.0,
699
+ "clip_ratio/region_mean": 0.0,
700
+ "step_time": 11.100019781000128,
701
+ "epoch": 1.7333333333333334,
702
+ "step": 26
703
+ },
704
+ {
705
+ "loss": 0.11774009466171265,
706
+ "grad_norm": 1.699737310409546,
707
+ "learning_rate": 7.407407407407407e-07,
708
+ "num_tokens": 42443.0,
709
+ "completions/mean_length": 24.0,
710
+ "completions/min_length": 20.0,
711
+ "completions/max_length": 28.0,
712
+ "completions/clipped_ratio": 0.0,
713
+ "completions/mean_terminated_length": 24.0,
714
+ "completions/min_terminated_length": 20.0,
715
+ "completions/max_terminated_length": 28.0,
716
+ "rewards/reward_function/mean": 0.574999988079071,
717
+ "rewards/reward_function/std": 0.10606600344181061,
718
+ "reward": 0.574999988079071,
719
+ "reward_std": 0.10606600344181061,
720
+ "frac_reward_zero_std": 0.0,
721
+ "entropy": 0.14994759857654572,
722
+ "clip_ratio/low_mean": 0.0,
723
+ "clip_ratio/low_min": 0.0,
724
+ "clip_ratio/high_mean": 0.0,
725
+ "clip_ratio/high_max": 0.0,
726
+ "clip_ratio/region_mean": 0.0,
727
+ "step_time": 15.477631969999948,
728
+ "epoch": 1.8,
729
+ "step": 27
730
+ },
731
+ {
732
+ "loss": -0.05434012413024902,
733
+ "grad_norm": 0.8750740885734558,
734
+ "learning_rate": 5.555555555555555e-07,
735
+ "num_tokens": 44062.0,
736
+ "completions/mean_length": 32.5,
737
+ "completions/min_length": 30.0,
738
+ "completions/max_length": 35.0,
739
+ "completions/clipped_ratio": 0.0,
740
+ "completions/mean_terminated_length": 32.5,
741
+ "completions/min_terminated_length": 30.0,
742
+ "completions/max_terminated_length": 35.0,
743
+ "rewards/reward_function/mean": 0.6895999908447266,
744
+ "rewards/reward_function/std": 0.10309616476297379,
745
+ "reward": 0.6895999908447266,
746
+ "reward_std": 0.10309616476297379,
747
+ "frac_reward_zero_std": 0.0,
748
+ "entropy": 0.142868272960186,
749
+ "clip_ratio/low_mean": 0.0,
750
+ "clip_ratio/low_min": 0.0,
751
+ "clip_ratio/high_mean": 0.0,
752
+ "clip_ratio/high_max": 0.0,
753
+ "clip_ratio/region_mean": 0.0,
754
+ "step_time": 12.90737977799995,
755
+ "epoch": 1.8666666666666667,
756
+ "step": 28
757
+ },
758
+ {
759
+ "loss": 0.15356993675231934,
760
+ "grad_norm": 1.543925404548645,
761
+ "learning_rate": 3.7037037037037036e-07,
762
+ "num_tokens": 45476.0,
763
+ "completions/mean_length": 23.0,
764
+ "completions/min_length": 18.0,
765
+ "completions/max_length": 28.0,
766
+ "completions/clipped_ratio": 0.0,
767
+ "completions/mean_terminated_length": 23.0,
768
+ "completions/min_terminated_length": 18.0,
769
+ "completions/max_terminated_length": 28.0,
770
+ "rewards/reward_function/mean": 0.6895999908447266,
771
+ "rewards/reward_function/std": 0.10309616476297379,
772
+ "reward": 0.6895999908447266,
773
+ "reward_std": 0.10309616476297379,
774
+ "frac_reward_zero_std": 0.0,
775
+ "entropy": 0.17455117404460907,
776
+ "clip_ratio/low_mean": 0.0,
777
+ "clip_ratio/low_min": 0.0,
778
+ "clip_ratio/high_mean": 0.0,
779
+ "clip_ratio/high_max": 0.0,
780
+ "clip_ratio/region_mean": 0.0,
781
+ "step_time": 11.09434006399988,
782
+ "epoch": 1.9333333333333333,
783
+ "step": 29
784
+ },
785
+ {
786
+ "loss": -0.030694186687469482,
787
+ "grad_norm": 1.561765432357788,
788
+ "learning_rate": 1.8518518518518518e-07,
789
+ "num_tokens": 47096.0,
790
+ "completions/mean_length": 23.0,
791
+ "completions/min_length": 22.0,
792
+ "completions/max_length": 24.0,
793
+ "completions/clipped_ratio": 0.0,
794
+ "completions/mean_terminated_length": 23.0,
795
+ "completions/min_terminated_length": 22.0,
796
+ "completions/max_terminated_length": 24.0,
797
+ "rewards/reward_function/mean": 0.543749988079071,
798
+ "rewards/reward_function/std": 0.06187182664871216,
799
+ "reward": 0.543749988079071,
800
+ "reward_std": 0.06187182664871216,
801
+ "frac_reward_zero_std": 0.0,
802
+ "entropy": 0.13204744830727577,
803
+ "clip_ratio/low_mean": 0.0,
804
+ "clip_ratio/low_min": 0.0,
805
+ "clip_ratio/high_mean": 0.0,
806
+ "clip_ratio/high_max": 0.0,
807
+ "clip_ratio/region_mean": 0.0,
808
+ "step_time": 11.530309271000078,
809
+ "epoch": 2.0,
810
+ "step": 30
811
+ },
812
+ {
813
+ "train_runtime": 507.6102,
814
+ "train_samples_per_second": 0.059,
815
+ "train_steps_per_second": 0.059,
816
+ "total_flos": 0.0,
817
+ "train_loss": -0.021817301710446674,
818
+ "epoch": 2.0,
819
+ "step": 30
820
+ }
821
+ ]
artifacts/training_summary.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ train_runtime_sec,train_steps,epochs,train_loss_final,reward_min,reward_max,reward_last
2
+ 507.6,30,2,-0.02182,0.40209999680519104,0.6895999908447266,0.543749988079071
inference.py CHANGED
@@ -20,11 +20,14 @@ from typing import Any, Dict, List
20
 
21
  import requests
22
  from openai import OpenAI
 
23
 
24
  # ---------------------------------------------------------------------------
25
  # Configuration
26
  # ---------------------------------------------------------------------------
27
 
 
 
28
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
29
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
30
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or ""
 
20
 
21
  import requests
22
  from openai import OpenAI
23
+ from dotenv import load_dotenv
24
 
25
  # ---------------------------------------------------------------------------
26
  # Configuration
27
  # ---------------------------------------------------------------------------
28
 
29
+ load_dotenv()
30
+
31
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
32
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
33
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or ""
server/app.py CHANGED
@@ -3,9 +3,11 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
 
7
  from openenv.core.env_server import create_fastapi_app
8
  from fastapi import Query
 
9
 
10
  from constants import PROJECT_DESCRIPTION, VERSION
11
  from models import CommitmentAction, CommitmentObservation, CommitmentState
@@ -13,10 +15,32 @@ from server.environment import CommitmentEnvironment
13
  from server.mcp import router as mcp_router
14
  from server.tasks import get_scenario_ids_grouped
15
 
16
- _shared_env = CommitmentEnvironment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  app = create_fastapi_app(
19
- env=lambda: _shared_env,
20
  action_cls=CommitmentAction,
21
  observation_cls=CommitmentObservation,
22
  )
@@ -27,7 +51,7 @@ app.version = VERSION
27
 
28
  app.routes[:] = [
29
  r for r in app.routes
30
- if not (hasattr(r, "path") and r.path in ("/state", "/mcp", "/reset"))
31
  ]
32
 
33
 
@@ -44,9 +68,11 @@ def reset_episode(
44
  query params in this deployment setup, which made scenario selection
45
  non-deterministic for demos/evaluations.
46
  """
47
- obs = _shared_env.reset(
 
 
48
  seed=seed,
49
- episode_id=episode_id,
50
  task_id=task_id,
51
  difficulty=difficulty,
52
  )
@@ -54,12 +80,30 @@ def reset_episode(
54
  "observation": obs.model_dump(),
55
  "reward": float(obs.reward),
56
  "done": bool(obs.done),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  }
58
 
59
 
60
  @app.get("/state", response_model=CommitmentState)
61
- def get_state() -> CommitmentState:
62
- return _shared_env.state
 
63
 
64
 
65
  @app.get("/tasks")
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ from threading import Lock
7
 
8
  from openenv.core.env_server import create_fastapi_app
9
  from fastapi import Query
10
+ from pydantic import BaseModel
11
 
12
  from constants import PROJECT_DESCRIPTION, VERSION
13
  from models import CommitmentAction, CommitmentObservation, CommitmentState
 
15
  from server.mcp import router as mcp_router
16
  from server.tasks import get_scenario_ids_grouped
17
 
18
+ _DEFAULT_SESSION_ID = "default"
19
+ _env_store: dict[str, CommitmentEnvironment] = {
20
+ _DEFAULT_SESSION_ID: CommitmentEnvironment(),
21
+ }
22
+ _env_store_lock = Lock()
23
+
24
+
25
+ def _get_env(session_id: str) -> CommitmentEnvironment:
26
+ """Return a per-session environment instance.
27
+
28
+ This avoids cross-user state bleed from a single shared mutable environment.
29
+ Clients can pass ``episode_id`` query param to isolate sessions.
30
+ """
31
+ with _env_store_lock:
32
+ env = _env_store.get(session_id)
33
+ if env is None:
34
+ env = CommitmentEnvironment()
35
+ _env_store[session_id] = env
36
+ return env
37
+
38
+
39
+ class StepPayload(BaseModel):
40
+ action: CommitmentAction
41
 
42
  app = create_fastapi_app(
43
+ env=lambda: _get_env(_DEFAULT_SESSION_ID),
44
  action_cls=CommitmentAction,
45
  observation_cls=CommitmentObservation,
46
  )
 
51
 
52
  app.routes[:] = [
53
  r for r in app.routes
54
+ if not (hasattr(r, "path") and r.path in ("/state", "/mcp", "/reset", "/step"))
55
  ]
56
 
57
 
 
68
  query params in this deployment setup, which made scenario selection
69
  non-deterministic for demos/evaluations.
70
  """
71
+ session_id = episode_id or _DEFAULT_SESSION_ID
72
+ env = _get_env(session_id)
73
+ obs = env.reset(
74
  seed=seed,
75
+ episode_id=session_id,
76
  task_id=task_id,
77
  difficulty=difficulty,
78
  )
 
80
  "observation": obs.model_dump(),
81
  "reward": float(obs.reward),
82
  "done": bool(obs.done),
83
+ "episode_id": session_id,
84
+ }
85
+
86
+
87
+ @app.post("/step")
88
+ def step_episode(
89
+ payload: StepPayload,
90
+ episode_id: str | None = Query(default=None),
91
+ ) -> dict[str, object]:
92
+ session_id = episode_id or _DEFAULT_SESSION_ID
93
+ env = _get_env(session_id)
94
+ obs = env.step(payload.action)
95
+ return {
96
+ "observation": obs.model_dump(),
97
+ "reward": float(obs.reward),
98
+ "done": bool(obs.done),
99
+ "episode_id": session_id,
100
  }
101
 
102
 
103
  @app.get("/state", response_model=CommitmentState)
104
+ def get_state(episode_id: str | None = Query(default=None)) -> CommitmentState:
105
+ session_id = episode_id or _DEFAULT_SESSION_ID
106
+ return _get_env(session_id).state
107
 
108
 
109
  @app.get("/tasks")
server/environment.py CHANGED
@@ -109,12 +109,12 @@ class CommitmentEnvironment(
109
  return self._finish_episode()
110
 
111
  step_reward = 0.0
112
- tool_result = self._dispatch_tool(action, at)
113
  self._last_tool_result = tool_result
114
 
115
- if "CONFLICT" in tool_result:
116
  step_reward = -0.05
117
- elif at in ("schedule_meeting", "reschedule_event", "send_email", "book_restaurant"):
118
  step_reward = 0.05
119
 
120
  self._cumulative_reward += step_reward
@@ -144,14 +144,14 @@ class CommitmentEnvironment(
144
  # Tool dispatch
145
  # ------------------------------------------------------------------
146
 
147
- def _dispatch_tool(self, action: CommitmentAction, at: str) -> str:
148
  assert self._world is not None
149
  turn = self._step_count
150
 
151
  if at == "view_calendar":
152
- return self._world.view_calendar(action.date)
153
  elif at == "check_availability":
154
- return self._world.check_availability(action.person)
155
  elif at == "search_restaurants":
156
  return self._world.search_restaurants(
157
  cuisine=action.cuisine,
@@ -159,9 +159,9 @@ class CommitmentEnvironment(
159
  dietary=action.dietary,
160
  max_distance_miles=action.max_distance_miles,
161
  near_airport=action.near_airport,
162
- )
163
  elif at == "schedule_meeting":
164
- return self._world.schedule_meeting(
165
  title=action.title,
166
  date=action.date,
167
  time=action.time,
@@ -170,25 +170,36 @@ class CommitmentEnvironment(
170
  location=action.location,
171
  turn=turn,
172
  )
 
 
173
  elif at == "reschedule_event":
174
- return self._world.reschedule_event(
175
  event_id=action.event_id,
176
  new_time=action.new_time,
177
  turn=turn,
178
  )
 
 
179
  elif at == "cancel_event":
180
- return self._world.cancel_event(action.event_id, turn=turn)
 
 
181
  elif at == "send_email":
182
  return self._world.send_email(
183
  to=action.to,
184
  subject=action.subject,
185
  body=action.body,
186
  turn=turn,
187
- )
188
  elif at == "book_restaurant":
189
- return self._world.book_restaurant(action.restaurant_name, turn=turn)
 
 
190
  else:
191
- return f"Unknown action_type: '{at}'. Valid types: view_calendar, check_availability, search_restaurants, schedule_meeting, reschedule_event, cancel_event, send_email, book_restaurant, submit_plan"
 
 
 
192
 
193
  # ------------------------------------------------------------------
194
  # Observation builder
 
109
  return self._finish_episode()
110
 
111
  step_reward = 0.0
112
+ tool_result, dispatch_status = self._dispatch_tool(action, at)
113
  self._last_tool_result = tool_result
114
 
115
+ if dispatch_status == "conflict":
116
  step_reward = -0.05
117
+ elif dispatch_status == "success" and at in ("schedule_meeting", "reschedule_event", "send_email", "book_restaurant"):
118
  step_reward = 0.05
119
 
120
  self._cumulative_reward += step_reward
 
144
  # Tool dispatch
145
  # ------------------------------------------------------------------
146
 
147
+ def _dispatch_tool(self, action: CommitmentAction, at: str) -> tuple[str, str]:
148
  assert self._world is not None
149
  turn = self._step_count
150
 
151
  if at == "view_calendar":
152
+ return self._world.view_calendar(action.date), "info"
153
  elif at == "check_availability":
154
+ return self._world.check_availability(action.person), "info"
155
  elif at == "search_restaurants":
156
  return self._world.search_restaurants(
157
  cuisine=action.cuisine,
 
159
  dietary=action.dietary,
160
  max_distance_miles=action.max_distance_miles,
161
  near_airport=action.near_airport,
162
+ ), "info"
163
  elif at == "schedule_meeting":
164
+ result = self._world.schedule_meeting(
165
  title=action.title,
166
  date=action.date,
167
  time=action.time,
 
170
  location=action.location,
171
  turn=turn,
172
  )
173
+ status = "conflict" if result.startswith("CONFLICT:") else "success"
174
+ return result, status
175
  elif at == "reschedule_event":
176
+ result = self._world.reschedule_event(
177
  event_id=action.event_id,
178
  new_time=action.new_time,
179
  turn=turn,
180
  )
181
+ status = "conflict" if result.startswith("CONFLICT:") else ("error" if "not found" in result.lower() else "success")
182
+ return result, status
183
  elif at == "cancel_event":
184
+ result = self._world.cancel_event(action.event_id, turn=turn)
185
+ status = "error" if "not found" in result.lower() else "success"
186
+ return result, status
187
  elif at == "send_email":
188
  return self._world.send_email(
189
  to=action.to,
190
  subject=action.subject,
191
  body=action.body,
192
  turn=turn,
193
+ ), "success"
194
  elif at == "book_restaurant":
195
+ result = self._world.book_restaurant(action.restaurant_name, turn=turn)
196
+ status = "error" if "not found" in result.lower() else "success"
197
+ return result, status
198
  else:
199
+ return (
200
+ f"Unknown action_type: '{at}'. Valid types: view_calendar, check_availability, search_restaurants, schedule_meeting, reschedule_event, cancel_event, send_email, book_restaurant, submit_plan",
201
+ "error",
202
+ )
203
 
204
  # ------------------------------------------------------------------
205
  # Observation builder
server/graders.py CHANGED
@@ -98,7 +98,7 @@ def _check_constraint(constraint, world: WorldState) -> bool:
98
  em.get("to", "").lower() == lower or lower in em.get("body", "").lower()
99
  for em in world.emails_sent
100
  )
101
- return higher_kept
102
 
103
  return False
104
 
 
98
  em.get("to", "").lower() == lower or lower in em.get("body", "").lower()
99
  for em in world.emails_sent
100
  )
101
+ return higher_kept and lower_moved
102
 
103
  return False
104
 
training/env_factory.py CHANGED
@@ -143,7 +143,9 @@ class CommitmentOSEnvFactory:
143
  if obs.done:
144
  break
145
  except Exception:
146
- continue
 
 
147
 
148
  if not env._done:
149
  obs = env.step(CommitmentAction(action_type="submit_plan"))
 
143
  if obs.done:
144
  break
145
  except Exception:
146
+ # Invalid action payloads should be penalized, not silently ignored.
147
+ last_reward = 0.01
148
+ break
149
 
150
  if not env._done:
151
  obs = env.step(CommitmentAction(action_type="submit_plan"))