Buckets:
bbkdevops/unicosys-hypergraph-bucket / tinymind-native-colab-handoff /bundle /train /code_grpo_trainer.py
| """ | |
| Code GRPO Trainer — Policy Optimization with Python Execution Feedback | |
| โมเดลเขียนโค้ด → รัน test cases จริง → reward จากผลลัพธ์จริง | |
| เหมือน DeepSeek-Coder-V2 แต่ปรับสำหรับ TinyMind | |
| Reward structure: | |
| +1.0 ผ่านทุก test case | |
| +0.5 ผ่านบางส่วน (partial credit) | |
| +0.2 format ถูก (มี <think> + <answer> + def) | |
| 0.0 โค้ด compile ได้แต่ logic ผิด | |
| -0.3 SyntaxError | |
| -0.5 ไม่มีโค้ดเลย | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| import io | |
| import json | |
| import re | |
| import sys | |
| import time | |
| import traceback | |
| from contextlib import redirect_stdout, redirect_stderr | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Callable | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from tokenizers import Tokenizer | |
| from model.config import OmegaConfig, small_config | |
| from model.architecture import OmegaModel | |
| from model.reasoning import extract_thinking | |
| from train.grpo_trainer import ( | |
| compute_group_advantages, grpo_policy_loss, grpo_collate, | |
| GRPO_CFG, | |
| ) | |
| CODE_SYSTEM = ( | |
| "You are an expert Python programmer. " | |
| "Think step-by-step in <think>...</think>, " | |
| "then provide complete correct Python in <answer>...</answer>." | |
| ) | |
| CODE_GRPO_CFG = { | |
| **GRPO_CFG, | |
| "data_path": "data/filtered/code_grpo.jsonl", | |
| "ref_checkpoint": "checkpoints/omega_best.pt", | |
| "out_dir": "checkpoints", | |
| "n_samples": 4, | |
| "max_new_tokens": 512, | |
| "temperature": 0.8, | |
| "max_steps": 4_000, | |
| "save_every": 400, | |
| "timeout_sec": 5.0, # max seconds per code execution | |
| } | |
| # ─── Code Extraction ───────────────────────────────────────────────────────── | |
| def extract_code(response: str) -> str: | |
| """ดึง Python code จาก <answer> หรือ markdown block""" | |
| # 1. <answer>...</answer> | |
| m = re.search(r"<answer>([\s\S]*?)</answer>", response, re.IGNORECASE) | |
| if m: | |
| block = m.group(1).strip() | |
| # strip markdown fence inside answer | |
| fence = re.search(r"```python\s*([\s\S]*?)```", block, re.IGNORECASE) | |
| return fence.group(1).strip() if fence else block | |
| # 2. markdown fence anywhere | |
| fence = re.search(r"```python\s*([\s\S]*?)```", response, re.IGNORECASE) | |
| if fence: | |
| return fence.group(1).strip() | |
| # 3. find def ... block | |
| lines = response.split("\n") | |
| code_lines: list[str] = [] | |
| in_def = False | |
| for line in lines: | |
| if re.match(r"^\s*def \w+", line): | |
| in_def = True | |
| if in_def: | |
| code_lines.append(line) | |
| return "\n".join(code_lines) if code_lines else "" | |
| # ─── Execution Sandbox ──────────────────────────────────────────────────────── | |
| class ExecResult: | |
| passed: int | |
| total: int | |
| error: str | |
| def run_code_tests(code: str, test_cases: str, timeout: float = 5.0) -> ExecResult: | |
| """Execute code + test cases in a sandboxed namespace with timeout.""" | |
| if not code.strip(): | |
| return ExecResult(0, 1, "no_code") | |
| # Parse check | |
| try: | |
| ast.parse(code) | |
| except SyntaxError as e: | |
| return ExecResult(0, 1, f"SyntaxError: {e}") | |
| # Count assertions | |
| total = max(1, test_cases.count("assert")) | |
| namespace: dict = {"__builtins__": __builtins__} | |
| stdout_buf = io.StringIO() | |
| stderr_buf = io.StringIO() | |
| try: | |
| with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf): | |
| exec(compile(code, "<solution>", "exec"), namespace) | |
| except Exception as e: | |
| return ExecResult(0, total, f"exec_error: {e}") | |
| # Run each assertion separately for partial credit | |
| passed = 0 | |
| assertion_lines = [ | |
| line.strip() for line in test_cases.split("\n") | |
| if line.strip().startswith("assert") | |
| ] | |
| if not assertion_lines: | |
| assertion_lines = [test_cases.strip()] | |
| for assertion in assertion_lines: | |
| try: | |
| with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf): | |
| exec(compile(assertion, "<test>", "exec"), namespace) | |
| passed += 1 | |
| except AssertionError: | |
| pass | |
| except Exception: | |
| pass | |
| return ExecResult(passed, max(len(assertion_lines), 1), "") | |
| # ─── Code Reward Function ───────────────────────────────────────────────────── | |
| def code_reward( | |
| generated: str, | |
| test_cases: str, | |
| timeout: float = 5.0, | |
| ) -> float: | |
| """Composite reward: correctness + format""" | |
| _, answer_text = extract_thinking(generated) | |
| code = extract_code(generated) | |
| # Format reward | |
| has_think = bool(re.search(r"<think>[\s\S]+</think>", generated, re.IGNORECASE)) | |
| has_answer = bool(re.search(r"<answer>[\s\S]+</answer>", generated, re.IGNORECASE)) | |
| has_def = bool(re.search(r"def \w+", code)) | |
| format_score = (0.1 if has_think else 0) + (0.05 if has_answer else 0) + (0.05 if has_def else 0) | |
| if not code: | |
| return -0.5 + format_score | |
| result = run_code_tests(code, test_cases, timeout) | |
| if "SyntaxError" in result.error: | |
| return -0.3 + format_score | |
| if "exec_error" in result.error: | |
| return -0.1 + format_score | |
| ratio = result.passed / max(result.total, 1) | |
| if ratio == 1.0: | |
| correctness = 1.0 | |
| elif ratio >= 0.5: | |
| correctness = 0.5 * ratio | |
| else: | |
| correctness = 0.1 * ratio | |
| return correctness + format_score | |
| # ─── Code Dataset ───────────────────────────────────────────────────────────── | |
| class CodeDataset(Dataset): | |
| def __init__(self, path: str, tokenizer: Tokenizer, max_prompt_len: int = 512): | |
| self.tokenizer = tokenizer | |
| self.max_prompt_len = max_prompt_len | |
| self.records: list[dict] = [] | |
| with open(path, encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| rec = json.loads(line) | |
| if rec.get("question") and rec.get("test_cases"): | |
| self.records.append(rec) | |
| def __len__(self) -> int: | |
| return len(self.records) | |
| def __getitem__(self, idx: int) -> dict: | |
| rec = self.records[idx] | |
| prompt = ( | |
| f"<bos><system>{CODE_SYSTEM}</system>\n" | |
| f"<user>{rec['question']}</user>\n" | |
| f"<assistant><think>" | |
| ) | |
| enc = self.tokenizer.encode(prompt) | |
| prompt_ids = enc.ids[: self.max_prompt_len] | |
| return { | |
| "question": rec["question"], | |
| "test_cases": rec["test_cases"], | |
| "prompt_ids": prompt_ids, | |
| "level": rec.get("level", 1), | |
| "category": rec.get("category", "unknown"), | |
| } | |
| def code_grpo_collate(batch: list[dict], pad_id: int = 0) -> dict: | |
| max_len = max(len(b["prompt_ids"]) for b in batch) | |
| padded = [b["prompt_ids"] + [pad_id]*(max_len - len(b["prompt_ids"])) for b in batch] | |
| return { | |
| "questions": [b["question"] for b in batch], | |
| "test_cases": [b["test_cases"] for b in batch], | |
| "prompt_ids": torch.tensor(padded, dtype=torch.long), | |
| "prompt_lens": torch.tensor([len(b["prompt_ids"]) for b in batch], dtype=torch.long), | |
| "levels": [b["level"] for b in batch], | |
| } | |
| # ─── Code GRPO Trainer ──────────────────────────────────────────────────────── | |
| class CodeGRPOTrainer: | |
| def __init__(self, cfg: dict = CODE_GRPO_CFG, model_cfg: OmegaConfig | None = None): | |
| self.cfg = cfg | |
| self.model_cfg = model_cfg or small_config() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.dtype = torch.bfloat16 if cfg.get("dtype") == "bfloat16" else torch.float16 | |
| self.step = 0 | |
| self.stats = {"total": 0, "pass_all": 0, "partial": 0, "fail": 0} | |
| def setup(self): | |
| print(f"Code GRPO | device={self.device} | G={self.cfg['n_samples']}") | |
| tok_path = self.cfg["tokenizer_path"] | |
| if not Path(tok_path).exists(): | |
| raise FileNotFoundError(f"Tokenizer not found: {tok_path}") | |
| self.tokenizer = Tokenizer.from_file(tok_path) | |
| ds = CodeDataset( | |
| self.cfg["data_path"], self.tokenizer, self.cfg["max_prompt_len"] | |
| ) | |
| self.loader = DataLoader( | |
| ds, batch_size=1, shuffle=True, | |
| collate_fn=lambda b: code_grpo_collate(b, pad_id=self.model_cfg.pad_token_id), | |
| num_workers=0, | |
| ) | |
| print(f"Code dataset: {len(ds):,} problems") | |
| ref_path = self.cfg.get("ref_checkpoint") | |
| if ref_path and Path(ref_path).exists(): | |
| ckpt = torch.load(ref_path, map_location=self.device, weights_only=False) | |
| saved_cfg: OmegaConfig = ckpt["model_cfg"] | |
| self.model = OmegaModel(saved_cfg).to(self.device) | |
| self.model.load_state_dict(ckpt["model_state"]) | |
| self.model_cfg = saved_cfg | |
| else: | |
| self.model = OmegaModel(self.model_cfg).to(self.device) | |
| self.optimizer = torch.optim.AdamW( | |
| self.model.parameters(), lr=float(self.cfg.get("lr", 1e-6)), betas=(0.9, 0.95) | |
| ) | |
| def _sample(self, prompt_ids: torch.Tensor) -> list[str]: | |
| self.model.eval() | |
| completions: list[str] = [] | |
| for _ in range(self.cfg["n_samples"]): | |
| generated = self.model.generate( | |
| prompt_ids.to(self.device), | |
| max_new_tokens=self.cfg["max_new_tokens"], | |
| temperature=self.cfg["temperature"], | |
| top_p=0.95, | |
| ) | |
| new_tokens = generated[0, prompt_ids.shape[1]:].tolist() | |
| completions.append(self.tokenizer.decode(new_tokens)) | |
| return completions | |
| def train_step(self, batch: dict) -> float: | |
| test_cases = batch["test_cases"][0] | |
| prompt_ids = batch["prompt_ids"][:1] | |
| level = batch["levels"][0] | |
| completions = self._sample(prompt_ids) | |
| rewards = [ | |
| code_reward(c, test_cases, timeout=self.cfg["timeout_sec"]) | |
| for c in completions | |
| ] | |
| # Stats | |
| for r in rewards: | |
| self.stats["total"] += 1 | |
| if r >= 0.9: self.stats["pass_all"] += 1 | |
| elif r >= 0.4: self.stats["partial"] += 1 | |
| else: self.stats["fail"] += 1 | |
| if all(r == rewards[0] for r in rewards): | |
| return 0.0 | |
| advantages = compute_group_advantages(rewards) | |
| self.model.train() | |
| total_loss = 0.0 | |
| for comp_text, adv in zip(completions, advantages): | |
| enc = self.tokenizer.encode(comp_text) | |
| comp_ids = torch.tensor( | |
| [enc.ids[:self.cfg["max_new_tokens"]]], dtype=torch.long | |
| ) | |
| with torch.no_grad(): | |
| full_ids = torch.cat([prompt_ids.to(self.device), comp_ids.to(self.device)], dim=1) | |
| out_ref = self.model(full_ids) | |
| lp_ref = F.log_softmax(out_ref["logits"][:, :-1, :], dim=-1) | |
| target = full_ids[:, 1:] | |
| old_lp = lp_ref.gather(2, target.unsqueeze(-1)).squeeze(-1) | |
| p_len = prompt_ids.shape[1] | |
| old_seq_lp = old_lp[:, p_len - 1:p_len - 1 + comp_ids.shape[1]].mean() | |
| adv_t = torch.tensor([adv], device=self.device, dtype=self.dtype) | |
| with torch.amp.autocast( | |
| device_type=self.device.type, dtype=self.dtype, | |
| enabled=self.device.type == "cuda" | |
| ): | |
| loss = grpo_policy_loss( | |
| self.model, | |
| prompt_ids.to(self.device), | |
| comp_ids.to(self.device), | |
| adv_t, | |
| old_seq_lp, | |
| clip_eps=float(self.cfg.get("clip_eps", 0.2)), | |
| ) / self.cfg["n_samples"] | |
| loss.backward() | |
| total_loss += loss.item() * self.cfg["n_samples"] | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), float(self.cfg.get("grad_clip", 1.0)) | |
| ) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| return total_loss | |
| def save(self, tag: str = "code_grpo_latest"): | |
| path = Path(self.cfg["out_dir"]) / f"omega_{tag}.pt" | |
| torch.save({ | |
| "step": self.step, | |
| "model_state": self.model.state_dict(), | |
| "model_cfg": self.model_cfg, | |
| }, path) | |
| print(f" Saved → {path}") | |
| def _pass_rate(self) -> str: | |
| t = max(self.stats["total"], 1) | |
| return (f"pass={self.stats['pass_all']/t*100:.1f}% " | |
| f"partial={self.stats['partial']/t*100:.1f}% " | |
| f"fail={self.stats['fail']/t*100:.1f}%") | |
| def train(self): | |
| self.setup() | |
| data_iter = iter(self.loader) | |
| t0 = time.time() | |
| running_loss = 0.0 | |
| print(f"Code GRPO for {self.cfg['max_steps']:,} steps\n") | |
| while self.step < self.cfg["max_steps"]: | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(self.loader) | |
| batch = next(data_iter) | |
| running_loss += self.train_step(batch) | |
| self.step += 1 | |
| if self.step % 10 == 0: | |
| dt = time.time() - t0 | |
| print(f"step {self.step:5d} | loss {running_loss/10:.4f} | " | |
| f"{self._pass_rate()} | {dt:.1f}s") | |
| running_loss = 0.0 | |
| t0 = time.time() | |
| if self.step % self.cfg["save_every"] == 0: | |
| self.save(f"code_grpo_step{self.step}") | |
| self.save("code_grpo_final") | |
| print(f"\nCode GRPO done! Final: {self._pass_rate()}") | |
| if __name__ == "__main__": | |
| trainer = CodeGRPOTrainer() | |
| trainer.train() | |
Xet Storage Details
- Size:
- 14.5 kB
- Xet hash:
- e470a6a594dc6d3fbe2ba1d27cdaf38f09df42aa94fb3b1c02504953aa9edae9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.