File size: 1,807 Bytes
d02bacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# AutoDataLab++ Chief of Staff Training\\n",
        "Minimal Colab-friendly scaffold for GRPO/PPO over the Chief of Staff discrete action space."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "!pip install -q trl transformers accelerate pandas matplotlib\\n",
        "from pathlib import Path\\n",
        "import json, random\\n",
        "ROOT = Path('/content/autodatalab-plus')\\n",
        "ACTION_SPACE = [\\n",
        "    {'action_type': 'consult', 'expert_id': 'analyst'},\\n",
        "    {'action_type': 'consult', 'expert_id': 'finance'},\\n",
        "    {'action_type': 'consult', 'expert_id': 'hr'},\\n",
        "    {'action_type': 'consult', 'expert_id': 'strategy'},\\n",
        "    {'action_type': 'summarize'},\\n",
        "    {'action_type': 'submit'},\\n",
        "]\\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from trl import PPOTrainer  # swap to GRPOTrainer if available\\n",
        "print('Use the environment in ceo_brief_env/environment.py to roll out episodes and map actions to token IDs.')\\n",
        "print('Checkpoint 0 = random / base model, checkpoint final = post-training.')\\n",
        "# Save reward curve to training/reward_curves/reward_curve.png after evaluation.\\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}