jayantaggarwal-sketch commited on
Commit
27cbc22
·
1 Parent(s): af8810b

Sync latest GitHub commit and notebook

Browse files
training/CommitmentOS_Training.ipynb ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# CommitmentOS Training Notebook\\n",
8
+ "\\n",
9
+ "This notebook reproduces GRPO training for CommitmentOS using TRL + LoRA."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "!pip -q install --upgrade pip\\n",
19
+ "!pip -q install openenv trl transformers peft datasets torch accelerate bitsandbytes matplotlib pandas"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "!git clone https://github.com/Jayant2304/commitment_os.git\\n",
29
+ "%cd commitment_os\\n",
30
+ "!python -m pytest tests/test_environment.py -q"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "!python training/train_grpo.py \\\\\\n",
40
+ " --model Qwen/Qwen2.5-1.5B-Instruct \\\\\\n",
41
+ " --epochs 2 \\\\\\n",
42
+ " --lr 5e-6 \\\\\\n",
43
+ " --batch_size 1 \\\\\\n",
44
+ " --group_size 2 \\\\\\n",
45
+ " --lora_rank 16 \\\\\\n",
46
+ " --lora_alpha 32 \\\\\\n",
47
+ " --output_dir ./training_output"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "import json\\n",
57
+ "import matplotlib.pyplot as plt\\n",
58
+ "from pathlib import Path\\n",
59
+ "\\n",
60
+ "p = Path('training_output/training_metrics.json')\\n",
61
+ "logs = json.loads(p.read_text())\\n",
62
+ "\\n",
63
+ "steps = [float(x['step']) for x in logs if 'step' in x and 'loss' in x]\\n",
64
+ "loss = [float(x['loss']) for x in logs if 'step' in x and 'loss' in x]\\n",
65
+ "r_steps = [float(x['step']) for x in logs if 'step' in x and 'reward' in x]\\n",
66
+ "rewards = [float(x['reward']) for x in logs if 'step' in x and 'reward' in x]\\n",
67
+ "\\n",
68
+ "plt.figure(figsize=(8,5))\\n",
69
+ "plt.plot(steps, loss, marker='o')\\n",
70
+ "plt.title('CommitmentOS GRPO Loss vs Step')\\n",
71
+ "plt.xlabel('Step'); plt.ylabel('Loss'); plt.grid(alpha=0.3)\\n",
72
+ "plt.tight_layout(); plt.savefig('loss_curve.png', dpi=200); plt.show()\\n",
73
+ "\\n",
74
+ "plt.figure(figsize=(8,5))\\n",
75
+ "plt.plot(r_steps, rewards, marker='o')\\n",
76
+ "plt.title('CommitmentOS GRPO Reward vs Step')\\n",
77
+ "plt.xlabel('Step'); plt.ylabel('Reward'); plt.grid(alpha=0.3)\\n",
78
+ "plt.tight_layout(); plt.savefig('reward_curve.png', dpi=200); plt.show()"
79
+ ]
80
+ }
81
+ ],
82
+ "metadata": {
83
+ "kernelspec": {
84
+ "display_name": "Python 3",
85
+ "language": "python",
86
+ "name": "python3"
87
+ },
88
+ "language_info": {
89
+ "name": "python",
90
+ "version": "3.10"
91
+ }
92
+ },
93
+ "nbformat": 4,
94
+ "nbformat_minor": 5
95
+ }