Codyfederer commited on
Commit
38d7272
·
verified ·
1 Parent(s): 67cafff

Upload train_tool_calling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_tool_calling.py +559 -0
train_tool_calling.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "transformers>=4.50.0",
6
+ # "datasets>=2.14.0",
7
+ # "trl>=0.12.0",
8
+ # "peft>=0.7.0",
9
+ # "accelerate>=0.25.0",
10
+ # "bitsandbytes>=0.41.0",
11
+ # "trackio",
12
+ # "huggingface_hub",
13
+ # ]
14
+ # ///
15
+ """
16
+ LoRA Fine-tuning Script: Add Tool Calling to Synthia-S1-27b
17
+
18
+ This script fine-tunes Tesslate/Synthia-S1-27b with LoRA using the
19
+ nvidia/Nemotron-Agentic-v1 tool_calling dataset.
20
+
21
+ Usage:
22
+ # With uv (recommended)
23
+ uv run train_tool_calling.py
24
+
25
+ # Or with pip
26
+ pip install torch transformers datasets trl peft accelerate bitsandbytes trackio
27
+ python train_tool_calling.py
28
+
29
+ Hardware Requirements:
30
+ - Minimum: 1x A100 80GB or 2x A10G 24GB
31
+ - Recommended: 1x A100 80GB for fastest training
32
+ """
33
+
34
+ import os
35
+ import json
36
+ from datasets import load_dataset, Dataset
37
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorForLanguageModeling
38
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
39
+ from trl import SFTTrainer, SFTConfig
40
+ import torch
41
+ import trackio
42
+ from huggingface_hub import hf_hub_download, HfApi, create_repo
43
+
44
+ # ============================================================================
45
+ # CONFIGURATION - Modify these values as needed
46
+ # ============================================================================
47
+
48
+ # Model configuration
49
+ BASE_MODEL = "Tesslate/Synthia-S1-27b"
50
+ OUTPUT_MODEL = "Synthia-S1-27b-tool-calling" # Will be pushed as Codyfederer/Synthia-S1-27b-tool-calling
51
+
52
+ # Dataset configuration
53
+ DATASET_NAME = "nvidia/Nemotron-Agentic-v1"
54
+ DATASET_SPLIT = "tool_calling"
55
+ MAX_SAMPLES = None # Set to a number (e.g., 10000) to limit dataset size for testing
56
+
57
+ # Training hyperparameters
58
+ NUM_EPOCHS = 1 # 1 epoch is often sufficient for large datasets
59
+ MAX_SEQ_LENGTH = 4096 # Adjust based on your GPU memory
60
+ BATCH_SIZE = 1 # Per device batch size
61
+ GRADIENT_ACCUMULATION = 16 # Effective batch size = BATCH_SIZE * GRADIENT_ACCUMULATION
62
+ LEARNING_RATE = 2e-4
63
+ WARMUP_RATIO = 0.03
64
+
65
+ # LoRA configuration
66
+ LORA_R = 64 # LoRA rank - higher = more capacity but more memory
67
+ LORA_ALPHA = 128 # LoRA alpha - typically 2x rank
68
+ LORA_DROPOUT = 0.05
69
+
70
+ # Quantization (4-bit for memory efficiency)
71
+ USE_4BIT = False # Using BF16 LoRA on H100 for better quality
72
+
73
+ # Tokenized dataset caching
74
+ TOKENIZED_DATASET_REPO = "Codyfederer/synthia-tool-calling-tokenized"
75
+ SAVE_TOKENIZED = True # Save tokenized dataset to Hub for reuse
76
+ TOKENIZED_DATASET_PRIVATE = True # Make tokenized dataset private
77
+ LOAD_TOKENIZED_IF_EXISTS = True # Skip tokenization if already exists on Hub
78
+
79
+ # Hub configuration
80
+ PUSH_TO_HUB = True
81
+ HUB_PRIVATE = False # Set to True for private model
82
+
83
+ # ============================================================================
84
+ # TRAINING SCRIPT
85
+ # ============================================================================
86
+
87
+ def tokenize_conversation(example, tokenizer, max_length):
88
+ """
89
+ Tokenize a conversation using the model's chat template.
90
+ Returns input_ids, attention_mask, and labels for causal LM training.
91
+ """
92
+ messages = example["messages"]
93
+
94
+ # Apply chat template to get the full text
95
+ text = tokenizer.apply_chat_template(
96
+ messages,
97
+ tokenize=False,
98
+ add_generation_prompt=False
99
+ )
100
+
101
+ # Tokenize the text
102
+ tokenized = tokenizer(
103
+ text,
104
+ truncation=True,
105
+ max_length=max_length,
106
+ padding=False, # We'll pad later in the data collator
107
+ return_tensors=None, # Return lists, not tensors
108
+ )
109
+
110
+ # For causal LM, labels are the same as input_ids (shifted internally by the model)
111
+ tokenized["labels"] = tokenized["input_ids"].copy()
112
+
113
+ return tokenized
114
+
115
+
116
+ def main():
117
+ print("=" * 60)
118
+ print("Tool Calling Fine-tuning for Synthia-S1-27b")
119
+ print("=" * 60)
120
+
121
+ # Initialize Trackio for monitoring
122
+ trackio.init(project="synthia-tool-calling")
123
+
124
+ # Get HF username for hub_model_id
125
+ from huggingface_hub import whoami
126
+ try:
127
+ username = whoami()["name"]
128
+ hub_model_id = f"{username}/{OUTPUT_MODEL}"
129
+ print(f"Will push to: {hub_model_id}")
130
+ except Exception as e:
131
+ print(f"Warning: Not logged in to HF Hub ({e})")
132
+ print("Model will be saved locally only. Run 'huggingface-cli login' to enable Hub push.")
133
+ hub_model_id = OUTPUT_MODEL
134
+ global PUSH_TO_HUB
135
+ PUSH_TO_HUB = False
136
+
137
+ # -------------------------------------------------------------------------
138
+ # Load Tokenizer FIRST (needed for tokenization)
139
+ # -------------------------------------------------------------------------
140
+ print(f"\nLoading tokenizer from {BASE_MODEL}...")
141
+
142
+ tokenizer = AutoTokenizer.from_pretrained(
143
+ BASE_MODEL,
144
+ trust_remote_code=True,
145
+ padding_side="right",
146
+ )
147
+
148
+ # Ensure pad token is set
149
+ if tokenizer.pad_token is None:
150
+ tokenizer.pad_token = tokenizer.eos_token
151
+ tokenizer.pad_token_id = tokenizer.eos_token_id
152
+
153
+ print(f"Vocab size: {len(tokenizer):,}")
154
+
155
+ # -------------------------------------------------------------------------
156
+ # Try to Load Pre-tokenized Dataset from Hub
157
+ # -------------------------------------------------------------------------
158
+ train_dataset = None
159
+ eval_dataset = None
160
+
161
+ if LOAD_TOKENIZED_IF_EXISTS:
162
+ print(f"\nChecking for pre-tokenized dataset: {TOKENIZED_DATASET_REPO}")
163
+ try:
164
+ from datasets import load_dataset as hf_load_dataset
165
+
166
+ # Try to load the tokenized dataset
167
+ tokenized_ds = hf_load_dataset(TOKENIZED_DATASET_REPO)
168
+
169
+ # Check if it has the required columns (input_ids, attention_mask)
170
+ if "train" in tokenized_ds and "input_ids" in tokenized_ds["train"].column_names:
171
+ print(" Found pre-tokenized dataset with input_ids!")
172
+ train_dataset = tokenized_ds["train"]
173
+ eval_dataset = tokenized_ds.get("test", tokenized_ds.get("validation"))
174
+ print(f" Train samples: {len(train_dataset):,}")
175
+ if eval_dataset:
176
+ print(f" Eval samples: {len(eval_dataset):,}")
177
+ else:
178
+ print(" Dataset exists but is not tokenized (no input_ids column)")
179
+ print(" Will re-tokenize and save...")
180
+ except Exception as e:
181
+ print(f" Could not load pre-tokenized dataset: {e}")
182
+ print(" Will tokenize from scratch...")
183
+
184
+ # -------------------------------------------------------------------------
185
+ # Load and Tokenize Dataset (if not loaded from Hub)
186
+ # -------------------------------------------------------------------------
187
+ if train_dataset is None:
188
+ print(f"\nLoading dataset: {DATASET_NAME} ({DATASET_SPLIT} split)...")
189
+
190
+ # Download the JSONL file directly from the dataset repo
191
+ jsonl_file = f"data/{DATASET_SPLIT}.jsonl"
192
+ print(f"Downloading {jsonl_file}...")
193
+
194
+ local_path = hf_hub_download(
195
+ repo_id=DATASET_NAME,
196
+ filename=jsonl_file,
197
+ repo_type="dataset"
198
+ )
199
+ print(f"Downloaded to: {local_path}")
200
+
201
+ # Load JSONL manually to handle schema inconsistencies
202
+ print("Loading and processing JSONL file...")
203
+ processed_examples = []
204
+ skipped = 0
205
+
206
+ with open(local_path, 'r', encoding='utf-8') as f:
207
+ for line_num, line in enumerate(f):
208
+ if line_num % 50000 == 0:
209
+ print(f" Processed {line_num:,} lines...")
210
+ try:
211
+ example = json.loads(line.strip())
212
+ messages = example.get("messages", [])
213
+
214
+ # Convert messages to consistent format
215
+ formatted_messages = []
216
+ for msg in messages:
217
+ role = msg.get("role", "user")
218
+ content = msg.get("content", "")
219
+
220
+ # Handle content that might be a list or complex object
221
+ if isinstance(content, list):
222
+ # For tool calls, content is often a list of dicts
223
+ parts = []
224
+ for item in content:
225
+ if isinstance(item, dict):
226
+ if "text" in item:
227
+ parts.append(item["text"])
228
+ else:
229
+ parts.append(json.dumps(item))
230
+ else:
231
+ parts.append(str(item))
232
+ content = "\n".join(parts) if parts else ""
233
+ elif isinstance(content, dict):
234
+ content = json.dumps(content)
235
+ elif content is None:
236
+ content = ""
237
+ else:
238
+ content = str(content)
239
+
240
+ formatted_messages.append({
241
+ "role": role,
242
+ "content": content
243
+ })
244
+
245
+ # Ensure proper role alternation for chat template
246
+ # Merge consecutive messages with same role, handle tool messages
247
+ if formatted_messages:
248
+ merged_messages = []
249
+ for msg in formatted_messages:
250
+ role = msg["role"]
251
+ content = msg["content"]
252
+
253
+ # Map tool role to assistant (tool responses are from assistant's perspective)
254
+ if role == "tool":
255
+ role = "user" # Tool output is provided to the model like user input
256
+ content = f"[Tool Result]\n{content}"
257
+
258
+ # If same role as previous, merge content
259
+ if merged_messages and merged_messages[-1]["role"] == role:
260
+ merged_messages[-1]["content"] += f"\n\n{content}"
261
+ else:
262
+ merged_messages.append({"role": role, "content": content})
263
+
264
+ # Ensure conversation starts with user and alternates
265
+ if merged_messages and merged_messages[0]["role"] != "user":
266
+ # Prepend a placeholder user message if starts with assistant
267
+ merged_messages.insert(0, {"role": "user", "content": "[Start]"})
268
+
269
+ processed_examples.append({"messages": merged_messages})
270
+
271
+ except Exception as e:
272
+ skipped += 1
273
+ if skipped < 5:
274
+ print(f" Warning: Skipped line {line_num}: {e}")
275
+
276
+ print(f"Loaded {len(processed_examples):,} examples (skipped {skipped})")
277
+
278
+ # Create dataset from processed examples
279
+ dataset = Dataset.from_list(processed_examples)
280
+ print(f"Dataset size: {len(dataset):,} examples")
281
+
282
+ if MAX_SAMPLES and len(dataset) > MAX_SAMPLES:
283
+ dataset = dataset.shuffle(seed=42).select(range(MAX_SAMPLES))
284
+ print(f"Limited to {MAX_SAMPLES:,} samples for training")
285
+
286
+ # Create train/eval split
287
+ split_dataset = dataset.train_test_split(test_size=0.02, seed=42)
288
+ train_dataset = split_dataset["train"]
289
+ eval_dataset = split_dataset["test"]
290
+
291
+ print(f"Train samples: {len(train_dataset):,}")
292
+ print(f"Eval samples: {len(eval_dataset):,}")
293
+
294
+ # -------------------------------------------------------------------------
295
+ # TOKENIZE the dataset (this is the key step!)
296
+ # -------------------------------------------------------------------------
297
+ print(f"\nTokenizing dataset with max_length={MAX_SEQ_LENGTH}...")
298
+ print("This may take a while for large datasets...")
299
+
300
+ # Tokenize train dataset
301
+ train_dataset = train_dataset.map(
302
+ lambda x: tokenize_conversation(x, tokenizer, MAX_SEQ_LENGTH),
303
+ remove_columns=["messages"], # Remove text, keep only tokens
304
+ num_proc=4, # Parallelize
305
+ desc="Tokenizing train",
306
+ )
307
+
308
+ # Tokenize eval dataset
309
+ eval_dataset = eval_dataset.map(
310
+ lambda x: tokenize_conversation(x, tokenizer, MAX_SEQ_LENGTH),
311
+ remove_columns=["messages"],
312
+ num_proc=4,
313
+ desc="Tokenizing eval",
314
+ )
315
+
316
+ print(f"Tokenization complete!")
317
+ print(f"Train dataset columns: {train_dataset.column_names}")
318
+ print(f"Sample input_ids length: {len(train_dataset[0]['input_ids'])}")
319
+
320
+ # Save TOKENIZED dataset to Hub for reuse
321
+ if SAVE_TOKENIZED:
322
+ print(f"\nSaving TOKENIZED dataset to Hub: {TOKENIZED_DATASET_REPO}")
323
+ try:
324
+ # Create the repo if it doesn't exist (private!)
325
+ api = HfApi()
326
+ try:
327
+ create_repo(
328
+ TOKENIZED_DATASET_REPO,
329
+ repo_type="dataset",
330
+ private=TOKENIZED_DATASET_PRIVATE,
331
+ exist_ok=True
332
+ )
333
+ print(f" Created/verified repo (private={TOKENIZED_DATASET_PRIVATE})")
334
+
335
+ # Try to update visibility if repo already exists
336
+ if TOKENIZED_DATASET_PRIVATE:
337
+ try:
338
+ api.update_repo_visibility(
339
+ TOKENIZED_DATASET_REPO,
340
+ repo_type="dataset",
341
+ private=True
342
+ )
343
+ print(f" Ensured repo is private")
344
+ except Exception:
345
+ pass # Ignore if already private or no permission
346
+ except Exception as e:
347
+ print(f" Repo creation note: {e}")
348
+
349
+ # Reset format to ensure data is serializable (not torch tensors)
350
+ train_dataset.reset_format()
351
+ eval_dataset.reset_format()
352
+
353
+ # Verify the data looks correct before pushing
354
+ print(f" Verifying tokenized data...")
355
+ print(f" Train columns: {train_dataset.column_names}")
356
+ print(f" Sample input_ids type: {type(train_dataset[0]['input_ids'])}")
357
+ print(f" Sample input_ids length: {len(train_dataset[0]['input_ids'])}")
358
+ print(f" First 10 tokens: {train_dataset[0]['input_ids'][:10]}")
359
+
360
+ # Push tokenized datasets to Hub (private is set at repo creation)
361
+ print(f" Pushing train split ({len(train_dataset):,} examples)...")
362
+ train_dataset.push_to_hub(
363
+ TOKENIZED_DATASET_REPO,
364
+ split="train",
365
+ )
366
+ print(f" Pushing test split ({len(eval_dataset):,} examples)...")
367
+ eval_dataset.push_to_hub(
368
+ TOKENIZED_DATASET_REPO,
369
+ split="test",
370
+ )
371
+ print(f" SUCCESS! Saved TOKENIZED data to: https://huggingface.co/datasets/{TOKENIZED_DATASET_REPO}")
372
+ print(f" Columns saved: {train_dataset.column_names}")
373
+ print(f" Dataset is private: {TOKENIZED_DATASET_PRIVATE}")
374
+
375
+ # Verify the upload by trying to load it back
376
+ print(f" Verifying upload...")
377
+ try:
378
+ from datasets import load_dataset as verify_load
379
+ verify_ds = verify_load(TOKENIZED_DATASET_REPO, split="train", streaming=True)
380
+ sample = next(iter(verify_ds))
381
+ if "input_ids" in sample:
382
+ print(f" VERIFIED: Dataset contains input_ids with {len(sample['input_ids'])} tokens")
383
+ else:
384
+ print(f" WARNING: Dataset uploaded but input_ids not found in columns: {list(sample.keys())}")
385
+ except Exception as ve:
386
+ print(f" Could not verify upload: {ve}")
387
+
388
+ except Exception as e:
389
+ print(f" ERROR saving to Hub: {e}")
390
+ import traceback
391
+ traceback.print_exc()
392
+ print(" Continuing with training anyway...")
393
+
394
+ # -------------------------------------------------------------------------
395
+ # Load Model with Quantization
396
+ # -------------------------------------------------------------------------
397
+ print(f"\nLoading model: {BASE_MODEL}...")
398
+
399
+ if USE_4BIT:
400
+ print("Using 4-bit quantization (QLoRA)")
401
+ bnb_config = BitsAndBytesConfig(
402
+ load_in_4bit=True,
403
+ bnb_4bit_quant_type="nf4",
404
+ bnb_4bit_compute_dtype=torch.bfloat16,
405
+ bnb_4bit_use_double_quant=True,
406
+ )
407
+ else:
408
+ bnb_config = None
409
+
410
+ model = AutoModelForCausalLM.from_pretrained(
411
+ BASE_MODEL,
412
+ quantization_config=bnb_config,
413
+ device_map="auto",
414
+ trust_remote_code=True,
415
+ torch_dtype=torch.bfloat16,
416
+ attn_implementation="sdpa", # Use PyTorch's Scaled Dot Product Attention
417
+ )
418
+
419
+ if USE_4BIT:
420
+ model = prepare_model_for_kbit_training(model)
421
+
422
+ print(f"Model loaded. Parameters: {model.num_parameters():,}")
423
+
424
+ # -------------------------------------------------------------------------
425
+ # Configure LoRA
426
+ # -------------------------------------------------------------------------
427
+ print(f"\nConfiguring LoRA (r={LORA_R}, alpha={LORA_ALPHA})...")
428
+
429
+ # Target modules for Gemma 3 architecture
430
+ target_modules = [
431
+ "q_proj", "k_proj", "v_proj", "o_proj", # Attention
432
+ "gate_proj", "up_proj", "down_proj", # MLP
433
+ ]
434
+
435
+ lora_config = LoraConfig(
436
+ r=LORA_R,
437
+ lora_alpha=LORA_ALPHA,
438
+ lora_dropout=LORA_DROPOUT,
439
+ target_modules=target_modules,
440
+ bias="none",
441
+ task_type="CAUSAL_LM",
442
+ )
443
+
444
+ model = get_peft_model(model, lora_config)
445
+ model.print_trainable_parameters()
446
+
447
+ # -------------------------------------------------------------------------
448
+ # Training Configuration
449
+ # -------------------------------------------------------------------------
450
+ print("\nConfiguring training...")
451
+
452
+ training_args = SFTConfig(
453
+ output_dir=f"./{OUTPUT_MODEL}",
454
+
455
+ # Training params
456
+ num_train_epochs=NUM_EPOCHS,
457
+ per_device_train_batch_size=BATCH_SIZE,
458
+ per_device_eval_batch_size=BATCH_SIZE,
459
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
460
+
461
+ # Optimizer
462
+ learning_rate=LEARNING_RATE,
463
+ lr_scheduler_type="cosine",
464
+ warmup_ratio=WARMUP_RATIO,
465
+ weight_decay=0.01,
466
+ optim="adamw_torch",
467
+
468
+ # Memory optimization
469
+ gradient_checkpointing=True,
470
+ gradient_checkpointing_kwargs={"use_reentrant": False},
471
+ max_grad_norm=1.0,
472
+
473
+ # Sequence length
474
+ max_length=MAX_SEQ_LENGTH,
475
+ packing=False, # Disable packing for tool calling (preserve conversation structure)
476
+
477
+ # Evaluation
478
+ eval_strategy="steps",
479
+ eval_steps=500,
480
+
481
+ # Saving
482
+ save_strategy="steps",
483
+ save_steps=500,
484
+ save_total_limit=3,
485
+
486
+ # Hub
487
+ push_to_hub=PUSH_TO_HUB,
488
+ hub_model_id=hub_model_id if PUSH_TO_HUB else None,
489
+ hub_strategy="checkpoint",
490
+ hub_private_repo=HUB_PRIVATE,
491
+
492
+ # Logging
493
+ logging_steps=10,
494
+ report_to="trackio",
495
+ run_name=f"lora-r{LORA_R}-lr{LEARNING_RATE}",
496
+
497
+ # Performance
498
+ bf16=True,
499
+ dataloader_num_workers=4,
500
+ dataloader_pin_memory=True,
501
+
502
+ # Reproducibility
503
+ seed=42,
504
+ )
505
+
506
+ # -------------------------------------------------------------------------
507
+ # Initialize Trainer
508
+ # -------------------------------------------------------------------------
509
+ print("\nInitializing trainer...")
510
+
511
+ # Create data collator for padding pre-tokenized data
512
+ data_collator = DataCollatorForLanguageModeling(
513
+ tokenizer=tokenizer,
514
+ mlm=False, # Causal LM, not masked LM
515
+ )
516
+
517
+ # Check if dataset is pre-tokenized
518
+ is_pretokenized = "input_ids" in train_dataset.column_names
519
+ print(f"Dataset is pre-tokenized: {is_pretokenized}")
520
+ print(f"Dataset columns: {train_dataset.column_names}")
521
+
522
+ trainer = SFTTrainer(
523
+ model=model,
524
+ args=training_args,
525
+ train_dataset=train_dataset,
526
+ eval_dataset=eval_dataset,
527
+ processing_class=tokenizer,
528
+ data_collator=data_collator,
529
+ )
530
+
531
+ # -------------------------------------------------------------------------
532
+ # Train!
533
+ # -------------------------------------------------------------------------
534
+ print("\n" + "=" * 60)
535
+ print("Starting training...")
536
+ print("=" * 60 + "\n")
537
+
538
+ trainer.train()
539
+
540
+ # -------------------------------------------------------------------------
541
+ # Save Final Model
542
+ # -------------------------------------------------------------------------
543
+ print("\nSaving final model...")
544
+ trainer.save_model()
545
+
546
+ if PUSH_TO_HUB:
547
+ print(f"Pushing to Hub: {hub_model_id}")
548
+ trainer.push_to_hub()
549
+ print(f"\n✅ Model available at: https://huggingface.co/{hub_model_id}")
550
+ else:
551
+ print(f"Model saved locally to: ./{OUTPUT_MODEL}")
552
+
553
+ print("\n" + "=" * 60)
554
+ print("Training complete!")
555
+ print("=" * 60)
556
+
557
+
558
+ if __name__ == "__main__":
559
+ main()