Fix Trackio init args
Browse files- train_sft_qwen25_hf_jobs.py +11 -8
train_sft_qwen25_hf_jobs.py
CHANGED
|
@@ -105,12 +105,11 @@ def main() -> None:
|
|
| 105 |
timestamp = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
|
| 106 |
run_name = f"sft_{args.model.split('/')[-1]}_{timestamp}"
|
| 107 |
|
| 108 |
-
|
| 109 |
-
project
|
| 110 |
-
run_name
|
| 111 |
-
space_id
|
| 112 |
-
|
| 113 |
-
config={
|
| 114 |
"model": args.model,
|
| 115 |
"dataset": f"{args.dataset}:{args.split}",
|
| 116 |
"max_length": args.max_length,
|
|
@@ -118,12 +117,17 @@ def main() -> None:
|
|
| 118 |
"epochs": args.num_train_epochs,
|
| 119 |
"max_train_samples": args.max_train_samples,
|
| 120 |
},
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
print(f"👤 HF user: {hf_user}")
|
| 124 |
print(f"📦 Loading dataset: {args.dataset} [{args.split}]")
|
| 125 |
dataset = load_dataset(args.dataset, split=args.split)
|
| 126 |
print(f"✅ Dataset loaded: {len(dataset)} rows")
|
|
|
|
|
|
|
| 127 |
|
| 128 |
dataset = dataset.shuffle(seed=args.seed)
|
| 129 |
eval_size = min(args.max_eval_samples, max(1, int(0.1 * len(dataset))))
|
|
@@ -199,4 +203,3 @@ def main() -> None:
|
|
| 199 |
|
| 200 |
if __name__ == "__main__":
|
| 201 |
main()
|
| 202 |
-
|
|
|
|
| 105 |
timestamp = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
|
| 106 |
run_name = f"sft_{args.model.split('/')[-1]}_{timestamp}"
|
| 107 |
|
| 108 |
+
trackio_kwargs = {
|
| 109 |
+
"project": args.trackio_project,
|
| 110 |
+
"run_name": run_name,
|
| 111 |
+
"space_id": args.trackio_space_id,
|
| 112 |
+
"config": {
|
|
|
|
| 113 |
"model": args.model,
|
| 114 |
"dataset": f"{args.dataset}:{args.split}",
|
| 115 |
"max_length": args.max_length,
|
|
|
|
| 117 |
"epochs": args.num_train_epochs,
|
| 118 |
"max_train_samples": args.max_train_samples,
|
| 119 |
},
|
| 120 |
+
}
|
| 121 |
+
if args.trackio_group:
|
| 122 |
+
trackio_kwargs["group"] = args.trackio_group
|
| 123 |
+
trackio.init(**trackio_kwargs)
|
| 124 |
|
| 125 |
print(f"👤 HF user: {hf_user}")
|
| 126 |
print(f"📦 Loading dataset: {args.dataset} [{args.split}]")
|
| 127 |
dataset = load_dataset(args.dataset, split=args.split)
|
| 128 |
print(f"✅ Dataset loaded: {len(dataset)} rows")
|
| 129 |
+
if len(dataset) < 2:
|
| 130 |
+
raise SystemExit("Dataset split must have at least 2 rows to create a train/eval split.")
|
| 131 |
|
| 132 |
dataset = dataset.shuffle(seed=args.seed)
|
| 133 |
eval_size = min(args.max_eval_samples, max(1, int(0.1 * len(dataset))))
|
|
|
|
| 203 |
|
| 204 |
if __name__ == "__main__":
|
| 205 |
main()
|
|
|