natnael kahssay commited on
Commit
aae5554
Β·
1 Parent(s): 6dd8379

add training/ as real directory (Dockerfile + train.py)

Browse files
Files changed (4) hide show
  1. training +0 -1
  2. training/.gitignore +3 -0
  3. training/Dockerfile +18 -0
  4. training/train.py +130 -0
training DELETED
@@ -1 +0,0 @@
1
- Subproject commit 6e2e91b196e9185240ede4fde3629358c5455b33
 
 
training/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ /output/
training/Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime
2
+
3
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
4
+
5
+ RUN pip install --no-cache-dir \
6
+ "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
7
+ trl \
8
+ httpx \
9
+ datasets \
10
+ transformers \
11
+ accelerate \
12
+ peft \
13
+ bitsandbytes
14
+
15
+ WORKDIR /app
16
+ COPY train.py .
17
+
18
+ CMD ["python", "train.py"]
training/train.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRPO training on MOA RL environment using TRL + Unsloth.
3
+ Connects to the deployed moa-rl-env server for rewards.
4
+ """
5
+ import asyncio
6
+ import os
7
+ import httpx
8
+ from datasets import Dataset
9
+ from trl import GRPOTrainer, GRPOConfig
10
+ from unsloth import FastLanguageModel
11
+
12
+ ENV_URL = os.environ.get("ENV_URL", "https://http--moa-rl-env--7b2fgcxb6gxp.code.run")
13
+ MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Llama-3.1-8B-Instruct")
14
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/output/moa-rl-grpo")
15
+
16
+ # ── Model ──────────────────────────────────────────────────────────────────────
17
+ model, tokenizer = FastLanguageModel.from_pretrained(
18
+ model_name=MODEL_NAME,
19
+ max_seq_length=2048,
20
+ load_in_4bit=True,
21
+ dtype=None, # auto
22
+ )
23
+ model = FastLanguageModel.get_peft_model(
24
+ model,
25
+ r=16,
26
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
27
+ "gate_proj", "up_proj", "down_proj"],
28
+ lora_alpha=16,
29
+ lora_dropout=0,
30
+ bias="none",
31
+ use_gradient_checkpointing="unsloth",
32
+ random_state=42,
33
+ )
34
+
35
+ # ── Tasks dataset ──────────────────────────────────────────────────────────────
36
+ def fetch_tasks() -> list[dict]:
37
+ resp = httpx.get(f"{ENV_URL}/tasks", timeout=30)
38
+ resp.raise_for_status()
39
+ return resp.json()["tasks"]
40
+
41
+ PROMPT_TEMPLATE = """\
42
+ You are an expert TypeScript developer.
43
+ Fix the following broken file so that all tests pass.
44
+
45
+ File: {file_path}
46
+
47
+ Current content:
48
+ ```typescript
49
+ {current_content}
50
+ ```
51
+
52
+ Respond with ONLY the fixed TypeScript file contents, no explanation.
53
+ """
54
+
55
+ def build_dataset() -> Dataset:
56
+ tasks = fetch_tasks()
57
+ rows = []
58
+ for t in tasks:
59
+ prompt = PROMPT_TEMPLATE.format(
60
+ file_path=t["file_path"],
61
+ current_content=t.get("current_content", "// empty"),
62
+ )
63
+ rows.append({"prompt": prompt, "task_id": t["id"], "file_path": t["file_path"]})
64
+ return Dataset.from_list(rows)
65
+
66
+ dataset = build_dataset()
67
+
68
+ # ── Reward function ────────────────────────────────────────────────────────────
69
+ async def _call_step(session_id: str, file_path: str, content: str) -> float:
70
+ async with httpx.AsyncClient(timeout=60) as client:
71
+ resp = await client.post(f"{ENV_URL}/step", json={
72
+ "session_id": session_id,
73
+ "action": {"file_path": file_path, "content": content},
74
+ })
75
+ resp.raise_for_status()
76
+ data = resp.json()
77
+ return data["reward"]
78
+
79
+ async def _reset(task_id: str) -> str:
80
+ async with httpx.AsyncClient(timeout=30) as client:
81
+ resp = await client.post(f"{ENV_URL}/reset", json={"task_id": task_id})
82
+ resp.raise_for_status()
83
+ return resp.json()["session_id"]
84
+
85
+ def reward_fn(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
86
+ task_ids = kwargs.get("task_id", [None] * len(prompts))
87
+ file_paths = kwargs.get("file_path", [None] * len(prompts))
88
+
89
+ async def run_all():
90
+ rewards = []
91
+ for task_id, file_path, completion in zip(task_ids, file_paths, completions):
92
+ try:
93
+ session_id = await _reset(task_id)
94
+ reward = await _call_step(session_id, file_path, completion)
95
+ except Exception as e:
96
+ print(f"[reward_fn] error: {e}")
97
+ reward = 0.0
98
+ rewards.append(reward)
99
+ return rewards
100
+
101
+ return asyncio.run(run_all())
102
+
103
+ # ── Training ───────────────────────────────────────────────────────────────────
104
+ trainer = GRPOTrainer(
105
+ model=model,
106
+ tokenizer=tokenizer,
107
+ reward_funcs=[reward_fn],
108
+ args=GRPOConfig(
109
+ output_dir=OUTPUT_DIR,
110
+ num_train_epochs=3,
111
+ per_device_train_batch_size=1,
112
+ gradient_accumulation_steps=4,
113
+ learning_rate=5e-6,
114
+ lr_scheduler_type="cosine",
115
+ warmup_ratio=0.1,
116
+ logging_steps=10,
117
+ save_steps=100,
118
+ bf16=True,
119
+ report_to="none",
120
+ num_generations=4,
121
+ max_prompt_length=1024,
122
+ max_completion_length=1024,
123
+ ),
124
+ train_dataset=dataset,
125
+ )
126
+
127
+ print(f"Training on {len(dataset)} tasks against {ENV_URL}")
128
+ trainer.train()
129
+ trainer.save_model(OUTPUT_DIR)
130
+ print("Done. Model saved to", OUTPUT_DIR)