# CodeSensei — GRPO Training Notebook # ======================================== # Run this in Google Colab with a T4 GPU runtime. # Fine-tunes Qwen3-1.7B to debug Python code using GRPO. # # CELL STRUCTURE (copy each section into a separate Colab cell): # Cell 1: Install dependencies # Cell 2: Login to HF # Cell 3: Configuration + Bug Dataset # Cell 4: Build Training Dataset # Cell 5: Reward Function # Cell 6: GRPO Config + Train # ============================================================ # CELL 1: Install dependencies # ============================================================ # Run these in order in separate Colab cells: # # STEP 1 — Fix vLLM conflict (Colab ships broken/old vLLM that crashes TRL): # !pip uninstall vllm -y # !rm -rf /usr/local/lib/python3.12/dist-packages/vllm* # # STEP 2 — Install TRL and dependencies (pinned to stable 0.15.0 and Transformers 4.46.3): # !pip install -q "trl==0.15.0" "transformers==4.46.3" datasets accelerate "bitsandbytes>=0.43" peft # # Why: TRL's GRPOTrainer imports vllm.sampling_params.StructuredOutputsParams at module load # time even when you don't use vLLM. Colab's vllm==0.10.2 doesn't have this class. # Uninstalling vLLM is safe here — we're using HuggingFace Transformers, not vLLM serving. # ============================================================ # CELL 2: Login to Hugging Face # ============================================================ # from huggingface_hub import notebook_login # notebook_login() # ============================================================ # CELL 3: Configuration + Bug Dataset # ============================================================ MODEL_NAME = "Qwen/Qwen3-1.7B" OUTPUT_DIR = "codesensei-qwen3-1.7b" SYSTEM_PROMPT = """You are an expert Python debugger. You receive a buggy Python function and failing test results. Fix the bug. RULES: 1. Respond with ONLY the corrected Python function. 2. Do NOT include explanations or markdown — just raw Python code. 3. Keep the same function name and signature. 4. Make minimal changes to fix the bug. Example response: def function_name(args): # corrected implementation""" BUG_DATASET = [ { "function_name": "add_numbers", "buggy_code": "def add_numbers(a, b):\n return a - b", "bug_description": "Uses subtraction instead of addition", "tests": [ {"name": "basic addition", "code": "assert add_numbers(2, 3) == 5"}, {"name": "zero addition", "code": "assert add_numbers(0, 0) == 0"}, {"name": "negative addition", "code": "assert add_numbers(-1, 1) == 0"}, ], }, { "function_name": "find_max", "buggy_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x < result:\n result = x\n return result", "bug_description": "Uses < instead of > (finds minimum instead of maximum)", "tests": [ {"name": "basic max", "code": "assert find_max([1, 3, 2]) == 3"}, {"name": "single element", "code": "assert find_max([5]) == 5"}, {"name": "negative numbers", "code": "assert find_max([-1, -5, -2]) == -1"}, {"name": "empty list", "code": "assert find_max([]) is None"}, ], }, { "function_name": "reverse_string", "buggy_code": 'def reverse_string(s):\n return s[1:]', "bug_description": "Slices from index 1 instead of reversing", "tests": [ {"name": "basic reverse", "code": 'assert reverse_string("hello") == "olleh"'}, {"name": "empty string", "code": 'assert reverse_string("") == ""'}, {"name": "single char", "code": 'assert reverse_string("a") == "a"'}, ], }, { "function_name": "fibonacci", "buggy_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 3)", "bug_description": "Recursive call uses n-3 instead of n-2", "tests": [ {"name": "fib(0)", "code": "assert fibonacci(0) == 0"}, {"name": "fib(1)", "code": "assert fibonacci(1) == 1"}, {"name": "fib(5)", "code": "assert fibonacci(5) == 5"}, {"name": "fib(10)", "code": "assert fibonacci(10) == 55"}, ], }, { "function_name": "count_vowels", "buggy_code": "def count_vowels(s):\n count = 0\n for c in s:\n if c in 'aeiou':\n count += 1\n return count", "bug_description": "Does not handle uppercase vowels", "tests": [ {"name": "lowercase", "code": "assert count_vowels('hello') == 2"}, {"name": "uppercase", "code": "assert count_vowels('HELLO') == 2"}, {"name": "mixed case", "code": "assert count_vowels('HeLLo') == 2"}, ], }, { "function_name": "is_palindrome", "buggy_code": "def is_palindrome(s):\n s = s.lower()\n return s == s[::-1]", "bug_description": "Does not strip non-alphanumeric characters before checking", "tests": [ {"name": "basic palindrome", "code": "assert is_palindrome('racecar') == True"}, {"name": "with punctuation", "code": "assert is_palindrome('A man, a plan, a canal: Panama') == True"}, {"name": "not palindrome", "code": "assert is_palindrome('hello') == False"}, ], }, { "function_name": "flatten_list", "buggy_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.append(item)\n else:\n result.append(item)\n return result", "bug_description": "Appends nested lists instead of recursively flattening them", "tests": [ {"name": "nested", "code": "assert flatten_list([1, [2, 3], [4, [5]]]) == [1, 2, 3, 4, 5]"}, {"name": "already flat", "code": "assert flatten_list([1, 2, 3]) == [1, 2, 3]"}, {"name": "deep nesting", "code": "assert flatten_list([[[[1]]]]) == [1]"}, ], }, { "function_name": "binary_search", "buggy_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left < right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid\n else:\n right = mid - 1\n return -1", "bug_description": "Uses < instead of <= in while condition, and left=mid instead of left=mid+1", "tests": [ {"name": "found middle", "code": "assert binary_search([1,2,3,4,5], 3) == 2"}, {"name": "found first", "code": "assert binary_search([1,2,3,4,5], 1) == 0"}, {"name": "found last", "code": "assert binary_search([1,2,3,4,5], 5) == 4"}, {"name": "not found", "code": "assert binary_search([1,2,3,4,5], 6) == -1"}, ], }, { "function_name": "merge_sorted", "buggy_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n return result", "bug_description": "Missing the remaining elements after the while loop ends", "tests": [ {"name": "basic merge", "code": "assert merge_sorted([1,3,5], [2,4,6]) == [1,2,3,4,5,6]"}, {"name": "one empty", "code": "assert merge_sorted([], [1,2,3]) == [1,2,3]"}, {"name": "unequal length", "code": "assert merge_sorted([1], [2,3,4]) == [1,2,3,4]"}, ], }, { "function_name": "remove_duplicates", "buggy_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item in seen:\n result.append(item)\n seen.add(item)\n return result", "bug_description": "Condition is inverted: keeps duplicates and removes unique items", "tests": [ {"name": "basic dedup", "code": "assert remove_duplicates([1,2,2,3,3,3]) == [1,2,3]"}, {"name": "no duplicates", "code": "assert remove_duplicates([1,2,3]) == [1,2,3]"}, {"name": "all same", "code": "assert remove_duplicates([5,5,5]) == [5]"}, ], }, ] print(f"✅ Config loaded | Model: {MODEL_NAME} | Bugs: {len(BUG_DATASET)}") # ============================================================ # CELL 4: Build Training Dataset # ============================================================ import random from datasets import Dataset def make_prompt(bug): """Create a training prompt from a bug entry.""" failing_tests = "\n".join(f" - {t['code']}" for t in bug["tests"]) return ( f"{SYSTEM_PROMPT}\n\n" f"BUGGY FUNCTION:\n{bug['buggy_code']}\n\n" f"FAILING TESTS:\n{failing_tests}\n\n" f"Fix the function:" ) # Generate 10 training prompts by cycling through the bug dataset for a quick trial prompts = [] for i in range(10): bug = BUG_DATASET[i % len(BUG_DATASET)] prompts.append({"prompt": make_prompt(bug), "bug_index": i % len(BUG_DATASET)}) dataset = Dataset.from_list(prompts) print(f"✅ Dataset: {len(dataset)} prompts from {len(BUG_DATASET)} bugs") # ============================================================ # CELL 5: Reward Function (runs tests directly — no WebSocket) # ============================================================ import subprocess import tempfile import sys import os import re def extract_code(text): """Extract Python code from LLM response.""" # Try ```python ... ``` block m = re.search(r"```python\s*\n(.*?)```", text, re.DOTALL) if m: return m.group(1).strip() # Try ``` ... ``` block m = re.search(r"```\s*\n(.*?)```", text, re.DOTALL) if m: return m.group(1).strip() # If starts with 'def ', treat as raw code if text.strip().startswith("def "): return text.strip() return text.strip() def run_code(code, timeout=5): """Run code in subprocess, return (stdout, stderr, success).""" with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write(code) path = f.name try: r = subprocess.run([sys.executable, path], capture_output=True, text=True, timeout=timeout) return r.stdout, r.stderr, r.returncode == 0 except subprocess.TimeoutExpired: return "", "Timeout", False except Exception as e: return "", str(e), False finally: os.unlink(path) def reward_code_debug(completions, bug_index, **kwargs): """Reward function: extract code, run tests, return score. Reward signals are bounded strictly to (0.01, 0.99) to: 1. Prevent loss=0: even total failure gets 0.01, not -0.5 (if all completions get -0.5, group_std=0 → advantage=0 → no learning) 2. Prevent gradient explosion: full solve gets 0.99 not 2.0 3. Match OpenEnv Phase 2 grader requirements Score mapping: All tests pass → 0.99 (was 2.0) Partial pass → 0.1 + (passed/total) * 0.79 (was tests_passed * 0.3) Syntax error → 0.01 (was -0.5) Runtime error → 0.01 (was -0.5) """ rewards = [] for completion, idx in zip(completions, bug_index): # Get the completion text if isinstance(completion, list): text = completion[-1].get("content", "") if completion else "" else: text = str(completion) proposed_fix = extract_code(text) bug = BUG_DATASET[int(idx)] # Syntax check try: compile(proposed_fix, "", "exec") except SyntaxError: rewards.append(0.01) # Minimum reward, not -0.5 (prevents loss=0) continue # Run each test tests_passed = 0 total = len(bug["tests"]) for test in bug["tests"]: test_code = f"{proposed_fix}\n\n{test['code']}\nprint('PASS')" stdout, stderr, ok = run_code(test_code) if ok and "PASS" in stdout: tests_passed += 1 # Compute bounded reward (0.01, 0.99) if tests_passed == total: reward = 0.99 # Full solve elif tests_passed > 0: reward = 0.1 + (tests_passed / total) * 0.79 # Partial: 0.1–0.89 else: reward = 0.01 # All failed rewards.append(reward) return rewards print("✅ Reward function ready (inline test execution)") # ============================================================ # CELL 6: GRPO Config + Train # ============================================================ from trl import GRPOConfig, GRPOTrainer from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import LoraConfig import torch tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.pad_token = tokenizer.eos_token # 4-bit quantization — reduces model from ~4GB to ~1.5GB bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) # Load model with 4-bit quantization model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", ) print(f"✅ Model loaded in 4-bit | VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB") # LoRA config — only train small adapter weights peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM", ) grpo_config = GRPOConfig( num_train_epochs=1, learning_rate=2e-5, # Higher LR — previous run had 0 updates gradient_accumulation_steps=2, # Faster feedback loop per_device_train_batch_size=4, # Must match num_generations! warmup_steps=5, num_generations=4, # Variance in rewards to fix loss=0 max_completion_length=150, # Shorter fits in T4 VRAM temperature=0.9, # Needed for generation variance output_dir=OUTPUT_DIR, report_to="none", logging_steps=1, save_steps=10, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, push_to_hub=False, optim="paged_adamw_8bit", ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[reward_code_debug], train_dataset=dataset, args=grpo_config, peft_config=peft_config, ) print("🚀 Starting GRPO training...") print(f" Model: {MODEL_NAME}") print(f" Dataset: {len(dataset)} prompts") print(f" Saving every 5 steps to HF Hub") print() import os _ckpt = os.path.isdir(OUTPUT_DIR) and any( "checkpoint" in d for d in os.listdir(OUTPUT_DIR) if os.path.isdir(os.path.join(OUTPUT_DIR, d)) ) trainer.train(resume_from_checkpoint=True if _ckpt else None) print("\n✅ Training complete!") print(f"📦 Model saved to: {OUTPUT_DIR}") # ============================================================ # CELL 7: Push Final Model to HF Hub # ============================================================ # trainer.push_to_hub() # print(f"✅ Model pushed to hub: {OUTPUT_DIR}") # ============================================================ # CELL 8: Quick Test — Inference # ============================================================ # from transformers import pipeline # # pipe = pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer) # test_bug = "def add_numbers(a, b):\n return a - b" # prompt = f"{SYSTEM_PROMPT}\n\nBUGGY FUNCTION:\n{test_bug}\n\nFix the function:" # result = pipe(prompt, max_new_tokens=256) # print(result[0]["generated_text"])