wish-engine-toolcall-runbook / artifacts /build_two_phase_job_payload.ts
sahilmob's picture
docs: add config and job snapshots
8020510 verified
#!/usr/bin/env node
import fs from "node:fs";
import path from "node:path";
type PhaseConfig = {
dataset_source: "hub" | "json";
dataset_id?: string;
train_split?: string;
eval_split?: string;
train_file?: string;
eval_file?: string;
num_train_epochs?: number;
per_device_train_batch_size?: number;
gradient_accumulation_steps?: number;
learning_rate?: number;
};
type TwoPhaseConfig = {
base_model: string;
phase1_hub_model_id?: string;
final_hub_model_id: string;
output_root?: string;
phase1: PhaseConfig;
phase2: PhaseConfig;
lora?: {
r?: number;
lora_alpha?: number;
lora_dropout?: number;
bias?: string;
task_type?: string;
target_modules?: string[];
};
runtime?: {
trackio_project?: string;
trackio_run_name?: string;
seed?: number;
logging_steps?: number;
save_steps?: number;
save_total_limit?: number;
eval_steps?: number;
warmup_ratio?: number;
lr_scheduler_type?: string;
max_length?: number;
gradient_checkpointing?: boolean;
strict_chat_template?: boolean;
skip_trainer_model_move?: boolean;
force_single_device_model?: boolean;
};
};
type CliOptions = {
configPath: string;
outPath: string;
emitScriptPath: string | null;
flavor: string;
timeout: string;
dryRun: boolean;
};
const usageText = `
Usage:
npx tsx ./training/hf-jobs/build_two_phase_job_payload.ts [options]
Options:
--config <path> Two-phase config JSON path
(default: ./training/hf-jobs/two-phase-sft.hf.config.json)
--out <path> Output JSON payload path
(default: ./training/hf-jobs/two-phase-job.payload.json)
--emit-script <path> Also write generated Python training script to this path
--flavor <name> HF Jobs hardware flavor (default: a10g-large)
--timeout <dur> HF Jobs timeout (default: 4h)
--dry-run Validate and print summary without writing files
--help Show this help
`.trim();
const parseArgs = (argv: string[]): CliOptions => {
const opts: CliOptions = {
configPath: path.resolve(
process.cwd(),
"training",
"hf-jobs",
"two-phase-sft.hf.config.json"
),
outPath: path.resolve(
process.cwd(),
"training",
"hf-jobs",
"two-phase-job.payload.json"
),
emitScriptPath: null,
flavor: "a10g-large",
timeout: "4h",
dryRun: false,
};
for (let i = 0; i < argv.length; i += 1) {
const arg = argv[i];
if (arg === "--config") {
opts.configPath = path.resolve(argv[i + 1]);
i += 1;
} else if (arg === "--out") {
opts.outPath = path.resolve(argv[i + 1]);
i += 1;
} else if (arg === "--emit-script") {
opts.emitScriptPath = path.resolve(argv[i + 1]);
i += 1;
} else if (arg === "--flavor") {
opts.flavor = String(argv[i + 1] || "").trim();
i += 1;
} else if (arg === "--timeout") {
opts.timeout = String(argv[i + 1] || "").trim();
i += 1;
} else if (arg === "--dry-run") {
opts.dryRun = true;
} else if (arg === "--help" || arg === "-h") {
console.log(usageText);
process.exit(0);
} else if (arg.startsWith("-")) {
throw new Error(`Unknown option: ${arg}\n\n${usageText}`);
}
}
return opts;
};
const ensureDir = (dirPath: string) => {
fs.mkdirSync(dirPath, { recursive: true });
};
const readJson = <T>(filePath: string): T => {
return JSON.parse(fs.readFileSync(filePath, "utf8"));
};
const writeJson = (filePath: string, value: unknown) => {
ensureDir(path.dirname(filePath));
fs.writeFileSync(filePath, `${JSON.stringify(value, null, 2)}\n`, "utf8");
};
const writeText = (filePath: string, value: string) => {
ensureDir(path.dirname(filePath));
fs.writeFileSync(filePath, value, "utf8");
};
const validateConfig = (config: TwoPhaseConfig) => {
const required = ["base_model", "final_hub_model_id", "phase1", "phase2"] as const;
for (const key of required) {
if (!(key in config)) {
throw new Error(`Missing required config key: ${key}`);
}
}
const phaseNames: Array<"phase1" | "phase2"> = ["phase1", "phase2"];
for (const phaseName of phaseNames) {
const phase = config[phaseName];
if (!phase || typeof phase !== "object") {
throw new Error(`Invalid ${phaseName} config`);
}
if (phase.dataset_source !== "hub" && phase.dataset_source !== "json") {
throw new Error(`${phaseName}.dataset_source must be "hub" or "json"`);
}
if (phase.dataset_source === "hub" && !phase.dataset_id) {
throw new Error(`${phaseName}.dataset_id is required when dataset_source=hub`);
}
if (phase.dataset_source === "json" && !phase.train_file) {
throw new Error(`${phaseName}.train_file is required when dataset_source=json`);
}
}
};
const generatePythonScript = (config: TwoPhaseConfig) => {
const cfgB64 = Buffer.from(JSON.stringify(config), "utf8").toString("base64");
return `#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets>=2.19.0",
# "trl>=0.12.0",
# "peft>=0.12.0",
# "transformers>=4.45.0",
# "accelerate>=0.34.0",
# "jinja2>=3.1.0",
# "trackio",
# ]
# ///
import base64
import json
import os
import trackio
import torch
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
CONFIG = json.loads(base64.b64decode("${cfgB64}").decode("utf-8"))
def load_phase_datasets(phase_cfg):
source = str(phase_cfg.get("dataset_source", "hub")).strip().lower()
if source == "hub":
dataset_id = phase_cfg["dataset_id"]
train_split = phase_cfg.get("train_split", "train")
eval_split = phase_cfg.get("eval_split", "validation")
train_ds = load_dataset(dataset_id, split=train_split)
eval_ds = load_dataset(dataset_id, split=eval_split) if eval_split else None
return train_ds, eval_ds
if source == "json":
train_file = phase_cfg["train_file"]
eval_file = phase_cfg.get("eval_file")
data_files = {"train": train_file}
if eval_file:
data_files["validation"] = eval_file
ds = load_dataset("json", data_files=data_files)
train_ds = ds["train"]
eval_ds = ds["validation"] if "validation" in ds else None
return train_ds, eval_ds
raise ValueError(f"Unsupported dataset_source: {source}")
def _safe_json(value):
try:
return json.dumps(value, ensure_ascii=False)
except Exception:
return str(value)
def _message_to_text(message):
if not isinstance(message, dict):
return _safe_json(message)
role = str(message.get("role", "unknown"))
if isinstance(message.get("tool_calls"), list):
return f"{role}: <tool_calls> " + _safe_json(message.get("tool_calls"))
content = message.get("content")
if isinstance(content, str):
return f"{role}: {content}"
if content is None:
return f"{role}:"
return f"{role}: " + _safe_json(content)
def _fallback_chat_render(messages):
lines = [_message_to_text(message) for message in messages]
return "\\n".join(lines).strip()
def load_base_model_for_training(model_id, runtime_cfg):
on_cuda = torch.cuda.is_available()
force_single_device_model = bool(runtime_cfg.get("force_single_device_model", True))
device_map = None
if on_cuda:
device_map = {"": 0} if force_single_device_model else "auto"
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if on_cuda else torch.float32,
low_cpu_mem_usage=True,
device_map=device_map,
)
if hasattr(model, "config") and hasattr(model.config, "use_cache"):
model.config.use_cache = False
if on_cuda and hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
return model
def assert_no_meta_parameters(model):
meta_params = [name for name, param in model.named_parameters() if param.device.type == "meta"]
if meta_params:
sample = ", ".join(meta_params[:5])
raise RuntimeError(f"Model has {len(meta_params)} meta parameters after load: {sample}")
class StaticDeviceSFTTrainer(SFTTrainer):
def __init__(self, *args, skip_model_move=False, **kwargs):
self._skip_model_move = bool(skip_model_move)
super().__init__(*args, **kwargs)
def _move_model_to_device(self, model, device):
if self._skip_model_move:
return model
return super()._move_model_to_device(model, device)
def normalize_dataset_for_sft(dataset, tokenizer, split_name, strict_chat_template):
fallback_used = 0
render_errors = 0
def row_to_text(example):
nonlocal fallback_used
nonlocal render_errors
messages = example.get("messages")
if isinstance(messages, list):
try:
rendered = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
if isinstance(rendered, str) and rendered.strip():
return {"text": rendered}
if strict_chat_template:
raise RuntimeError("chat template returned empty text")
except Exception as exc:
render_errors += 1
if strict_chat_template:
raise RuntimeError(
f"chat template failed in split={split_name}: {type(exc).__name__}: {exc}"
) from exc
fallback_used += 1
return {"text": _fallback_chat_render(messages)}
prompt = example.get("prompt")
response = example.get("response")
completion = example.get("completion")
chosen = example.get("chosen")
text = example.get("text")
if isinstance(prompt, str) and isinstance(response, str):
return {"text": f"{prompt}\\n\\n{response}"}
if isinstance(prompt, str) and isinstance(completion, str):
return {"text": f"{prompt}\\n\\n{completion}"}
if isinstance(prompt, str) and isinstance(chosen, str):
return {"text": f"{prompt}\\n\\n{chosen}"}
if isinstance(text, str):
return {"text": text}
return {"text": _safe_json(example)}
mapped = dataset.map(
row_to_text,
desc=f"Normalize {split_name} dataset to text",
)
drop_columns = [column for column in mapped.column_names if column != "text"]
if drop_columns:
mapped = mapped.remove_columns(drop_columns)
print(
f"[normalize] split={split_name} rows={len(mapped)} "
f"fallback_used={fallback_used} render_errors={render_errors} "
f"strict_chat_template={strict_chat_template}"
)
return mapped
def build_sft_config(phase_name, phase_cfg, runtime_cfg, output_root, push_to_hub, hub_model_id, has_eval):
run_name_root = runtime_cfg.get("trackio_run_name", "two-phase-sft")
cfg_kwargs = {
"output_dir": os.path.join(output_root, phase_name),
"push_to_hub": push_to_hub,
"hub_model_id": hub_model_id if push_to_hub and hub_model_id else None,
"num_train_epochs": float(phase_cfg.get("num_train_epochs", 1)),
"per_device_train_batch_size": int(phase_cfg.get("per_device_train_batch_size", 4)),
"gradient_accumulation_steps": int(phase_cfg.get("gradient_accumulation_steps", 4)),
"learning_rate": float(phase_cfg.get("learning_rate", 2e-5)),
"logging_steps": int(runtime_cfg.get("logging_steps", 10)),
"save_strategy": "steps",
"save_steps": int(runtime_cfg.get("save_steps", 100)),
"save_total_limit": int(runtime_cfg.get("save_total_limit", 2)),
"warmup_ratio": float(runtime_cfg.get("warmup_ratio", 0.1)),
"lr_scheduler_type": str(runtime_cfg.get("lr_scheduler_type", "cosine")),
"dataset_text_field": "text",
"seed": int(runtime_cfg.get("seed", 42)),
"bf16": bool(torch.cuda.is_available()),
"gradient_checkpointing": bool(runtime_cfg.get("gradient_checkpointing", True)),
"report_to": "trackio",
"project": str(runtime_cfg.get("trackio_project", "wish-engine-jssg")),
"run_name": f"{run_name_root}-{phase_name}",
}
max_length = runtime_cfg.get("max_length")
if max_length is not None:
cfg_kwargs["max_length"] = int(max_length)
if has_eval:
cfg_kwargs["eval_strategy"] = "steps"
cfg_kwargs["eval_steps"] = int(runtime_cfg.get("eval_steps", 100))
cfg_kwargs = {k: v for k, v in cfg_kwargs.items() if v is not None}
return SFTConfig(**cfg_kwargs)
def build_lora_config(lora_cfg):
return LoraConfig(
r=int(lora_cfg.get("r", 16)),
lora_alpha=int(lora_cfg.get("lora_alpha", 32)),
lora_dropout=float(lora_cfg.get("lora_dropout", 0.05)),
bias=str(lora_cfg.get("bias", "none")),
task_type=str(lora_cfg.get("task_type", "CAUSAL_LM")),
target_modules=list(lora_cfg.get("target_modules", ["q_proj", "v_proj"])),
)
def train_phase(phase_name, model_ref_or_obj, tokenizer, phase_cfg, runtime_cfg, output_root, peft_config, push_to_hub, hub_model_id):
strict_chat_template = bool(runtime_cfg.get("strict_chat_template", True))
skip_model_move = bool(runtime_cfg.get("skip_trainer_model_move", True))
train_ds, eval_ds = load_phase_datasets(phase_cfg)
train_ds = normalize_dataset_for_sft(
train_ds,
tokenizer,
f"{phase_name}-train",
strict_chat_template=strict_chat_template,
)
if eval_ds is not None:
eval_ds = normalize_dataset_for_sft(
eval_ds,
tokenizer,
f"{phase_name}-eval",
strict_chat_template=strict_chat_template,
)
print(f"[{phase_name}] train={len(train_ds)} eval={len(eval_ds) if eval_ds is not None else 0}")
if len(train_ds) > 0:
print(f"[{phase_name}] sample_text_chars={len(train_ds[0]['text'])}")
sft_cfg = build_sft_config(
phase_name=phase_name,
phase_cfg=phase_cfg,
runtime_cfg=runtime_cfg,
output_root=output_root,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id,
has_eval=eval_ds is not None,
)
trainer = StaticDeviceSFTTrainer(
model=model_ref_or_obj,
processing_class=tokenizer,
train_dataset=train_ds,
eval_dataset=eval_ds,
args=sft_cfg,
peft_config=peft_config,
skip_model_move=skip_model_move,
)
trainer.train()
trainer.save_model()
if push_to_hub:
trainer.push_to_hub()
return trainer
def main():
base_model = CONFIG["base_model"]
final_hub_model_id = CONFIG["final_hub_model_id"]
output_root = CONFIG.get("output_root", "wish-engine-two-phase-sft")
runtime_cfg = CONFIG.get("runtime", {})
phase1_cfg = CONFIG["phase1"]
phase2_cfg = CONFIG["phase2"]
lora_cfg = CONFIG.get("lora", {})
phase1_hub_model_id = CONFIG.get("phase1_hub_model_id")
print(f"[setup] base_model={base_model}")
print(f"[setup] final_hub_model_id={final_hub_model_id}")
print(f"[setup] skip_trainer_model_move={bool(runtime_cfg.get('skip_trainer_model_move', True))}")
print(f"[setup] force_single_device_model={bool(runtime_cfg.get('force_single_device_model', True))}")
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True, use_fast=True)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model_obj = load_base_model_for_training(base_model, runtime_cfg)
assert_no_meta_parameters(base_model_obj)
trainer1 = train_phase(
phase_name="phase1-focus",
model_ref_or_obj=base_model_obj,
tokenizer=tokenizer,
phase_cfg=phase1_cfg,
runtime_cfg=runtime_cfg,
output_root=output_root,
peft_config=build_lora_config(lora_cfg),
push_to_hub=bool(phase1_hub_model_id),
hub_model_id=phase1_hub_model_id,
)
assert_no_meta_parameters(trainer1.model)
if torch.cuda.is_available():
torch.cuda.empty_cache()
train_phase(
phase_name="phase2-curriculum",
model_ref_or_obj=trainer1.model,
tokenizer=tokenizer,
phase_cfg=phase2_cfg,
runtime_cfg=runtime_cfg,
output_root=output_root,
peft_config=None,
push_to_hub=True,
hub_model_id=final_hub_model_id,
)
trackio.finish()
print("[done] two-phase training complete")
if __name__ == "__main__":
main()
`;
};
const buildPayload = (
pythonScript: string,
options: Pick<CliOptions, "flavor" | "timeout">
) => {
const jobParameters = {
script: pythonScript,
flavor: options.flavor,
timeout: options.timeout,
secrets: {
HF_TOKEN: "$HF_TOKEN",
},
};
return {
createdAt: new Date().toISOString(),
tool: "hf_jobs",
method: "uv",
parameters: jobParameters,
callSnippet: `hf_jobs("uv", ${JSON.stringify(jobParameters, null, 2)})`,
};
};
const main = () => {
const options = parseArgs(process.argv.slice(2));
const config = readJson<TwoPhaseConfig>(options.configPath);
validateConfig(config);
const pythonScript = generatePythonScript(config);
const payload = buildPayload(pythonScript, options);
if (options.dryRun) {
console.log(
JSON.stringify(
{
message: "Dry run OK",
configPath: options.configPath,
flavor: options.flavor,
timeout: options.timeout,
hasPhase1HubPush: Boolean(config.phase1_hub_model_id),
finalHubModelId: config.final_hub_model_id,
phase1Source: config.phase1.dataset_source,
phase2Source: config.phase2.dataset_source,
outputPath: options.outPath,
},
null,
2
)
);
return;
}
writeJson(options.outPath, payload);
if (options.emitScriptPath) {
writeText(options.emitScriptPath, pythonScript);
}
console.log(
JSON.stringify(
{
message: "Two-phase HF Jobs payload generated",
configPath: options.configPath,
outPath: options.outPath,
emitScriptPath: options.emitScriptPath,
flavor: options.flavor,
timeout: options.timeout,
},
null,
2
)
);
};
main();