Aswini-Kumar commited on
Commit
f936921
·
verified ·
1 Parent(s): 0336beb

Upload training/make_notebook.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/make_notebook.py +162 -0
training/make_notebook.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ nb = {
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5,
6
+ "metadata": {
7
+ "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
8
+ "language_info": {"name": "python", "version": "3.10.0"}
9
+ },
10
+ "cells": [
11
+ {
12
+ "cell_type": "markdown", "id": "a1", "metadata": {},
13
+ "source": [
14
+ "# DataCentric-Env — GRPO Training Notebook\n\n",
15
+ "Trains Qwen2.5-3B-Instruct as a data quality agent using GRPO.\n\n",
16
+ "**Sections:**\n",
17
+ "1. Install dependencies\n",
18
+ "2. Model setup (Qwen2.5-3B-Instruct, 4-bit LoRA)\n",
19
+ "3. Rollout function\n",
20
+ "4. Collect training data\n",
21
+ "5. GRPO training loop\n",
22
+ "6. Save model via Unsloth merge path"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code", "id": "c1", "metadata": {}, "outputs": [], "execution_count": None,
27
+ "source": [
28
+ "# Cell 1: Install dependencies\n",
29
+ "!pip install unsloth trl transformers accelerate peft datasets requests"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code", "id": "c2", "metadata": {}, "outputs": [], "execution_count": None,
34
+ "source": [
35
+ "# Cell 2: Imports and config\n",
36
+ "from unsloth import FastLanguageModel\n",
37
+ "from trl import GRPOTrainer, GRPOConfig\n",
38
+ "from datasets import Dataset\n",
39
+ "import requests, json, torch\n",
40
+ "\n",
41
+ "ENV_URL = 'https://your-hf-username-datacentric-env.hf.space' # set your HF Space URL\n",
42
+ "\n",
43
+ "SYSTEM_PROMPT = (\n",
44
+ " 'You are a data quality agent. You receive dataset statistics and must choose '\n",
45
+ " 'which specialist tool to call to improve the dataset so a downstream classifier '\n",
46
+ " 'performs better.\\n\\n'\n",
47
+ " 'Always respond with valid JSON in this exact format:\\n'\n",
48
+ " '{\"agent\": \"<tool_name>\", \"target\": \"<column_or_all>\", \"strategy\": \"<strategy_name>\"}\\n\\n'\n",
49
+ " 'Available tools: cleaner, augmenter, balancer, relabeler, validator\\n'\n",
50
+ " 'Cleaner strategies: median_impute, mean_impute, drop_rows\\n'\n",
51
+ " 'Balancer strategies: undersample\\n'\n",
52
+ " 'Relabeler: use when labels are noisy, costs 2 budget points.'\n",
53
+ ")\n",
54
+ "print('Imports OK')"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code", "id": "c3", "metadata": {}, "outputs": [], "execution_count": None,
59
+ "source": [
60
+ "# Cell 3: Model setup\n",
61
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
62
+ " model_name='unsloth/Qwen2.5-3B-Instruct',\n",
63
+ " max_seq_length=1024,\n",
64
+ " load_in_4bit=True,\n",
65
+ ")\n",
66
+ "model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=32)\n",
67
+ "print('Model loaded')"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code", "id": "c4", "metadata": {}, "outputs": [], "execution_count": None,
72
+ "source": [
73
+ "# Cell 4: Rollout function\n",
74
+ "def build_prompt(obs):\n",
75
+ " return SYSTEM_PROMPT + '\\n\\nCurrent state:\\n' + json.dumps(obs, indent=2) + '\\n\\nYour action:'\n",
76
+ "\n",
77
+ "def rollout(prompt='start'):\n",
78
+ " obs = requests.post(ENV_URL + '/reset').json()\n",
79
+ " trajectories = []\n",
80
+ " for step in range(10):\n",
81
+ " full_prompt = build_prompt(obs)\n",
82
+ " inputs = tokenizer(full_prompt, return_tensors='pt').to('cuda')\n",
83
+ " with torch.no_grad():\n",
84
+ " outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7)\n",
85
+ " response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
86
+ " try:\n",
87
+ " action = json.loads(response.strip())\n",
88
+ " except Exception:\n",
89
+ " action = {'agent': 'validator'}\n",
90
+ " result = requests.post(ENV_URL + '/step', json=action).json()\n",
91
+ " reward = result.get('reward', -1.0)\n",
92
+ " trajectories.append({'prompt': full_prompt, 'response': response, 'reward': reward})\n",
93
+ " obs = result.get('observation', obs)\n",
94
+ " if result.get('done'):\n",
95
+ " break\n",
96
+ " return trajectories\n",
97
+ "\n",
98
+ "print('Rollout function defined')"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code", "id": "c5", "metadata": {}, "outputs": [], "execution_count": None,
103
+ "source": [
104
+ "# Cell 5: Collect rollouts and build dataset\n",
105
+ "print('Collecting rollouts...')\n",
106
+ "all_trajectories = []\n",
107
+ "for episode in range(50):\n",
108
+ " all_trajectories.extend(rollout('start'))\n",
109
+ " if episode % 10 == 0:\n",
110
+ " print(f' Episode {episode}/50 collected')\n",
111
+ "\n",
112
+ "dataset = Dataset.from_list([\n",
113
+ " {'prompt': t['prompt'], 'chosen': t['response'], 'reward': t['reward']}\n",
114
+ " for t in all_trajectories\n",
115
+ "])\n",
116
+ "print(f'Dataset size: {len(dataset)}')"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code", "id": "c6", "metadata": {}, "outputs": [], "execution_count": None,
121
+ "source": [
122
+ "# Cell 6: GRPO training\n",
123
+ "config = GRPOConfig(\n",
124
+ " output_dir='./datacentric-grpo',\n",
125
+ " num_train_epochs=3,\n",
126
+ " per_device_train_batch_size=4,\n",
127
+ " learning_rate=5e-5,\n",
128
+ " logging_steps=10,\n",
129
+ " save_steps=100,\n",
130
+ " report_to='none',\n",
131
+ ")\n",
132
+ "\n",
133
+ "trainer = GRPOTrainer(\n",
134
+ " model=model,\n",
135
+ " args=config,\n",
136
+ " train_dataset=dataset,\n",
137
+ " tokenizer=tokenizer,\n",
138
+ ")\n",
139
+ "\n",
140
+ "trainer.train()"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code", "id": "c7", "metadata": {}, "outputs": [], "execution_count": None,
145
+ "source": [
146
+ "# Cell 7: Save via Unsloth merge path\n",
147
+ "# IMPORTANT: do NOT use naive save_pretrained — use Unsloth merge path\n",
148
+ "model.save_pretrained_merged(\n",
149
+ " 'datacentric-grpo-final',\n",
150
+ " tokenizer,\n",
151
+ " save_method='merged_16bit',\n",
152
+ ")\n",
153
+ "print('Training complete. Model saved to datacentric-grpo-final/')"
154
+ ]
155
+ }
156
+ ]
157
+ }
158
+
159
+ with open("training/train.ipynb", "w") as f:
160
+ json.dump(nb, f, indent=1)
161
+
162
+ print("Notebook created successfully.")