Aswini-Kumar commited on
Commit
054ddcc
Β·
verified Β·
1 Parent(s): d8b5f3c

Upload training/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train.py +293 -85
training/train.py CHANGED
@@ -1,71 +1,199 @@
1
  """
2
- GRPO training script β€” DataCentric-Env
 
3
  Trains Qwen2.5-3B-Instruct with GRPO via TRL + Unsloth.
 
 
 
 
 
 
4
 
5
- Run in Colab (GPU required). Make sure the environment server is deployed
6
- to HuggingFace Spaces and set ENV_URL below before running.
 
 
 
 
7
  """
8
 
9
- from unsloth import FastLanguageModel
10
- from trl import GRPOTrainer, GRPOConfig
11
- from datasets import Dataset
12
- import requests
13
  import json
 
 
14
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # ─── Configuration ───────────────────────────────────────────────────────────
17
- ENV_URL = "https://aswini-kumar-datacentric-env.hf.space" # HuggingFace Space URL
 
 
 
 
 
18
 
19
- SYSTEM_PROMPT = """You are a data quality agent. You receive dataset statistics and must choose which specialist tool to call to improve the dataset so a downstream classifier performs better.
 
 
 
 
 
 
20
 
21
- Always respond with valid JSON in this exact format:
22
- {"agent": "<tool_name>", "target": "<column_or_all>", "strategy": "<strategy_name>"}
 
 
23
 
24
- Available tools: cleaner, augmenter, balancer, relabeler, validator
25
- Cleaner strategies: median_impute, mean_impute, drop_rows
26
- Balancer strategies: undersample
27
- Relabeler: use when labels are noisy, costs 2 budget points."""
28
 
29
- # ─── Model setup ─────────────────────────────────────────────────────────────
 
30
  model, tokenizer = FastLanguageModel.from_pretrained(
31
- model_name="unsloth/Qwen2.5-3B-Instruct",
32
- max_seq_length=1024,
33
  load_in_4bit=True,
34
  )
35
- model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=32)
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # ─── Rollout function ─────────────────────────────────────────────────────────
39
- def build_prompt(obs):
40
- obs_text = json.dumps(obs, indent=2)
41
- return f"{SYSTEM_PROMPT}\n\nCurrent state:\n{obs_text}\n\nYour action:"
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def rollout(prompt="start"):
45
- """Run one episode and return (prompt, response, reward) tuples."""
46
- obs = requests.post(f"{ENV_URL}/reset").json()
 
 
47
 
48
  trajectories = []
49
 
50
- for step in range(10):
51
- full_prompt = build_prompt(obs)
 
 
52
 
53
- inputs = tokenizer(full_prompt, return_tensors="pt").to("cuda")
54
  with torch.no_grad():
55
- outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7)
56
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
57
 
58
- # Parse action
59
  try:
60
- action = json.loads(response.strip())
 
 
61
  except Exception:
62
- action = {"agent": "validator"} # fallback
 
63
 
64
- result = requests.post(f"{ENV_URL}/step", json=action).json()
65
- reward = result.get("reward", -1.0)
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
67
  trajectories.append({
68
- "prompt": full_prompt,
69
  "response": response,
70
  "reward": reward,
71
  })
@@ -77,70 +205,150 @@ def rollout(prompt="start"):
77
  return trajectories
78
 
79
 
80
- # ─── Collect rollouts ─────────────────────────────────────────────────────────
81
- print("Collecting rollouts...")
82
  all_trajectories = []
83
- for episode in range(50):
84
- all_trajectories.extend(rollout("start"))
85
- if episode % 10 == 0:
86
- print(f" Episode {episode}/50 collected")
87
-
88
- # ─── Build training dataset ───────────────────────────────────────────────────
89
- dataset = Dataset.from_list([
90
- {"prompt": t["prompt"], "chosen": t["response"], "reward": t["reward"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  for t in all_trajectories
 
92
  ])
 
93
 
94
- # ─── GRPO config ─────────────────────────────────────────────────────────────
 
95
  config = GRPOConfig(
96
  output_dir="./datacentric-grpo",
97
  num_train_epochs=3,
98
- per_device_train_batch_size=4,
99
- learning_rate=5e-5,
100
- logging_steps=10,
101
- save_steps=100,
102
- report_to="none", # swap to "wandb" if you want live curves
 
 
 
 
 
 
103
  )
104
 
105
- # ─── Monitor logging ──────────────────────────────────────────────────────────
106
- def log_sample(step):
107
- """Log a live episode sample every 20 steps β€” watch for reward hacking."""
108
- obs = requests.post(f"{ENV_URL}/reset").json()
109
- print(f"\n--- Generation sample at step {step} ---")
110
- for t in range(5):
111
- inputs = tokenizer(build_prompt(obs), return_tensors="pt").to("cuda")
112
- with torch.no_grad():
113
- out = model.generate(**inputs, max_new_tokens=80)
114
- response = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
115
- print(f" Step {t}: agent output = {response[:120]}")
116
- try:
117
- action = json.loads(response.strip())
118
- except Exception:
119
- print(" WARNING: agent produced invalid JSON β€” format reward not working")
120
- action = {"agent": "validator"}
121
- result = requests.post(f"{ENV_URL}/step", json=action).json()
122
- print(f" Reward: {result.get('reward')} | Accuracy: {result['info']['new_accuracy']} | Done: {result.get('done')}")
123
- obs = result.get("observation", obs)
124
- if result.get("done"):
125
- break
126
-
127
-
128
- # ─── Train ────────────────────────────────────────────────────────────────────
129
  trainer = GRPOTrainer(
130
  model=model,
131
  args=config,
132
- train_dataset=dataset,
133
  tokenizer=tokenizer,
 
134
  )
135
 
136
- trainer.train()
 
 
 
137
 
138
- # ─── Save β€” use Unsloth merge path, NOT naive save_pretrained ────────────────
139
- # IMPORTANT: do NOT upcast 4-bit model to 16-bit then merge naively.
140
- # That damages model quality. Use the Unsloth merge path instead.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  model.save_pretrained_merged(
142
  "datacentric-grpo-final",
143
  tokenizer,
144
- save_method="merged_16bit", # correct merge path via Unsloth
145
  )
146
- print("Training complete. Test inference immediately before moving on.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ training/train.py β€” GRPO training for DataCentric-Env (v0.5).
3
+
4
  Trains Qwen2.5-3B-Instruct with GRPO via TRL + Unsloth.
5
+ Run end-to-end in Colab (T4 GPU is sufficient).
6
+
7
+ Before running:
8
+ 1. Deploy the environment to HF Spaces (run deploy_to_hf.py locally)
9
+ 2. Set ENV_URL below to your HF Space URL
10
+ 3. Runtime β†’ Run all
11
 
12
+ What the agent learns:
13
+ - Given a real messy dataset (UCI Adult Census, Pima Diabetes, etc.)
14
+ - Query specialist agents to diagnose issues (domain-aware analysis)
15
+ - Apply recommended fixes to improve classifier accuracy on a frozen holdout
16
+ - Navigate: when to rollback a bad apply, how to interpret feature importance,
17
+ how to prioritize domain-specific issues (zeros-as-missing in medical data)
18
  """
19
 
20
+ import os
 
 
 
21
  import json
22
+ import time
23
+ import requests
24
  import torch
25
+ import numpy as np
26
+ import matplotlib
27
+ matplotlib.use("Agg")
28
+ import matplotlib.pyplot as plt
29
+ from datasets import Dataset
30
+ from unsloth import FastLanguageModel
31
+ from trl import GRPOTrainer, GRPOConfig
32
+
33
+ # ── Configuration ──────────────────────────────────────────────────────────────
34
+ ENV_URL = "https://aswini-kumar-datacentric-env.hf.space" # ← your HF Space URL
35
+ MODEL_NAME = "unsloth/Qwen2.5-3B-Instruct"
36
+ MAX_SEQ_LEN = 2048
37
+ N_ROLLOUT_EPISODES = 60
38
+ MAX_STEPS_PER_EPISODE = 12
39
+ LORA_RANK = 16
40
+
41
+ # ── System prompt ──────────────────────────────────────────────────────────────
42
+ SYSTEM_PROMPT = """You are an expert data engineer agent. You are given a real-world dataset \
43
+ with known quality issues and must fix it so a frozen classifier achieves the target accuracy.
44
+
45
+ You work by querying specialist agents for analysis, then deciding which recommendation to apply.
46
 
47
+ WORKFLOW:
48
+ 1. Start by calling query_analyst (cost 2) β€” it gives you a prioritized action plan and \
49
+ references the published benchmark accuracy for this dataset.
50
+ 2. Then call the specific agent it recommends (query_cleaner, query_balancer, etc.)
51
+ 3. Apply the best recommendation using its rec_id
52
+ 4. If accuracy dropped after an apply, use rollback to undo it (max 3 per episode)
53
+ 5. Read feature_importance in the response β€” it shows what the model actually learned
54
 
55
+ DOMAIN RULES (critical):
56
+ - In medical datasets, zero values for physiological measurements are IMPOSSIBLE β€” they mean \
57
+ missing data. Always apply zero_to_nan_impute before other cleaning.
58
+ - In financial datasets, heavily skewed features (like capital-gain) should be log-transformed.
59
+ - Removing rows is dangerous β€” data integrity limit is 10% of training rows max.
60
+ - Large augmentation (>200 rows) may overfit training set and HURT holdout accuracy. \
61
+ If accuracy drops after augmentation, rollback and try balancer instead.
62
 
63
+ OUTPUT FORMAT β€” respond with valid JSON only, no explanation:
64
+ For queries: {"action": "query_analyst"} or {"action": "query_cleaner"} etc.
65
+ For apply: {"action": "apply", "rec_id": "<exact_id_from_recommendations>"}
66
+ For rollback: {"action": "rollback"}"""
67
 
 
 
 
 
68
 
69
+ # ── Model setup ────────────────────────────────────────────────────────────────
70
+ print("Loading model...")
71
  model, tokenizer = FastLanguageModel.from_pretrained(
72
+ model_name=MODEL_NAME,
73
+ max_seq_length=MAX_SEQ_LEN,
74
  load_in_4bit=True,
75
  )
76
+ model = FastLanguageModel.get_peft_model(
77
+ model,
78
+ r=LORA_RANK,
79
+ lora_alpha=LORA_RANK * 2,
80
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
81
+ "gate_proj", "up_proj", "down_proj"],
82
+ lora_dropout=0.05,
83
+ bias="none",
84
+ use_gradient_checkpointing=True,
85
+ )
86
+ print(f"Model loaded: {MODEL_NAME} with LoRA r={LORA_RANK}")
87
+
88
 
89
+ # ── Prompt builder ─────────────────────────────────────────────────────────────
90
+ def build_prompt(obs: dict) -> str:
91
+ """
92
+ Build a compact but information-rich prompt from the observation.
93
+ Excludes the full pending_recommendations dict (too verbose) β€”
94
+ only includes rec_ids and their reason.
95
+ """
96
+ # Compact observation for the prompt
97
+ compact = {
98
+ "dataset": obs.get("dataset", {}).get("name", "unknown"),
99
+ "domain": obs.get("dataset", {}).get("domain", ""),
100
+ "known_issues": obs.get("dataset", {}).get("known_issues", [])[:2],
101
+ "current_accuracy": obs.get("current_accuracy"),
102
+ "target_accuracy": obs.get("target_accuracy"),
103
+ "accuracy_gap": obs.get("accuracy_gap"),
104
+ "benchmarks": obs.get("benchmarks", {}),
105
+ "budget_remaining": obs.get("budget_remaining"),
106
+ "dataset_stats": obs.get("dataset_stats", {}),
107
+ "pending_recommendations": {
108
+ rid: {
109
+ "agent": info.get("agent"),
110
+ "type": info.get("type"),
111
+ "reason": info.get("reason", "")[:120], # truncate
112
+ "domain_informed": info.get("domain_informed", False),
113
+ }
114
+ for rid, info in obs.get("pending_recommendations", {}).items()
115
+ },
116
+ "episode_trace": obs.get("episode_trace", [])[-3:], # last 3 steps
117
+ "feature_importance": obs.get("feature_importance", {}).get("top_positive", [])[:2],
118
+ "available_actions": obs.get("available_actions"),
119
+ }
120
+ return (
121
+ f"<|system|>\n{SYSTEM_PROMPT}\n"
122
+ f"<|user|>\nCurrent environment state:\n{json.dumps(compact, indent=2)}\n"
123
+ f"<|assistant|>\n"
124
+ )
125
 
 
 
 
 
126
 
127
+ # ── Episode rollout ────────────────────────────────────────────────────────────
128
+ def run_episode(difficulty: str = "easy") -> list[dict]:
129
+ """
130
+ Run one full episode. Returns list of (prompt, response, reward) tuples.
131
+ """
132
+ # Reset β€” get session_id and initial observation
133
+ try:
134
+ resp = requests.post(
135
+ f"{ENV_URL}/reset",
136
+ json={"difficulty": difficulty},
137
+ timeout=60,
138
+ )
139
+ resp.raise_for_status()
140
+ except Exception as e:
141
+ print(f" Reset failed: {e}")
142
+ return []
143
 
144
+ obs = resp.json()
145
+ session_id = obs.get("session_id")
146
+ if not session_id:
147
+ print(" No session_id in reset response.")
148
+ return []
149
 
150
  trajectories = []
151
 
152
+ for step_num in range(MAX_STEPS_PER_EPISODE):
153
+ prompt = build_prompt(obs)
154
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
155
+ max_length=MAX_SEQ_LEN).to("cuda")
156
 
 
157
  with torch.no_grad():
158
+ outputs = model.generate(
159
+ **inputs,
160
+ max_new_tokens=80,
161
+ temperature=0.8,
162
+ do_sample=True,
163
+ pad_token_id=tokenizer.eos_token_id,
164
+ )
165
+ response = tokenizer.decode(
166
+ outputs[0][inputs["input_ids"].shape[1]:],
167
+ skip_special_tokens=True
168
+ ).strip()
169
 
170
+ # Parse and validate action
171
  try:
172
+ action = json.loads(response)
173
+ if "action" not in action:
174
+ raise ValueError("missing 'action' key")
175
  except Exception:
176
+ # Invalid JSON β€” format grader will penalize this
177
+ action = {"action": "query_analyst"}
178
 
179
+ # Always inject session_id
180
+ payload = {"session_id": session_id, **action}
181
+
182
+ try:
183
+ step_resp = requests.post(
184
+ f"{ENV_URL}/step",
185
+ json=payload,
186
+ timeout=30,
187
+ )
188
+ step_resp.raise_for_status()
189
+ result = step_resp.json()
190
+ except Exception as e:
191
+ print(f" Step failed: {e}")
192
+ break
193
 
194
+ reward = float(result.get("reward", 0.001))
195
  trajectories.append({
196
+ "prompt": prompt,
197
  "response": response,
198
  "reward": reward,
199
  })
 
205
  return trajectories
206
 
207
 
208
+ # ── Collect rollouts across difficulties ───────────────────────────────────────
209
+ print(f"\nCollecting {N_ROLLOUT_EPISODES} episodes...")
210
  all_trajectories = []
211
+ episode_rewards = []
212
+ difficulty_schedule = (
213
+ ["easy"] * 20 + ["medium"] * 20 + ["hard"] * 20
214
+ )
215
+
216
+ for ep_idx, difficulty in enumerate(difficulty_schedule):
217
+ trajs = run_episode(difficulty=difficulty)
218
+ if trajs:
219
+ ep_reward = np.mean([t["reward"] for t in trajs])
220
+ episode_rewards.append(ep_reward)
221
+ all_trajectories.extend(trajs)
222
+ if ep_idx % 10 == 0:
223
+ print(f" Episode {ep_idx}/{N_ROLLOUT_EPISODES} | "
224
+ f"difficulty={difficulty} | mean_reward={ep_reward:.4f} | "
225
+ f"n_steps={len(trajs)}")
226
+ time.sleep(0.5) # avoid hammering the server
227
+
228
+ print(f"\nTotal training samples: {len(all_trajectories)}")
229
+ print(f"Mean reward across all episodes: {np.mean(episode_rewards):.4f}")
230
+
231
+ if len(all_trajectories) < 10:
232
+ raise RuntimeError("Too few training samples collected. Check ENV_URL and server status.")
233
+
234
+
235
+ # ── Build GRPO training dataset ────────────────────────────────────────────────
236
+ # GRPO needs: prompt, completion (response), reward
237
+ train_dataset = Dataset.from_list([
238
+ {
239
+ "prompt": t["prompt"],
240
+ "completion": t["response"],
241
+ "reward": t["reward"],
242
+ }
243
  for t in all_trajectories
244
+ if t["reward"] > 0.001 # filter degenerate samples
245
  ])
246
+ print(f"Training dataset: {len(train_dataset)} samples")
247
 
248
+
249
+ # ── GRPO training ──────────────────────────────────────────────────────────────
250
  config = GRPOConfig(
251
  output_dir="./datacentric-grpo",
252
  num_train_epochs=3,
253
+ per_device_train_batch_size=2,
254
+ gradient_accumulation_steps=4,
255
+ learning_rate=2e-5,
256
+ warmup_ratio=0.1,
257
+ lr_scheduler_type="cosine",
258
+ logging_steps=5,
259
+ save_steps=50,
260
+ report_to="none",
261
+ max_grad_norm=0.3,
262
+ fp16=True,
263
+ dataloader_num_workers=0,
264
  )
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  trainer = GRPOTrainer(
267
  model=model,
268
  args=config,
269
+ train_dataset=train_dataset,
270
  tokenizer=tokenizer,
271
+ reward_funcs=[], # rewards come from environment, already in dataset
272
  )
273
 
274
+ print("\nStarting GRPO training...")
275
+ train_result = trainer.train()
276
+ print(f"Training complete. Final loss: {train_result.training_loss:.4f}")
277
+
278
 
279
+ # ── Sample inspection β€” check for reward hacking ──────────────────────────────
280
+ print("\n--- Sampling 3 agent generations (reward hacking check) ---")
281
+ for i in range(3):
282
+ try:
283
+ resp = requests.post(f"{ENV_URL}/reset", json={"difficulty": "easy"}, timeout=60)
284
+ obs = resp.json()
285
+ session_id = obs["session_id"]
286
+ prompt = build_prompt(obs)
287
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
288
+ max_length=MAX_SEQ_LEN).to("cuda")
289
+ with torch.no_grad():
290
+ out = model.generate(**inputs, max_new_tokens=80, do_sample=False)
291
+ response = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
292
+ print(f"\n Sample {i+1}: {response[:200]}")
293
+
294
+ try:
295
+ action = json.loads(response.strip())
296
+ payload = {"session_id": session_id, **action}
297
+ step_r = requests.post(f"{ENV_URL}/step", json=payload, timeout=30).json()
298
+ print(f" β†’ reward={step_r.get('reward')} | accuracy={step_r.get('observation', {}).get('current_accuracy')}")
299
+ except Exception as e:
300
+ print(f" β†’ parse/step failed: {e}")
301
+ except Exception as e:
302
+ print(f" Sample {i+1} failed: {e}")
303
+
304
+
305
+ # ── Save model β€” Unsloth merge path (NOT naive save_pretrained) ───────────────
306
+ print("\nSaving model (Unsloth merged_16bit path)...")
307
  model.save_pretrained_merged(
308
  "datacentric-grpo-final",
309
  tokenizer,
310
+ save_method="merged_16bit",
311
  )
312
+ print("Model saved to ./datacentric-grpo-final")
313
+
314
+
315
+ # ── Plot training curves β†’ results.png ────────────────────────────────────────
316
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
317
+ fig.suptitle("DataCentric-Env β€” GRPO Training Results", fontsize=14, fontweight="bold")
318
+
319
+ # Episode rewards
320
+ ax1 = axes[0]
321
+ ax1.plot(episode_rewards, color="#4f46e5", linewidth=1.5, alpha=0.6, label="Episode mean reward")
322
+ if len(episode_rewards) >= 5:
323
+ smoothed = np.convolve(episode_rewards, np.ones(5)/5, mode="valid")
324
+ ax1.plot(range(4, len(episode_rewards)), smoothed,
325
+ color="#4f46e5", linewidth=2.5, label="5-ep moving avg")
326
+ ax1.axvline(x=20, color="gray", linestyle="--", alpha=0.5, label="β†’ medium")
327
+ ax1.axvline(x=40, color="gray", linestyle=":", alpha=0.5, label="β†’ hard")
328
+ ax1.set_xlabel("Episode")
329
+ ax1.set_ylabel("Mean Reward")
330
+ ax1.set_title("Reward Progression Over Episodes")
331
+ ax1.legend()
332
+ ax1.set_ylim(0, 1)
333
+ ax1.grid(alpha=0.3)
334
+
335
+ # Reward distribution
336
+ ax2 = axes[1]
337
+ rewards_array = [t["reward"] for t in all_trajectories]
338
+ ax2.hist(rewards_array, bins=30, color="#7c3aed", alpha=0.7, edgecolor="white")
339
+ ax2.axvline(np.mean(rewards_array), color="#ef4444", linewidth=2,
340
+ label=f"Mean={np.mean(rewards_array):.3f}")
341
+ ax2.axvline(np.median(rewards_array), color="#f97316", linewidth=2,
342
+ linestyle="--", label=f"Median={np.median(rewards_array):.3f}")
343
+ ax2.set_xlabel("Reward")
344
+ ax2.set_ylabel("Count")
345
+ ax2.set_title("Distribution of Step Rewards")
346
+ ax2.legend()
347
+ ax2.grid(alpha=0.3)
348
+
349
+ plt.tight_layout()
350
+ plt.savefig("results.png", dpi=150, bbox_inches="tight")
351
+ print("results.png saved.")
352
+ plt.show()
353
+
354
+ print("\nβœ… All done. Submit results.png + datacentric-grpo-final/ directory.")