| |
| """ |
| KernelX — Full GPU Training Script for Hugging Face |
| |
| Run this on a HF Space or notebook with GPU (T4/A10/A100). |
| It handles everything: download data, train World Model, train Strategist (GRPO), |
| merge LoRA, export GGUF, and push results back to HF Hub. |
| |
| Usage (on HF with GPU): |
| pip install torch transformers trl peft datasets accelerate huggingface_hub |
| python train_on_hf.py --hf-token YOUR_TOKEN |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
|
|
| def setup(hf_token: str): |
| """Login and download data from HF.""" |
| from huggingface_hub import login, hf_hub_download, snapshot_download |
| login(token=hf_token) |
|
|
| |
| data_dir = Path("data") |
| data_dir.mkdir(exist_ok=True) |
|
|
| for fname in ["state_transitions.jsonl", "train.jsonl", "val.jsonl", "test.jsonl", "preprocessing_config.json"]: |
| path = hf_hub_download( |
| repo_id="Rayugacodes/kernelx-training-data", |
| filename=fname, |
| repo_type="dataset", |
| local_dir=str(data_dir), |
| ) |
| print(f"Downloaded {fname}") |
|
|
| |
| snapshot_download( |
| repo_id="Rayugacodes/kernelx-strategist", |
| local_dir="model_repo", |
| allow_patterns=["training/**"], |
| ) |
| print("Downloaded training scripts") |
|
|
| return data_dir |
|
|
|
|
| def train_world_model(data_dir: Path, max_samples: int = 50000): |
| """Stage 2: Train World Model via SFT.""" |
| from datasets import Dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
|
|
| config = json.load(open(data_dir / "preprocessing_config.json")) |
| MODEL_NAME = config["model"]["name"] |
| FEATURE_NAMES = config["feature_names"] |
|
|
| def format_state(features): |
| parts = [] |
| for name, val in zip(FEATURE_NAMES, features): |
| if val == int(val): |
| parts.append(f"{name}:{int(val)}") |
| else: |
| parts.append(f"{name}:{val:.2f}") |
| return " | ".join(parts) |
|
|
| def make_sft_example(record): |
| state_str = format_state(record["state"]) |
| action_str = f"{record['action']:.4f}" |
| next_state_str = format_state(record["next_state"]) |
| text = ( |
| "<|system|>You are a Linux kernel simulator. " |
| "Predict the next system state.<|end|>\n" |
| f"<|user|>[STATE] {state_str}\n" |
| f"[ACTION] {action_str}\n" |
| f"[PID] {record['pid']}\n" |
| "Predict [NEXT_STATE]<|end|>\n" |
| f"<|assistant|>[NEXT_STATE] {next_state_str}<|end|>" |
| ) |
| return {"text": text} |
|
|
| print("\n=== Stage 2: World Model SFT ===") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto") |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| train_records = [json.loads(l) for l in open(data_dir / "train.jsonl") if l.strip()][:max_samples] |
| val_records = [json.loads(l) for l in open(data_dir / "val.jsonl") if l.strip()][:max_samples // 8] |
|
|
| train_dataset = Dataset.from_list([make_sft_example(r) for r in train_records]) |
| val_dataset = Dataset.from_list([make_sft_example(r) for r in val_records]) |
| print(f" Train: {len(train_dataset)} Val: {len(val_dataset)}") |
|
|
| lora_config = LoraConfig( |
| r=16, lora_alpha=32, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", |
| ) |
|
|
| training_args = SFTConfig( |
| output_dir="./world_model_checkpoints", |
| num_train_epochs=3, |
| per_device_train_batch_size=8, |
| gradient_accumulation_steps=2, |
| learning_rate=2e-4, |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.1, |
| logging_steps=10, |
| eval_strategy="steps", |
| eval_steps=200, |
| save_steps=500, |
| save_total_limit=2, |
| fp16=True, |
| max_length=512, |
| report_to="none", |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, args=training_args, |
| train_dataset=train_dataset, eval_dataset=val_dataset, |
| peft_config=lora_config, |
| ) |
|
|
| trainer.train() |
| trainer.save_model("./world_model_final") |
| tokenizer.save_pretrained("./world_model_final") |
| print("World Model saved.") |
| return model, tokenizer |
|
|
|
|
| def train_strategist(data_dir: Path, max_samples: int = 10000): |
| """Stage 3: Warm-start SFT + GRPO for the Strategist.""" |
| import re |
| import random |
| import numpy as np |
| from datasets import Dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig, GRPOConfig, GRPOTrainer |
|
|
| config = json.load(open(data_dir / "preprocessing_config.json")) |
| MODEL_NAME = config["model"]["name"] |
| FEATURE_NAMES = config["feature_names"] |
| IDX_WAIT_US = 9 |
| IDX_CTX_SWITCHES = 8 |
| IDX_EXEC_NS = 4 |
|
|
| def format_state(features): |
| parts = [] |
| for name, val in zip(FEATURE_NAMES, features): |
| if val == int(val): |
| parts.append(f"{name}:{int(val)}") |
| else: |
| parts.append(f"{name}:{val:.2f}") |
| return " | ".join(parts) |
|
|
| def build_prompt(state, pid, cpu): |
| state_str = format_state(state) |
| return ( |
| "<|system|>You are a Linux kernel scheduling strategist. " |
| "Given the current system state, output a scheduling action.<|end|>\n" |
| f"<|user|>[STATE] {state_str}\n" |
| f"[PID] {pid} [CPU] {cpu}\n" |
| "[ACTION]<|end|>\n" |
| "<|assistant|>" |
| ) |
|
|
| def parse_action(text): |
| m = re.search(r"\[ACTION\]\s*([-+]?\d*\.?\d+)", text) |
| if not m: |
| m = re.search(r"([-+]?\d*\.?\d+)", text) |
| if not m: |
| raise ValueError("No action found") |
| return float(m.group(1)) |
|
|
| |
| all_records = [json.loads(l) for l in open(data_dir / "train.jsonl") if l.strip()] |
| records = random.sample(all_records, min(max_samples, len(all_records))) |
| print(f"\n=== Stage 3: Strategist Training ({len(records)} samples) ===") |
|
|
| |
| print("\n--- Phase 1: Warm-start SFT ---") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto") |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| warmstart_examples = [] |
| for rec in records[:500]: |
| state = rec["state"] |
| wait_us = state[IDX_WAIT_US] |
| csw = state[IDX_CTX_SWITCHES] |
| if wait_us > 15: |
| action = -0.6 |
| elif csw > 10: |
| action = -0.3 |
| elif wait_us < 3: |
| action = 0.1 |
| else: |
| action = 0.05 |
| prompt = build_prompt(state, rec["pid"], rec["cpu"]) |
| warmstart_examples.append({"text": f"{prompt}{action:.4f}<|end|>"}) |
|
|
| ws_dataset = Dataset.from_list(warmstart_examples) |
|
|
| lora_config = LoraConfig( |
| r=16, lora_alpha=32, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", |
| ) |
|
|
| ws_args = SFTConfig( |
| output_dir="./strategist_warmstart", |
| num_train_epochs=2, |
| per_device_train_batch_size=8, |
| gradient_accumulation_steps=2, |
| learning_rate=2e-4, |
| fp16=True, |
| max_length=512, |
| logging_steps=5, |
| save_steps=100, |
| report_to="none", |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, args=ws_args, |
| train_dataset=ws_dataset, peft_config=lora_config, |
| ) |
| trainer.train() |
| trainer.save_model("./strategist_warmstart") |
| tokenizer.save_pretrained("./strategist_warmstart") |
| print("Warm-start complete.") |
|
|
| |
| print("\n--- Phase 2: GRPO RL Training ---") |
|
|
| |
| all_states = np.array([r["state"] for r in records]) |
| all_next_states = [r["next_state"] for r in records] |
|
|
| def simulate(state_features, action_val): |
| state_arr = np.array(state_features) |
| dists = np.linalg.norm(all_states[:500] - state_arr, axis=1) |
| return all_next_states[int(np.argmin(dists))] |
|
|
| def reward_fn(completions, prompts): |
| rewards = [] |
| for prompt, completion in zip(prompts, completions): |
| try: |
| |
| state_match = re.search(r"\[STATE\]\s*(.+?)(?:\n|$)", prompt) |
| values = [] |
| for part in state_match.group(1).split("|"): |
| part = part.strip() |
| if ":" in part: |
| values.append(float(part.split(":")[1])) |
|
|
| action_val = parse_action(completion) |
| next_state = simulate(values, action_val) |
|
|
| |
| exec_delta = next_state[IDX_EXEC_NS] - values[IDX_EXEC_NS] |
| r_throughput = float(np.log(max(0.0, exec_delta) + 1)) |
| wait_delta = next_state[IDX_WAIT_US] - values[IDX_WAIT_US] |
| r_latency = -2.0 * max(0.0, wait_delta) |
| r_stability = -0.5 * abs(action_val) |
| r_format = 1.0 if -1.0 <= action_val <= 1.0 else 0.0 |
|
|
| rewards.append(r_throughput + r_latency + r_stability + r_format) |
| except (ValueError, IndexError, AttributeError): |
| rewards.append(-5.0) |
| return rewards |
|
|
| prompt_dataset = Dataset.from_list([ |
| {"prompt": build_prompt(r["state"], r["pid"], r["cpu"])} |
| for r in records |
| ]) |
|
|
| grpo_lora = LoraConfig( |
| r=16, lora_alpha=32, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", |
| ) |
|
|
| grpo_config = GRPOConfig( |
| output_dir="./strategist_grpo", |
| num_train_epochs=1, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=8, |
| learning_rate=5e-6, |
| num_generations=4, |
| max_completion_length=16, |
| max_prompt_length=384, |
| logging_steps=5, |
| save_steps=200, |
| save_total_limit=2, |
| temperature=0.7, |
| fp16=True, |
| report_to="none", |
| ) |
|
|
| grpo_trainer = GRPOTrainer( |
| model=model, |
| args=grpo_config, |
| train_dataset=prompt_dataset, |
| reward_funcs=reward_fn, |
| peft_config=grpo_lora, |
| ) |
|
|
| grpo_trainer.train() |
| grpo_trainer.save_model("./strategist_final") |
| tokenizer.save_pretrained("./strategist_final") |
| print("GRPO training complete.") |
|
|
| return model, tokenizer |
|
|
|
|
| def merge_and_push(hf_token: str): |
| """Merge LoRA, push merged model to HF Hub.""" |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| from huggingface_hub import login |
| login(token=hf_token) |
|
|
| config = json.load(open("data/preprocessing_config.json")) |
| MODEL_NAME = config["model"]["name"] |
|
|
| print("\n=== Merging LoRA and pushing to HF ===") |
| base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = PeftModel.from_pretrained(base, "./strategist_final") |
| merged = model.merge_and_unload() |
|
|
| merged.save_pretrained("./strategist_merged") |
| tokenizer.save_pretrained("./strategist_merged") |
|
|
| merged.push_to_hub("Rayugacodes/kernelx-strategist", commit_message="Merged strategist (warm-start + GRPO)") |
| tokenizer.push_to_hub("Rayugacodes/kernelx-strategist", commit_message="Tokenizer") |
| print("Pushed to https://huggingface.co/Rayugacodes/kernelx-strategist") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="KernelX GPU Training on HF") |
| parser.add_argument("--hf-token", required=True, help="HuggingFace token") |
| parser.add_argument("--world-model-samples", type=int, default=50000) |
| parser.add_argument("--strategist-samples", type=int, default=10000) |
| parser.add_argument("--skip-world-model", action="store_true") |
| parser.add_argument("--skip-strategist", action="store_true") |
| parser.add_argument("--skip-merge", action="store_true") |
| args = parser.parse_args() |
|
|
| |
| data_dir = setup(args.hf_token) |
|
|
| |
| if not args.skip_world_model: |
| train_world_model(data_dir, max_samples=args.world_model_samples) |
|
|
| if not args.skip_strategist: |
| train_strategist(data_dir, max_samples=args.strategist_samples) |
|
|
| if not args.skip_merge: |
| merge_and_push(args.hf_token) |
|
|
| print("\n=== All done! ===") |
| print("Model: https://huggingface.co/Rayugacodes/kernelx-strategist") |
| print("Next: convert to GGUF for sub-50ms CPU inference") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|