ligaments-dev commited on
Commit
4e9803a
·
verified ·
1 Parent(s): f3a75ce

Add training script

Browse files
Files changed (1) hide show
  1. train.py +119 -0
train.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from datasets import load_dataset, Dataset
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from trl import SFTTrainer, SFTConfig
6
+ import trackio
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Config
10
+ # ---------------------------------------------------------------------------
11
+ MODEL_ID = "google/gemma-2b-it"
12
+ DATASET_ID = "talkmap/telecom-conversation-corpus"
13
+ OUTPUT_REPO = "ligaments-dev/gemma-2b-telecom-sft"
14
+ MAX_LENGTH = 2048
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Logging / tracking
18
+ # ---------------------------------------------------------------------------
19
+ trackio.init(project="gemma-telecom-sft")
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Load tokenizer
23
+ # ---------------------------------------------------------------------------
24
+ print("Loading tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
+ if tokenizer.pad_token is None:
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Preprocess dataset: group rows by conversation_id into messages
31
+ # ---------------------------------------------------------------------------
32
+ print("Loading dataset...")
33
+ raw = load_dataset(DATASET_ID, split="train")
34
+
35
+ # The dataset has columns: conversation_id, speaker, date_time, text
36
+ # We group by conversation_id and build a single messages list per conversation.
37
+ print("Grouping conversations...")
38
+ conv_map = {}
39
+ for ex in raw:
40
+ cid = ex["conversation_id"]
41
+ if cid not in conv_map:
42
+ conv_map[cid] = []
43
+ role = "user" if ex["speaker"] == "client" else "assistant"
44
+ conv_map[cid].append({"role": role, "content": ex["text"]})
45
+
46
+ # Build a huggingface Dataset from the grouped conversations
47
+ messages_list = []
48
+ for cid, msgs in conv_map.items():
49
+ # Keep only alternating user/assistant; skip system-like turns if any.
50
+ # Gemma-2b-it chat template expects user/assistant.
51
+ messages_list.append({"messages": msgs})
52
+
53
+ train_dataset = Dataset.from_list(messages_list)
54
+ print(f"Prepared {len(train_dataset)} conversation examples.")
55
+
56
+ # Save processed dataset so we don't recompute on resume
57
+ processed_path = "/tmp/telecom_processed"
58
+ train_dataset.save_to_disk(processed_path)
59
+ print(f"Saved processed dataset to {processed_path}")
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Load model
63
+ # ---------------------------------------------------------------------------
64
+ print("Loading model...")
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ MODEL_ID,
67
+ torch_dtype="auto",
68
+ device_map="auto",
69
+ )
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Training arguments
73
+ # ---------------------------------------------------------------------------
74
+ training_args = SFTConfig(
75
+ output_dir="/tmp/gemma-telecom-sft",
76
+ num_train_epochs=3,
77
+ per_device_train_batch_size=2,
78
+ gradient_accumulation_steps=8,
79
+ learning_rate=2e-5,
80
+ warmup_ratio=0.1,
81
+ lr_scheduler_type="cosine",
82
+ max_length=MAX_LENGTH,
83
+ packing=True,
84
+ gradient_checkpointing=True,
85
+ bf16=True,
86
+ logging_steps=10,
87
+ save_strategy="epoch",
88
+ push_to_hub=True,
89
+ hub_model_id=OUTPUT_REPO,
90
+ hub_private_repo=False,
91
+ disable_tqdm=True,
92
+ logging_strategy="steps",
93
+ logging_first_step=True,
94
+ report_to=["trackio"],
95
+ )
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Trainer
99
+ # ---------------------------------------------------------------------------
100
+ print("Initializing SFTTrainer...")
101
+ trainer = SFTTrainer(
102
+ model=model,
103
+ args=training_args,
104
+ train_dataset=train_dataset,
105
+ processing_class=tokenizer,
106
+ )
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Train
110
+ # ---------------------------------------------------------------------------
111
+ print("Starting training...")
112
+ trainer.train()
113
+
114
+ # ---------------------------------------------------------------------------
115
+ # Push to hub
116
+ # ---------------------------------------------------------------------------
117
+ print("Pushing model to hub...")
118
+ trainer.push_to_hub(commit_message="Full SFT on telecom conversation corpus")
119
+ print("Done!")