codesensei-env / training /colab_train.py
vineetshukla.work@gmail.com
chore: test trial run for colab script with 10 dataset size and fixed dependencies
a59d79f
# CodeSensei β€” GRPO Training Notebook
# ========================================
# Run this in Google Colab with a T4 GPU runtime.
# Fine-tunes Qwen3-1.7B to debug Python code using GRPO.
#
# CELL STRUCTURE (copy each section into a separate Colab cell):
# Cell 1: Install dependencies
# Cell 2: Login to HF
# Cell 3: Configuration + Bug Dataset
# Cell 4: Build Training Dataset
# Cell 5: Reward Function
# Cell 6: GRPO Config + Train
# ============================================================
# CELL 1: Install dependencies
# ============================================================
# Run these in order in separate Colab cells:
#
# STEP 1 β€” Fix vLLM conflict (Colab ships broken/old vLLM that crashes TRL):
# !pip uninstall vllm -y
# !rm -rf /usr/local/lib/python3.12/dist-packages/vllm*
#
# STEP 2 β€” Install TRL and dependencies (pinned to stable 0.15.0 and Transformers 4.46.3):
# !pip install -q "trl==0.15.0" "transformers==4.46.3" datasets accelerate "bitsandbytes>=0.43" peft
#
# Why: TRL's GRPOTrainer imports vllm.sampling_params.StructuredOutputsParams at module load
# time even when you don't use vLLM. Colab's vllm==0.10.2 doesn't have this class.
# Uninstalling vLLM is safe here β€” we're using HuggingFace Transformers, not vLLM serving.
# ============================================================
# CELL 2: Login to Hugging Face
# ============================================================
# from huggingface_hub import notebook_login
# notebook_login()
# ============================================================
# CELL 3: Configuration + Bug Dataset
# ============================================================
MODEL_NAME = "Qwen/Qwen3-1.7B"
OUTPUT_DIR = "codesensei-qwen3-1.7b"
SYSTEM_PROMPT = """You are an expert Python debugger. You receive a buggy Python function and failing test results. Fix the bug.
RULES:
1. Respond with ONLY the corrected Python function.
2. Do NOT include explanations or markdown β€” just raw Python code.
3. Keep the same function name and signature.
4. Make minimal changes to fix the bug.
Example response:
def function_name(args):
# corrected implementation"""
BUG_DATASET = [
{
"function_name": "add_numbers",
"buggy_code": "def add_numbers(a, b):\n return a - b",
"bug_description": "Uses subtraction instead of addition",
"tests": [
{"name": "basic addition", "code": "assert add_numbers(2, 3) == 5"},
{"name": "zero addition", "code": "assert add_numbers(0, 0) == 0"},
{"name": "negative addition", "code": "assert add_numbers(-1, 1) == 0"},
],
},
{
"function_name": "find_max",
"buggy_code": "def find_max(lst):\n if not lst:\n return None\n result = lst[0]\n for x in lst:\n if x < result:\n result = x\n return result",
"bug_description": "Uses < instead of > (finds minimum instead of maximum)",
"tests": [
{"name": "basic max", "code": "assert find_max([1, 3, 2]) == 3"},
{"name": "single element", "code": "assert find_max([5]) == 5"},
{"name": "negative numbers", "code": "assert find_max([-1, -5, -2]) == -1"},
{"name": "empty list", "code": "assert find_max([]) is None"},
],
},
{
"function_name": "reverse_string",
"buggy_code": 'def reverse_string(s):\n return s[1:]',
"bug_description": "Slices from index 1 instead of reversing",
"tests": [
{"name": "basic reverse", "code": 'assert reverse_string("hello") == "olleh"'},
{"name": "empty string", "code": 'assert reverse_string("") == ""'},
{"name": "single char", "code": 'assert reverse_string("a") == "a"'},
],
},
{
"function_name": "fibonacci",
"buggy_code": "def fibonacci(n):\n if n <= 0:\n return 0\n if n == 1:\n return 1\n return fibonacci(n - 1) + fibonacci(n - 3)",
"bug_description": "Recursive call uses n-3 instead of n-2",
"tests": [
{"name": "fib(0)", "code": "assert fibonacci(0) == 0"},
{"name": "fib(1)", "code": "assert fibonacci(1) == 1"},
{"name": "fib(5)", "code": "assert fibonacci(5) == 5"},
{"name": "fib(10)", "code": "assert fibonacci(10) == 55"},
],
},
{
"function_name": "count_vowels",
"buggy_code": "def count_vowels(s):\n count = 0\n for c in s:\n if c in 'aeiou':\n count += 1\n return count",
"bug_description": "Does not handle uppercase vowels",
"tests": [
{"name": "lowercase", "code": "assert count_vowels('hello') == 2"},
{"name": "uppercase", "code": "assert count_vowels('HELLO') == 2"},
{"name": "mixed case", "code": "assert count_vowels('HeLLo') == 2"},
],
},
{
"function_name": "is_palindrome",
"buggy_code": "def is_palindrome(s):\n s = s.lower()\n return s == s[::-1]",
"bug_description": "Does not strip non-alphanumeric characters before checking",
"tests": [
{"name": "basic palindrome", "code": "assert is_palindrome('racecar') == True"},
{"name": "with punctuation", "code": "assert is_palindrome('A man, a plan, a canal: Panama') == True"},
{"name": "not palindrome", "code": "assert is_palindrome('hello') == False"},
],
},
{
"function_name": "flatten_list",
"buggy_code": "def flatten_list(lst):\n result = []\n for item in lst:\n if isinstance(item, list):\n result.append(item)\n else:\n result.append(item)\n return result",
"bug_description": "Appends nested lists instead of recursively flattening them",
"tests": [
{"name": "nested", "code": "assert flatten_list([1, [2, 3], [4, [5]]]) == [1, 2, 3, 4, 5]"},
{"name": "already flat", "code": "assert flatten_list([1, 2, 3]) == [1, 2, 3]"},
{"name": "deep nesting", "code": "assert flatten_list([[[[1]]]]) == [1]"},
],
},
{
"function_name": "binary_search",
"buggy_code": "def binary_search(arr, target):\n left, right = 0, len(arr) - 1\n while left < right:\n mid = (left + right) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n left = mid\n else:\n right = mid - 1\n return -1",
"bug_description": "Uses < instead of <= in while condition, and left=mid instead of left=mid+1",
"tests": [
{"name": "found middle", "code": "assert binary_search([1,2,3,4,5], 3) == 2"},
{"name": "found first", "code": "assert binary_search([1,2,3,4,5], 1) == 0"},
{"name": "found last", "code": "assert binary_search([1,2,3,4,5], 5) == 4"},
{"name": "not found", "code": "assert binary_search([1,2,3,4,5], 6) == -1"},
],
},
{
"function_name": "merge_sorted",
"buggy_code": "def merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n return result",
"bug_description": "Missing the remaining elements after the while loop ends",
"tests": [
{"name": "basic merge", "code": "assert merge_sorted([1,3,5], [2,4,6]) == [1,2,3,4,5,6]"},
{"name": "one empty", "code": "assert merge_sorted([], [1,2,3]) == [1,2,3]"},
{"name": "unequal length", "code": "assert merge_sorted([1], [2,3,4]) == [1,2,3,4]"},
],
},
{
"function_name": "remove_duplicates",
"buggy_code": "def remove_duplicates(lst):\n seen = set()\n result = []\n for item in lst:\n if item in seen:\n result.append(item)\n seen.add(item)\n return result",
"bug_description": "Condition is inverted: keeps duplicates and removes unique items",
"tests": [
{"name": "basic dedup", "code": "assert remove_duplicates([1,2,2,3,3,3]) == [1,2,3]"},
{"name": "no duplicates", "code": "assert remove_duplicates([1,2,3]) == [1,2,3]"},
{"name": "all same", "code": "assert remove_duplicates([5,5,5]) == [5]"},
],
},
]
print(f"βœ… Config loaded | Model: {MODEL_NAME} | Bugs: {len(BUG_DATASET)}")
# ============================================================
# CELL 4: Build Training Dataset
# ============================================================
import random
from datasets import Dataset
def make_prompt(bug):
"""Create a training prompt from a bug entry."""
failing_tests = "\n".join(f" - {t['code']}" for t in bug["tests"])
return (
f"{SYSTEM_PROMPT}\n\n"
f"BUGGY FUNCTION:\n{bug['buggy_code']}\n\n"
f"FAILING TESTS:\n{failing_tests}\n\n"
f"Fix the function:"
)
# Generate 10 training prompts by cycling through the bug dataset for a quick trial
prompts = []
for i in range(10):
bug = BUG_DATASET[i % len(BUG_DATASET)]
prompts.append({"prompt": make_prompt(bug), "bug_index": i % len(BUG_DATASET)})
dataset = Dataset.from_list(prompts)
print(f"βœ… Dataset: {len(dataset)} prompts from {len(BUG_DATASET)} bugs")
# ============================================================
# CELL 5: Reward Function (runs tests directly β€” no WebSocket)
# ============================================================
import subprocess
import tempfile
import sys
import os
import re
def extract_code(text):
"""Extract Python code from LLM response."""
# Try ```python ... ``` block
m = re.search(r"```python\s*\n(.*?)```", text, re.DOTALL)
if m:
return m.group(1).strip()
# Try ``` ... ``` block
m = re.search(r"```\s*\n(.*?)```", text, re.DOTALL)
if m:
return m.group(1).strip()
# If starts with 'def ', treat as raw code
if text.strip().startswith("def "):
return text.strip()
return text.strip()
def run_code(code, timeout=5):
"""Run code in subprocess, return (stdout, stderr, success)."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(code)
path = f.name
try:
r = subprocess.run([sys.executable, path], capture_output=True, text=True, timeout=timeout)
return r.stdout, r.stderr, r.returncode == 0
except subprocess.TimeoutExpired:
return "", "Timeout", False
except Exception as e:
return "", str(e), False
finally:
os.unlink(path)
def reward_code_debug(completions, bug_index, **kwargs):
"""Reward function: extract code, run tests, return score.
Reward signals are bounded strictly to (0.01, 0.99) to:
1. Prevent loss=0: even total failure gets 0.01, not -0.5
(if all completions get -0.5, group_std=0 β†’ advantage=0 β†’ no learning)
2. Prevent gradient explosion: full solve gets 0.99 not 2.0
3. Match OpenEnv Phase 2 grader requirements
Score mapping:
All tests pass β†’ 0.99 (was 2.0)
Partial pass β†’ 0.1 + (passed/total) * 0.79 (was tests_passed * 0.3)
Syntax error β†’ 0.01 (was -0.5)
Runtime error β†’ 0.01 (was -0.5)
"""
rewards = []
for completion, idx in zip(completions, bug_index):
# Get the completion text
if isinstance(completion, list):
text = completion[-1].get("content", "") if completion else ""
else:
text = str(completion)
proposed_fix = extract_code(text)
bug = BUG_DATASET[int(idx)]
# Syntax check
try:
compile(proposed_fix, "<fix>", "exec")
except SyntaxError:
rewards.append(0.01) # Minimum reward, not -0.5 (prevents loss=0)
continue
# Run each test
tests_passed = 0
total = len(bug["tests"])
for test in bug["tests"]:
test_code = f"{proposed_fix}\n\n{test['code']}\nprint('PASS')"
stdout, stderr, ok = run_code(test_code)
if ok and "PASS" in stdout:
tests_passed += 1
# Compute bounded reward (0.01, 0.99)
if tests_passed == total:
reward = 0.99 # Full solve
elif tests_passed > 0:
reward = 0.1 + (tests_passed / total) * 0.79 # Partial: 0.1–0.89
else:
reward = 0.01 # All failed
rewards.append(reward)
return rewards
print("βœ… Reward function ready (inline test execution)")
# ============================================================
# CELL 6: GRPO Config + Train
# ============================================================
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
import torch
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
# 4-bit quantization β€” reduces model from ~4GB to ~1.5GB
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load model with 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
)
print(f"βœ… Model loaded in 4-bit | VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB")
# LoRA config β€” only train small adapter weights
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM",
)
grpo_config = GRPOConfig(
num_train_epochs=1,
learning_rate=2e-5, # Higher LR β€” previous run had 0 updates
gradient_accumulation_steps=2, # Faster feedback loop
per_device_train_batch_size=4, # Must match num_generations!
warmup_steps=5,
num_generations=4, # Variance in rewards to fix loss=0
max_completion_length=150, # Shorter fits in T4 VRAM
temperature=0.9, # Needed for generation variance
output_dir=OUTPUT_DIR,
report_to="none",
logging_steps=1,
save_steps=10,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
push_to_hub=False,
optim="paged_adamw_8bit",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[reward_code_debug],
train_dataset=dataset,
args=grpo_config,
peft_config=peft_config,
)
print("πŸš€ Starting GRPO training...")
print(f" Model: {MODEL_NAME}")
print(f" Dataset: {len(dataset)} prompts")
print(f" Saving every 5 steps to HF Hub")
print()
import os
_ckpt = os.path.isdir(OUTPUT_DIR) and any(
"checkpoint" in d for d in os.listdir(OUTPUT_DIR)
if os.path.isdir(os.path.join(OUTPUT_DIR, d))
)
trainer.train(resume_from_checkpoint=True if _ckpt else None)
print("\nβœ… Training complete!")
print(f"πŸ“¦ Model saved to: {OUTPUT_DIR}")
# ============================================================
# CELL 7: Push Final Model to HF Hub
# ============================================================
# trainer.push_to_hub()
# print(f"βœ… Model pushed to hub: {OUTPUT_DIR}")
# ============================================================
# CELL 8: Quick Test β€” Inference
# ============================================================
# from transformers import pipeline
#
# pipe = pipeline("text-generation", model=OUTPUT_DIR, tokenizer=tokenizer)
# test_bug = "def add_numbers(a, b):\n return a - b"
# prompt = f"{SYSTEM_PROMPT}\n\nBUGGY FUNCTION:\n{test_bug}\n\nFix the function:"
# result = pipe(prompt, max_new_tokens=256)
# print(result[0]["generated_text"])