davidsmts commited on
Commit
c13e2c7
·
verified ·
1 Parent(s): 4aadaf6

Fix Trackio init args

Browse files
Files changed (1) hide show
  1. train_sft_qwen25_hf_jobs.py +11 -8
train_sft_qwen25_hf_jobs.py CHANGED
@@ -105,12 +105,11 @@ def main() -> None:
105
  timestamp = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
106
  run_name = f"sft_{args.model.split('/')[-1]}_{timestamp}"
107
 
108
- trackio.init(
109
- project=args.trackio_project,
110
- run_name=run_name,
111
- space_id=args.trackio_space_id,
112
- group=args.trackio_group,
113
- config={
114
  "model": args.model,
115
  "dataset": f"{args.dataset}:{args.split}",
116
  "max_length": args.max_length,
@@ -118,12 +117,17 @@ def main() -> None:
118
  "epochs": args.num_train_epochs,
119
  "max_train_samples": args.max_train_samples,
120
  },
121
- )
 
 
 
122
 
123
  print(f"👤 HF user: {hf_user}")
124
  print(f"📦 Loading dataset: {args.dataset} [{args.split}]")
125
  dataset = load_dataset(args.dataset, split=args.split)
126
  print(f"✅ Dataset loaded: {len(dataset)} rows")
 
 
127
 
128
  dataset = dataset.shuffle(seed=args.seed)
129
  eval_size = min(args.max_eval_samples, max(1, int(0.1 * len(dataset))))
@@ -199,4 +203,3 @@ def main() -> None:
199
 
200
  if __name__ == "__main__":
201
  main()
202
-
 
105
  timestamp = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
106
  run_name = f"sft_{args.model.split('/')[-1]}_{timestamp}"
107
 
108
+ trackio_kwargs = {
109
+ "project": args.trackio_project,
110
+ "run_name": run_name,
111
+ "space_id": args.trackio_space_id,
112
+ "config": {
 
113
  "model": args.model,
114
  "dataset": f"{args.dataset}:{args.split}",
115
  "max_length": args.max_length,
 
117
  "epochs": args.num_train_epochs,
118
  "max_train_samples": args.max_train_samples,
119
  },
120
+ }
121
+ if args.trackio_group:
122
+ trackio_kwargs["group"] = args.trackio_group
123
+ trackio.init(**trackio_kwargs)
124
 
125
  print(f"👤 HF user: {hf_user}")
126
  print(f"📦 Loading dataset: {args.dataset} [{args.split}]")
127
  dataset = load_dataset(args.dataset, split=args.split)
128
  print(f"✅ Dataset loaded: {len(dataset)} rows")
129
+ if len(dataset) < 2:
130
+ raise SystemExit("Dataset split must have at least 2 rows to create a train/eval split.")
131
 
132
  dataset = dataset.shuffle(seed=args.seed)
133
  eval_size = min(args.max_eval_samples, max(1, int(0.1 * len(dataset))))
 
203
 
204
  if __name__ == "__main__":
205
  main()