| #!/usr/bin/env node |
|
|
| import fs from "node:fs"; |
| import path from "node:path"; |
|
|
| type PhaseConfig = { |
| dataset_source: "hub" | "json"; |
| dataset_id?: string; |
| train_split?: string; |
| eval_split?: string; |
| train_file?: string; |
| eval_file?: string; |
| num_train_epochs?: number; |
| per_device_train_batch_size?: number; |
| gradient_accumulation_steps?: number; |
| learning_rate?: number; |
| }; |
|
|
| type TwoPhaseConfig = { |
| base_model: string; |
| phase1_hub_model_id?: string; |
| final_hub_model_id: string; |
| output_root?: string; |
| phase1: PhaseConfig; |
| phase2: PhaseConfig; |
| lora?: { |
| r?: number; |
| lora_alpha?: number; |
| lora_dropout?: number; |
| bias?: string; |
| task_type?: string; |
| target_modules?: string[]; |
| }; |
| runtime?: { |
| trackio_project?: string; |
| trackio_run_name?: string; |
| seed?: number; |
| logging_steps?: number; |
| save_steps?: number; |
| save_total_limit?: number; |
| eval_steps?: number; |
| warmup_ratio?: number; |
| lr_scheduler_type?: string; |
| max_length?: number; |
| gradient_checkpointing?: boolean; |
| strict_chat_template?: boolean; |
| skip_trainer_model_move?: boolean; |
| force_single_device_model?: boolean; |
| }; |
| }; |
|
|
| type CliOptions = { |
| configPath: string; |
| outPath: string; |
| emitScriptPath: string | null; |
| flavor: string; |
| timeout: string; |
| dryRun: boolean; |
| }; |
|
|
| const usageText = ` |
| Usage: |
| npx tsx ./training/hf-jobs/build_two_phase_job_payload.ts [options] |
| |
| Options: |
| --config <path> Two-phase config JSON path |
| (default: ./training/hf-jobs/two-phase-sft.hf.config.json) |
| --out <path> Output JSON payload path |
| (default: ./training/hf-jobs/two-phase-job.payload.json) |
| --emit-script <path> Also write generated Python training script to this path |
| --flavor <name> HF Jobs hardware flavor (default: a10g-large) |
| --timeout <dur> HF Jobs timeout (default: 4h) |
| --dry-run Validate and print summary without writing files |
| --help Show this help |
| `.trim(); |
|
|
| const parseArgs = (argv: string[]): CliOptions => { |
| const opts: CliOptions = { |
| configPath: path.resolve( |
| process.cwd(), |
| "training", |
| "hf-jobs", |
| "two-phase-sft.hf.config.json" |
| ), |
| outPath: path.resolve( |
| process.cwd(), |
| "training", |
| "hf-jobs", |
| "two-phase-job.payload.json" |
| ), |
| emitScriptPath: null, |
| flavor: "a10g-large", |
| timeout: "4h", |
| dryRun: false, |
| }; |
|
|
| for (let i = 0; i < argv.length; i += 1) { |
| const arg = argv[i]; |
| if (arg === "--config") { |
| opts.configPath = path.resolve(argv[i + 1]); |
| i += 1; |
| } else if (arg === "--out") { |
| opts.outPath = path.resolve(argv[i + 1]); |
| i += 1; |
| } else if (arg === "--emit-script") { |
| opts.emitScriptPath = path.resolve(argv[i + 1]); |
| i += 1; |
| } else if (arg === "--flavor") { |
| opts.flavor = String(argv[i + 1] || "").trim(); |
| i += 1; |
| } else if (arg === "--timeout") { |
| opts.timeout = String(argv[i + 1] || "").trim(); |
| i += 1; |
| } else if (arg === "--dry-run") { |
| opts.dryRun = true; |
| } else if (arg === "--help" || arg === "-h") { |
| console.log(usageText); |
| process.exit(0); |
| } else if (arg.startsWith("-")) { |
| throw new Error(`Unknown option: ${arg}\n\n${usageText}`); |
| } |
| } |
|
|
| return opts; |
| }; |
|
|
| const ensureDir = (dirPath: string) => { |
| fs.mkdirSync(dirPath, { recursive: true }); |
| }; |
|
|
| const readJson = <T>(filePath: string): T => { |
| return JSON.parse(fs.readFileSync(filePath, "utf8")); |
| }; |
|
|
| const writeJson = (filePath: string, value: unknown) => { |
| ensureDir(path.dirname(filePath)); |
| fs.writeFileSync(filePath, `${JSON.stringify(value, null, 2)}\n`, "utf8"); |
| }; |
|
|
| const writeText = (filePath: string, value: string) => { |
| ensureDir(path.dirname(filePath)); |
| fs.writeFileSync(filePath, value, "utf8"); |
| }; |
|
|
| const validateConfig = (config: TwoPhaseConfig) => { |
| const required = ["base_model", "final_hub_model_id", "phase1", "phase2"] as const; |
| for (const key of required) { |
| if (!(key in config)) { |
| throw new Error(`Missing required config key: ${key}`); |
| } |
| } |
|
|
| const phaseNames: Array<"phase1" | "phase2"> = ["phase1", "phase2"]; |
| for (const phaseName of phaseNames) { |
| const phase = config[phaseName]; |
| if (!phase || typeof phase !== "object") { |
| throw new Error(`Invalid ${phaseName} config`); |
| } |
| if (phase.dataset_source !== "hub" && phase.dataset_source !== "json") { |
| throw new Error(`${phaseName}.dataset_source must be "hub" or "json"`); |
| } |
| if (phase.dataset_source === "hub" && !phase.dataset_id) { |
| throw new Error(`${phaseName}.dataset_id is required when dataset_source=hub`); |
| } |
| if (phase.dataset_source === "json" && !phase.train_file) { |
| throw new Error(`${phaseName}.train_file is required when dataset_source=json`); |
| } |
| } |
| }; |
|
|
| const generatePythonScript = (config: TwoPhaseConfig) => { |
| const cfgB64 = Buffer.from(JSON.stringify(config), "utf8").toString("base64"); |
|
|
| return `#!/usr/bin/env python3 |
| # /// script |
| # requires-python = ">=3.10" |
| # dependencies = [ |
| # "datasets>=2.19.0", |
| # "trl>=0.12.0", |
| # "peft>=0.12.0", |
| # "transformers>=4.45.0", |
| # "accelerate>=0.34.0", |
| # "jinja2>=3.1.0", |
| # "trackio", |
| # ] |
| # /// |
| |
| import base64 |
| import json |
| import os |
| |
| import trackio |
| import torch |
| from datasets import load_dataset |
| from peft import LoraConfig |
| from trl import SFTConfig, SFTTrainer |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| |
| CONFIG = json.loads(base64.b64decode("${cfgB64}").decode("utf-8")) |
| |
| |
| def load_phase_datasets(phase_cfg): |
| source = str(phase_cfg.get("dataset_source", "hub")).strip().lower() |
| if source == "hub": |
| dataset_id = phase_cfg["dataset_id"] |
| train_split = phase_cfg.get("train_split", "train") |
| eval_split = phase_cfg.get("eval_split", "validation") |
| train_ds = load_dataset(dataset_id, split=train_split) |
| eval_ds = load_dataset(dataset_id, split=eval_split) if eval_split else None |
| return train_ds, eval_ds |
| if source == "json": |
| train_file = phase_cfg["train_file"] |
| eval_file = phase_cfg.get("eval_file") |
| data_files = {"train": train_file} |
| if eval_file: |
| data_files["validation"] = eval_file |
| ds = load_dataset("json", data_files=data_files) |
| train_ds = ds["train"] |
| eval_ds = ds["validation"] if "validation" in ds else None |
| return train_ds, eval_ds |
| raise ValueError(f"Unsupported dataset_source: {source}") |
| |
| |
| def _safe_json(value): |
| try: |
| return json.dumps(value, ensure_ascii=False) |
| except Exception: |
| return str(value) |
| |
| |
| def _message_to_text(message): |
| if not isinstance(message, dict): |
| return _safe_json(message) |
| role = str(message.get("role", "unknown")) |
| if isinstance(message.get("tool_calls"), list): |
| return f"{role}: <tool_calls> " + _safe_json(message.get("tool_calls")) |
| content = message.get("content") |
| if isinstance(content, str): |
| return f"{role}: {content}" |
| if content is None: |
| return f"{role}:" |
| return f"{role}: " + _safe_json(content) |
| |
| |
| def _fallback_chat_render(messages): |
| lines = [_message_to_text(message) for message in messages] |
| return "\\n".join(lines).strip() |
| |
| |
| def load_base_model_for_training(model_id, runtime_cfg): |
| on_cuda = torch.cuda.is_available() |
| force_single_device_model = bool(runtime_cfg.get("force_single_device_model", True)) |
| device_map = None |
| if on_cuda: |
| device_map = {"": 0} if force_single_device_model else "auto" |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if on_cuda else torch.float32, |
| low_cpu_mem_usage=True, |
| device_map=device_map, |
| ) |
| if hasattr(model, "config") and hasattr(model.config, "use_cache"): |
| model.config.use_cache = False |
| if on_cuda and hasattr(model, "gradient_checkpointing_enable"): |
| model.gradient_checkpointing_enable() |
| return model |
| |
| |
| def assert_no_meta_parameters(model): |
| meta_params = [name for name, param in model.named_parameters() if param.device.type == "meta"] |
| if meta_params: |
| sample = ", ".join(meta_params[:5]) |
| raise RuntimeError(f"Model has {len(meta_params)} meta parameters after load: {sample}") |
| |
| |
| class StaticDeviceSFTTrainer(SFTTrainer): |
| def __init__(self, *args, skip_model_move=False, **kwargs): |
| self._skip_model_move = bool(skip_model_move) |
| super().__init__(*args, **kwargs) |
| |
| def _move_model_to_device(self, model, device): |
| if self._skip_model_move: |
| return model |
| return super()._move_model_to_device(model, device) |
| |
| |
| def normalize_dataset_for_sft(dataset, tokenizer, split_name, strict_chat_template): |
| fallback_used = 0 |
| render_errors = 0 |
| |
| def row_to_text(example): |
| nonlocal fallback_used |
| nonlocal render_errors |
| messages = example.get("messages") |
| if isinstance(messages, list): |
| try: |
| rendered = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=False, |
| ) |
| if isinstance(rendered, str) and rendered.strip(): |
| return {"text": rendered} |
| if strict_chat_template: |
| raise RuntimeError("chat template returned empty text") |
| except Exception as exc: |
| render_errors += 1 |
| if strict_chat_template: |
| raise RuntimeError( |
| f"chat template failed in split={split_name}: {type(exc).__name__}: {exc}" |
| ) from exc |
| fallback_used += 1 |
| return {"text": _fallback_chat_render(messages)} |
| |
| prompt = example.get("prompt") |
| response = example.get("response") |
| completion = example.get("completion") |
| chosen = example.get("chosen") |
| text = example.get("text") |
| |
| if isinstance(prompt, str) and isinstance(response, str): |
| return {"text": f"{prompt}\\n\\n{response}"} |
| if isinstance(prompt, str) and isinstance(completion, str): |
| return {"text": f"{prompt}\\n\\n{completion}"} |
| if isinstance(prompt, str) and isinstance(chosen, str): |
| return {"text": f"{prompt}\\n\\n{chosen}"} |
| if isinstance(text, str): |
| return {"text": text} |
| return {"text": _safe_json(example)} |
| |
| mapped = dataset.map( |
| row_to_text, |
| desc=f"Normalize {split_name} dataset to text", |
| ) |
| drop_columns = [column for column in mapped.column_names if column != "text"] |
| if drop_columns: |
| mapped = mapped.remove_columns(drop_columns) |
| print( |
| f"[normalize] split={split_name} rows={len(mapped)} " |
| f"fallback_used={fallback_used} render_errors={render_errors} " |
| f"strict_chat_template={strict_chat_template}" |
| ) |
| return mapped |
| |
| |
| def build_sft_config(phase_name, phase_cfg, runtime_cfg, output_root, push_to_hub, hub_model_id, has_eval): |
| run_name_root = runtime_cfg.get("trackio_run_name", "two-phase-sft") |
| cfg_kwargs = { |
| "output_dir": os.path.join(output_root, phase_name), |
| "push_to_hub": push_to_hub, |
| "hub_model_id": hub_model_id if push_to_hub and hub_model_id else None, |
| "num_train_epochs": float(phase_cfg.get("num_train_epochs", 1)), |
| "per_device_train_batch_size": int(phase_cfg.get("per_device_train_batch_size", 4)), |
| "gradient_accumulation_steps": int(phase_cfg.get("gradient_accumulation_steps", 4)), |
| "learning_rate": float(phase_cfg.get("learning_rate", 2e-5)), |
| "logging_steps": int(runtime_cfg.get("logging_steps", 10)), |
| "save_strategy": "steps", |
| "save_steps": int(runtime_cfg.get("save_steps", 100)), |
| "save_total_limit": int(runtime_cfg.get("save_total_limit", 2)), |
| "warmup_ratio": float(runtime_cfg.get("warmup_ratio", 0.1)), |
| "lr_scheduler_type": str(runtime_cfg.get("lr_scheduler_type", "cosine")), |
| "dataset_text_field": "text", |
| "seed": int(runtime_cfg.get("seed", 42)), |
| "bf16": bool(torch.cuda.is_available()), |
| "gradient_checkpointing": bool(runtime_cfg.get("gradient_checkpointing", True)), |
| "report_to": "trackio", |
| "project": str(runtime_cfg.get("trackio_project", "wish-engine-jssg")), |
| "run_name": f"{run_name_root}-{phase_name}", |
| } |
| max_length = runtime_cfg.get("max_length") |
| if max_length is not None: |
| cfg_kwargs["max_length"] = int(max_length) |
| if has_eval: |
| cfg_kwargs["eval_strategy"] = "steps" |
| cfg_kwargs["eval_steps"] = int(runtime_cfg.get("eval_steps", 100)) |
| cfg_kwargs = {k: v for k, v in cfg_kwargs.items() if v is not None} |
| return SFTConfig(**cfg_kwargs) |
| |
| |
| def build_lora_config(lora_cfg): |
| return LoraConfig( |
| r=int(lora_cfg.get("r", 16)), |
| lora_alpha=int(lora_cfg.get("lora_alpha", 32)), |
| lora_dropout=float(lora_cfg.get("lora_dropout", 0.05)), |
| bias=str(lora_cfg.get("bias", "none")), |
| task_type=str(lora_cfg.get("task_type", "CAUSAL_LM")), |
| target_modules=list(lora_cfg.get("target_modules", ["q_proj", "v_proj"])), |
| ) |
| |
| |
| def train_phase(phase_name, model_ref_or_obj, tokenizer, phase_cfg, runtime_cfg, output_root, peft_config, push_to_hub, hub_model_id): |
| strict_chat_template = bool(runtime_cfg.get("strict_chat_template", True)) |
| skip_model_move = bool(runtime_cfg.get("skip_trainer_model_move", True)) |
| train_ds, eval_ds = load_phase_datasets(phase_cfg) |
| train_ds = normalize_dataset_for_sft( |
| train_ds, |
| tokenizer, |
| f"{phase_name}-train", |
| strict_chat_template=strict_chat_template, |
| ) |
| if eval_ds is not None: |
| eval_ds = normalize_dataset_for_sft( |
| eval_ds, |
| tokenizer, |
| f"{phase_name}-eval", |
| strict_chat_template=strict_chat_template, |
| ) |
| print(f"[{phase_name}] train={len(train_ds)} eval={len(eval_ds) if eval_ds is not None else 0}") |
| if len(train_ds) > 0: |
| print(f"[{phase_name}] sample_text_chars={len(train_ds[0]['text'])}") |
| sft_cfg = build_sft_config( |
| phase_name=phase_name, |
| phase_cfg=phase_cfg, |
| runtime_cfg=runtime_cfg, |
| output_root=output_root, |
| push_to_hub=push_to_hub, |
| hub_model_id=hub_model_id, |
| has_eval=eval_ds is not None, |
| ) |
| trainer = StaticDeviceSFTTrainer( |
| model=model_ref_or_obj, |
| processing_class=tokenizer, |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| args=sft_cfg, |
| peft_config=peft_config, |
| skip_model_move=skip_model_move, |
| ) |
| trainer.train() |
| trainer.save_model() |
| if push_to_hub: |
| trainer.push_to_hub() |
| return trainer |
| |
| |
| def main(): |
| base_model = CONFIG["base_model"] |
| final_hub_model_id = CONFIG["final_hub_model_id"] |
| output_root = CONFIG.get("output_root", "wish-engine-two-phase-sft") |
| runtime_cfg = CONFIG.get("runtime", {}) |
| phase1_cfg = CONFIG["phase1"] |
| phase2_cfg = CONFIG["phase2"] |
| lora_cfg = CONFIG.get("lora", {}) |
| phase1_hub_model_id = CONFIG.get("phase1_hub_model_id") |
| |
| print(f"[setup] base_model={base_model}") |
| print(f"[setup] final_hub_model_id={final_hub_model_id}") |
| print(f"[setup] skip_trainer_model_move={bool(runtime_cfg.get('skip_trainer_model_move', True))}") |
| print(f"[setup] force_single_device_model={bool(runtime_cfg.get('force_single_device_model', True))}") |
| tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True, use_fast=True) |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "right" |
| |
| base_model_obj = load_base_model_for_training(base_model, runtime_cfg) |
| assert_no_meta_parameters(base_model_obj) |
| |
| trainer1 = train_phase( |
| phase_name="phase1-focus", |
| model_ref_or_obj=base_model_obj, |
| tokenizer=tokenizer, |
| phase_cfg=phase1_cfg, |
| runtime_cfg=runtime_cfg, |
| output_root=output_root, |
| peft_config=build_lora_config(lora_cfg), |
| push_to_hub=bool(phase1_hub_model_id), |
| hub_model_id=phase1_hub_model_id, |
| ) |
| |
| assert_no_meta_parameters(trainer1.model) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| train_phase( |
| phase_name="phase2-curriculum", |
| model_ref_or_obj=trainer1.model, |
| tokenizer=tokenizer, |
| phase_cfg=phase2_cfg, |
| runtime_cfg=runtime_cfg, |
| output_root=output_root, |
| peft_config=None, |
| push_to_hub=True, |
| hub_model_id=final_hub_model_id, |
| ) |
| |
| trackio.finish() |
| print("[done] two-phase training complete") |
| |
| |
| if __name__ == "__main__": |
| main() |
| `; |
| }; |
|
|
| const buildPayload = ( |
| pythonScript: string, |
| options: Pick<CliOptions, "flavor" | "timeout"> |
| ) => { |
| const jobParameters = { |
| script: pythonScript, |
| flavor: options.flavor, |
| timeout: options.timeout, |
| secrets: { |
| HF_TOKEN: "$HF_TOKEN", |
| }, |
| }; |
|
|
| return { |
| createdAt: new Date().toISOString(), |
| tool: "hf_jobs", |
| method: "uv", |
| parameters: jobParameters, |
| callSnippet: `hf_jobs("uv", ${JSON.stringify(jobParameters, null, 2)})`, |
| }; |
| }; |
|
|
| const main = () => { |
| const options = parseArgs(process.argv.slice(2)); |
| const config = readJson<TwoPhaseConfig>(options.configPath); |
| validateConfig(config); |
|
|
| const pythonScript = generatePythonScript(config); |
| const payload = buildPayload(pythonScript, options); |
|
|
| if (options.dryRun) { |
| console.log( |
| JSON.stringify( |
| { |
| message: "Dry run OK", |
| configPath: options.configPath, |
| flavor: options.flavor, |
| timeout: options.timeout, |
| hasPhase1HubPush: Boolean(config.phase1_hub_model_id), |
| finalHubModelId: config.final_hub_model_id, |
| phase1Source: config.phase1.dataset_source, |
| phase2Source: config.phase2.dataset_source, |
| outputPath: options.outPath, |
| }, |
| null, |
| 2 |
| ) |
| ); |
| return; |
| } |
|
|
| writeJson(options.outPath, payload); |
| if (options.emitScriptPath) { |
| writeText(options.emitScriptPath, pythonScript); |
| } |
|
|
| console.log( |
| JSON.stringify( |
| { |
| message: "Two-phase HF Jobs payload generated", |
| configPath: options.configPath, |
| outPath: options.outPath, |
| emitScriptPath: options.emitScriptPath, |
| flavor: options.flavor, |
| timeout: options.timeout, |
| }, |
| null, |
| 2 |
| ) |
| ); |
| }; |
|
|
| main(); |
|
|