Spaces:
Sleeping
Sleeping
vineetshukla.work@gmail.com
chore: test trial run for colab script with 10 dataset size and fixed dependencies
a59d79f | # CodeSensei β GRPO Training Notebook | |
| # ======================================== | |
| # Run this in Google Colab with a T4 GPU runtime. | |
| # Fine-tunes Qwen3-1.7B to debug Python code using GRPO. | |
| # | |
| # CELL STRUCTURE (copy each section into a separate Colab cell): | |
| # Cell 1: Install dependencies | |
| # Cell 2: Login to HF | |
| # Cell 3: Configuration + Bug Dataset | |
| # Cell 4: Build Training Dataset | |
| # Cell 5: Reward Function | |
| # Cell 6: GRPO Config + Train | |
| # ============================================================ | |
| # CELL 1: Install dependencies | |
| # ============================================================ | |
| # Run these in order in separate Colab cells: | |
| # | |
| # STEP 1 β Fix vLLM conflict (Colab ships broken/old vLLM that crashes TRL): | |
| # !pip uninstall vllm -y | |
| # !rm -rf /usr/local/lib/python3.12/dist-packages/vllm* | |
| # | |
| # STEP 2 β Install TRL and dependencies (pinned to stable 0.15.0 and Transformers 4.46.3): | |
| # !pip install -q "trl==0.15.0" "transformers==4.46.3" datasets accelerate "bitsandbytes>=0.43" peft | |
| # | |
| # Why: TRL's GRPOTrainer imports vllm.sampling_params.StructuredOutputsParams at module load | |
| # time even when you don't use vLLM. Colab's vllm==0.10.2 doesn't have this class. | |
| # Uninstalling vLLM is safe here β we're using HuggingFace Transformers, not vLLM serving. | |
| # ============================================================ | |
| # CELL 2: Login to Hugging Face | |
| # ============================================================ | |
| # from huggingface_hub import notebook_login | |
| # notebook_login() | |
| # ============================================================ | |
| # CELL 3: Configuration + Bug Dataset | |
| # ============================================================ | |
| MODEL_NAME = "Qwen/Qwen3-1.7B" | |
| OUTPUT_DIR = "codesensei-qwen3-1.7b" | |
| SYSTEM_PROMPT = """You are an expert Python debugger. You receive a buggy Python function and failing test results. Fix the bug. | |
| RULES: | |
| 1. Respond with ONLY the corrected Python function. | |
| 2. Do NOT include explanations or markdown β just raw Python code. | |
| 3. Keep the same function name and signature. | |
| 4. Make minimal changes to fix the bug. | |
| Example response: | |
| def function_name(args): | |
| # corrected implementation""" | |
| BUG_DATASET = [ | |
| { | |
| "function_name": "add_numbers", | |
| "buggy_code": "def add_numbers(a, b):\n return a - b", | |
| "bug_description": "Uses subtraction instead of addition", | |
| "tests": [ | |
| {"name": "basic addition", "code": "assert add_numbers(2, 3) == 5"}, | |
| {"name": "zero addition", "code": "assert add_numbers(0, 0) == 0"}, | |
| {"name": "negative addition", "code": "assert add_numbers(-1, 1) == 0"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "find_max", | |
| "buggy_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x < result:\n result = x\n return result", | |
| "bug_description": "Uses < instead of > (finds minimum instead of maximum)", | |
| "tests": [ | |
| {"name": "basic max", "code": "assert find_max([1, 3, 2]) == 3"}, | |
| {"name": "single element", "code": "assert find_max([5]) == 5"}, | |
| {"name": "negative numbers", "code": "assert find_max([-1, -5, -2]) == -1"}, | |
| {"name": "empty list", "code": "assert find_max([]) is None"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "reverse_string", | |
| "buggy_code": 'def reverse_string(s):\n return s[1:]', | |
| "bug_description": "Slices from index 1 instead of reversing", | |
| "tests": [ | |
| {"name": "basic reverse", "code": 'assert reverse_string("hello") == "olleh"'}, | |
| {"name": "empty string", "code": 'assert reverse_string("") == ""'}, | |
| {"name": "single char", "code": 'assert reverse_string("a") == "a"'}, | |
| ], | |
| }, | |
| { | |
| "function_name": "fibonacci", | |
| "buggy_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 3)", | |
| "bug_description": "Recursive call uses n-3 instead of n-2", | |
| "tests": [ | |
| {"name": "fib(0)", "code": "assert fibonacci(0) == 0"}, | |
| {"name": "fib(1)", "code": "assert fibonacci(1) == 1"}, | |
| {"name": "fib(5)", "code": "assert fibonacci(5) == 5"}, | |
| {"name": "fib(10)", "code": "assert fibonacci(10) == 55"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "count_vowels", | |
| "buggy_code": "def count_vowels(s):\n count = 0\n for c in s:\n if c in 'aeiou':\n count += 1\n return count", | |
| "bug_description": "Does not handle uppercase vowels", | |
| "tests": [ | |
| {"name": "lowercase", "code": "assert count_vowels('hello') == 2"}, | |
| {"name": "uppercase", "code": "assert count_vowels('HELLO') == 2"}, | |
| {"name": "mixed case", "code": "assert count_vowels('HeLLo') == 2"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "is_palindrome", | |
| "buggy_code": "def is_palindrome(s):\n s = s.lower()\n return s == s[::-1]", | |
| "bug_description": "Does not strip non-alphanumeric characters before checking", | |
| "tests": [ | |
| {"name": "basic palindrome", "code": "assert is_palindrome('racecar') == True"}, | |
| {"name": "with punctuation", "code": "assert is_palindrome('A man, a plan, a canal: Panama') == True"}, | |
| {"name": "not palindrome", "code": "assert is_palindrome('hello') == False"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "flatten_list", | |
| "buggy_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.append(item)\n else:\n result.append(item)\n return result", | |
| "bug_description": "Appends nested lists instead of recursively flattening them", | |
| "tests": [ | |
| {"name": "nested", "code": "assert flatten_list([1, [2, 3], [4, [5]]]) == [1, 2, 3, 4, 5]"}, | |
| {"name": "already flat", "code": "assert flatten_list([1, 2, 3]) == [1, 2, 3]"}, | |
| {"name": "deep nesting", "code": "assert flatten_list([[[[1]]]]) == [1]"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "binary_search", | |
| "buggy_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left < right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid\n else:\n right = mid - 1\n return -1", | |
| "bug_description": "Uses < instead of <= in while condition, and left=mid instead of left=mid+1", | |
| "tests": [ | |
| {"name": "found middle", "code": "assert binary_search([1,2,3,4,5], 3) == 2"}, | |
| {"name": "found first", "code": "assert binary_search([1,2,3,4,5], 1) == 0"}, | |
| {"name": "found last", "code": "assert binary_search([1,2,3,4,5], 5) == 4"}, | |
| {"name": "not found", "code": "assert binary_search([1,2,3,4,5], 6) == -1"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "merge_sorted", | |
| "buggy_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n return result", | |
| "bug_description": "Missing the remaining elements after the while loop ends", | |
| "tests": [ | |
| {"name": "basic merge", "code": "assert merge_sorted([1,3,5], [2,4,6]) == [1,2,3,4,5,6]"}, | |
| {"name": "one empty", "code": "assert merge_sorted([], [1,2,3]) == [1,2,3]"}, | |
| {"name": "unequal length", "code": "assert merge_sorted([1], [2,3,4]) == [1,2,3,4]"}, | |
| ], | |
| }, | |
| { | |
| "function_name": "remove_duplicates", | |
| "buggy_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item in seen:\n result.append(item)\n seen.add(item)\n return result", | |
| "bug_description": "Condition is inverted: keeps duplicates and removes unique items", | |
| "tests": [ | |
| {"name": "basic dedup", "code": "assert remove_duplicates([1,2,2,3,3,3]) == [1,2,3]"}, | |
| {"name": "no duplicates", "code": "assert remove_duplicates([1,2,3]) == [1,2,3]"}, | |
| {"name": "all same", "code": "assert remove_duplicates([5,5,5]) == [5]"}, | |
| ], | |
| }, | |
| ] | |
| print(f"β Config loaded | Model: {MODEL_NAME} | Bugs: {len(BUG_DATASET)}") | |
| # ============================================================ | |
| # CELL 4: Build Training Dataset | |
| # ============================================================ | |
| import random | |
| from datasets import Dataset | |
| def make_prompt(bug): | |
| """Create a training prompt from a bug entry.""" | |
| failing_tests = "\n".join(f" - {t['code']}" for t in bug["tests"]) | |
| return ( | |
| f"{SYSTEM_PROMPT}\n\n" | |
| f"BUGGY FUNCTION:\n{bug['buggy_code']}\n\n" | |
| f"FAILING TESTS:\n{failing_tests}\n\n" | |
| f"Fix the function:" | |
| ) | |
| # Generate 10 training prompts by cycling through the bug dataset for a quick trial | |
| prompts = [] | |
| for i in range(10): | |
| bug = BUG_DATASET[i % len(BUG_DATASET)] | |
| prompts.append({"prompt": make_prompt(bug), "bug_index": i % len(BUG_DATASET)}) | |
| dataset = Dataset.from_list(prompts) | |
| print(f"β Dataset: {len(dataset)} prompts from {len(BUG_DATASET)} bugs") | |
| # ============================================================ | |
| # CELL 5: Reward Function (runs tests directly β no WebSocket) | |
| # ============================================================ | |
| import subprocess | |
| import tempfile | |
| import sys | |
| import os | |
| import re | |
| def extract_code(text): | |
| """Extract Python code from LLM response.""" | |
| # Try ```python ... ``` block | |
| m = re.search(r"```python\s*\n(.*?)```", text, re.DOTALL) | |
| if m: | |
| return m.group(1).strip() | |
| # Try ``` ... ``` block | |
| m = re.search(r"```\s*\n(.*?)```", text, re.DOTALL) | |
| if m: | |
| return m.group(1).strip() | |
| # If starts with 'def ', treat as raw code | |
| if text.strip().startswith("def "): | |
| return text.strip() | |
| return text.strip() | |
| def run_code(code, timeout=5): | |
| """Run code in subprocess, return (stdout, stderr, success).""" | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: | |
| f.write(code) | |
| path = f.name | |
| try: | |
| r = subprocess.run([sys.executable, path], capture_output=True, text=True, timeout=timeout) | |
| return r.stdout, r.stderr, r.returncode == 0 | |
| except subprocess.TimeoutExpired: | |
| return "", "Timeout", False | |
| except Exception as e: | |
| return "", str(e), False | |
| finally: | |
| os.unlink(path) | |
| def reward_code_debug(completions, bug_index, **kwargs): | |
| """Reward function: extract code, run tests, return score. | |
| Reward signals are bounded strictly to (0.01, 0.99) to: | |
| 1. Prevent loss=0: even total failure gets 0.01, not -0.5 | |
| (if all completions get -0.5, group_std=0 β advantage=0 β no learning) | |
| 2. Prevent gradient explosion: full solve gets 0.99 not 2.0 | |
| 3. Match OpenEnv Phase 2 grader requirements | |
| Score mapping: | |
| All tests pass β 0.99 (was 2.0) | |
| Partial pass β 0.1 + (passed/total) * 0.79 (was tests_passed * 0.3) | |
| Syntax error β 0.01 (was -0.5) | |
| Runtime error β 0.01 (was -0.5) | |
| """ | |
| rewards = [] | |
| for completion, idx in zip(completions, bug_index): | |
| # Get the completion text | |
| if isinstance(completion, list): | |
| text = completion[-1].get("content", "") if completion else "" | |
| else: | |
| text = str(completion) | |
| proposed_fix = extract_code(text) | |
| bug = BUG_DATASET[int(idx)] | |
| # Syntax check | |
| try: | |
| compile(proposed_fix, "<fix>", "exec") | |
| except SyntaxError: | |
| rewards.append(0.01) # Minimum reward, not -0.5 (prevents loss=0) | |
| continue | |
| # Run each test | |
| tests_passed = 0 | |
| total = len(bug["tests"]) | |
| for test in bug["tests"]: | |
| test_code = f"{proposed_fix}\n\n{test['code']}\nprint('PASS')" | |
| stdout, stderr, ok = run_code(test_code) | |
| if ok and "PASS" in stdout: | |
| tests_passed += 1 | |
| # Compute bounded reward (0.01, 0.99) | |
| if tests_passed == total: | |
| reward = 0.99 # Full solve | |
| elif tests_passed > 0: | |
| reward = 0.1 + (tests_passed / total) * 0.79 # Partial: 0.1β0.89 | |
| else: | |
| reward = 0.01 # All failed | |
| rewards.append(reward) | |
| return rewards | |
| print("β Reward function ready (inline test execution)") | |
| # ============================================================ | |
| # CELL 6: GRPO Config + Train | |
| # ============================================================ | |
| from trl import GRPOConfig, GRPOTrainer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import LoraConfig | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 4-bit quantization β reduces model from ~4GB to ~1.5GB | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # Load model with 4-bit quantization | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| print(f"β Model loaded in 4-bit | VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB") | |
| # LoRA config β only train small adapter weights | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| task_type="CAUSAL_LM", | |
| ) | |
| grpo_config = GRPOConfig( | |
| num_train_epochs=1, | |
| learning_rate=2e-5, # Higher LR β previous run had 0 updates | |
| gradient_accumulation_steps=2, # Faster feedback loop | |
| per_device_train_batch_size=4, # Must match num_generations! | |
| warmup_steps=5, | |
| num_generations=4, # Variance in rewards to fix loss=0 | |
| max_completion_length=150, # Shorter fits in T4 VRAM | |
| temperature=0.9, # Needed for generation variance | |
| output_dir=OUTPUT_DIR, | |
| report_to="none", | |
| logging_steps=1, | |
| save_steps=10, | |
| gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| push_to_hub=False, | |
| optim="paged_adamw_8bit", | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[reward_code_debug], | |
| train_dataset=dataset, | |
| args=grpo_config, | |
| peft_config=peft_config, | |
| ) | |
| print("π Starting GRPO training...") | |
| print(f" Model: {MODEL_NAME}") | |
| print(f" Dataset: {len(dataset)} prompts") | |
| print(f" Saving every 5 steps to HF Hub") | |
| print() | |
| import os | |
| _ckpt = os.path.isdir(OUTPUT_DIR) and any( | |
| "checkpoint" in d for d in os.listdir(OUTPUT_DIR) | |
| if os.path.isdir(os.path.join(OUTPUT_DIR, d)) | |
| ) | |
| trainer.train(resume_from_checkpoint=True if _ckpt else None) | |
| print("\nβ Training complete!") | |
| print(f"π¦ Model saved to: {OUTPUT_DIR}") | |
| # ============================================================ | |
| # CELL 7: Push Final Model to HF Hub | |
| # ============================================================ | |
| # trainer.push_to_hub() | |
| # print(f"β Model pushed to hub: {OUTPUT_DIR}") | |
| # ============================================================ | |
| # CELL 8: Quick Test β Inference | |
| # ============================================================ | |
| # from transformers import pipeline | |
| # | |
| # pipe = pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer) | |
| # test_bug = "def add_numbers(a, b):\n return a - b" | |
| # prompt = f"{SYSTEM_PROMPT}\n\nBUGGY FUNCTION:\n{test_bug}\n\nFix the function:" | |
| # result = pipe(prompt, max_new_tokens=256) | |
| # print(result[0]["generated_text"]) | |