Spaces:
Sleeping
Sleeping
Fix merge: fall back to warm-start adapter from HF when GRPO skipped
Browse files- train_on_hf.py +15 -2
train_on_hf.py
CHANGED
|
@@ -331,16 +331,29 @@ def train_strategist(data_dir: Path, max_samples: int = 10000):
|
|
| 331 |
|
| 332 |
def merge_and_push(hf_token: str):
|
| 333 |
"""Merge LoRA, push merged model to HF Hub."""
|
|
|
|
| 334 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 335 |
from peft import PeftModel
|
| 336 |
|
| 337 |
config = json.load(open("data/preprocessing_config.json"))
|
| 338 |
MODEL_NAME = config["model"]["name"]
|
| 339 |
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu")
|
| 342 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 343 |
-
model = PeftModel.from_pretrained(base,
|
| 344 |
merged = model.merge_and_unload()
|
| 345 |
|
| 346 |
merged.save_pretrained("./strategist_merged")
|
|
|
|
| 331 |
|
| 332 |
def merge_and_push(hf_token: str):
|
| 333 |
"""Merge LoRA, push merged model to HF Hub."""
|
| 334 |
+
import os
|
| 335 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 336 |
from peft import PeftModel
|
| 337 |
|
| 338 |
config = json.load(open("data/preprocessing_config.json"))
|
| 339 |
MODEL_NAME = config["model"]["name"]
|
| 340 |
|
| 341 |
+
# Use strategist_final if it exists, otherwise fall back to warm-start
|
| 342 |
+
adapter_path = "./strategist_final" if os.path.exists("./strategist_final/adapter_config.json") else "./strategist_warmstart"
|
| 343 |
+
if not os.path.exists(adapter_path):
|
| 344 |
+
# If neither local dir exists, download the warm-start from HF
|
| 345 |
+
from huggingface_hub import snapshot_download
|
| 346 |
+
adapter_path = snapshot_download(
|
| 347 |
+
repo_id="Rayugacodes/kernelx-strategist",
|
| 348 |
+
allow_patterns=["adapter/*"],
|
| 349 |
+
local_dir="./hf_adapter",
|
| 350 |
+
)
|
| 351 |
+
adapter_path = "./hf_adapter/adapter"
|
| 352 |
+
|
| 353 |
+
print(f"\n=== Merging LoRA from {adapter_path} and pushing to HF ===")
|
| 354 |
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu")
|
| 355 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 356 |
+
model = PeftModel.from_pretrained(base, adapter_path)
|
| 357 |
merged = model.merge_and_unload()
|
| 358 |
|
| 359 |
merged.save_pretrained("./strategist_merged")
|