ligaments-dev commited on
Commit
fee27f5
·
verified ·
1 Parent(s): 0749dbc

Add training script for gemma-2b-it telecom full fine-tuning

Browse files
Files changed (1) hide show
  1. train_telecom_gemma.py +131 -0
train_telecom_gemma.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full fine-tuning script:
3
+ Model: google/gemma-2-2b-it
4
+ Dataset: talkmap/telecom-conversation-corpus
5
+ Converts turn-based telecom dialogues -> conversational messages format for SFT.
6
+ """
7
+ import os
8
+ from collections import defaultdict
9
+
10
+ from datasets import load_dataset, Dataset
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ from trl import SFTTrainer, SFTConfig
13
+ import trackio
14
+ import torch
15
+
16
+ # ------------------------------------------------------------------
17
+ # Config
18
+ # ------------------------------------------------------------------
19
+ MODEL_ID = "google/gemma-2-2b-it"
20
+ DATASET_ID = "talkmap/telecom-conversation-corpus"
21
+ OUTPUT_DIR = "./gemma-2b-it-telecom"
22
+ HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
23
+ MAX_SEQ_LENGTH = 2048
24
+
25
+ # ------------------------------------------------------------------
26
+ # Trackio monitoring
27
+ # ------------------------------------------------------------------
28
+ trackio.init(
29
+ project="gemma-telecom-finetune",
30
+ name="gemma-2b-it-full-sft",
31
+ )
32
+
33
+ # ------------------------------------------------------------------
34
+ # Load dataset
35
+ # ------------------------------------------------------------------
36
+ print("Loading dataset...")
37
+ ds = load_dataset(DATASET_ID, split="train")
38
+ print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
39
+
40
+ # Group rows by conversation_id and sort by date_time
41
+ print("Grouping conversations...")
42
+ conversations = defaultdict(list)
43
+ for row in ds:
44
+ conversations[row["conversation_id"]].append(row)
45
+
46
+ for conv_id in conversations:
47
+ conversations[conv_id].sort(key=lambda x: x["date_time"])
48
+
49
+ # Convert each conversation into messages format
50
+ print("Converting to messages format...")
51
+ messages_data = []
52
+ for conv_id, turns in conversations.items():
53
+ messages = []
54
+ for turn in turns:
55
+ role = "user" if turn["speaker"] == "client" else "assistant"
56
+ messages.append({"role": role, "content": turn["text"]})
57
+ messages_data.append({"messages": messages})
58
+
59
+ train_dataset = Dataset.from_list(messages_data)
60
+ print(f"Total conversations: {len(train_dataset)}")
61
+
62
+ # ------------------------------------------------------------------
63
+ # Tokenizer
64
+ # ------------------------------------------------------------------
65
+ print("Loading tokenizer...")
66
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
67
+ if tokenizer.pad_token is None:
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+ tokenizer.pad_token_id = tokenizer.eos_token_id
70
+
71
+ # ------------------------------------------------------------------
72
+ # Model
73
+ # ------------------------------------------------------------------
74
+ print("Loading model...")
75
+ model = AutoModelForCausalLM.from_pretrained(
76
+ MODEL_ID,
77
+ torch_dtype=torch.bfloat16,
78
+ device_map="auto",
79
+ trust_remote_code=True,
80
+ )
81
+
82
+ model.gradient_checkpointing_enable()
83
+
84
+ # ------------------------------------------------------------------
85
+ # Training arguments
86
+ # ------------------------------------------------------------------
87
+ args = SFTConfig(
88
+ output_dir=OUTPUT_DIR,
89
+ hub_model_id=HUB_MODEL_ID,
90
+ push_to_hub=True,
91
+ num_train_epochs=3,
92
+ per_device_train_batch_size=1,
93
+ gradient_accumulation_steps=4,
94
+ learning_rate=2e-5,
95
+ max_seq_length=MAX_SEQ_LENGTH,
96
+ logging_strategy="steps",
97
+ logging_steps=10,
98
+ logging_first_step=True,
99
+ disable_tqdm=True,
100
+ save_strategy="epoch",
101
+ bf16=True,
102
+ gradient_checkpointing=True,
103
+ report_to=["trackio"],
104
+ remove_unused_columns=False,
105
+ )
106
+
107
+ # ------------------------------------------------------------------
108
+ # Trainer
109
+ # ------------------------------------------------------------------
110
+ print("Initializing SFTTrainer...")
111
+ trainer = SFTTrainer(
112
+ model=model,
113
+ args=args,
114
+ train_dataset=train_dataset,
115
+ processing_class=tokenizer,
116
+ )
117
+
118
+ # ------------------------------------------------------------------
119
+ # Train
120
+ # ------------------------------------------------------------------
121
+ print("Starting training...")
122
+ trainer.train()
123
+
124
+ # ------------------------------------------------------------------
125
+ # Save & Push
126
+ # ------------------------------------------------------------------
127
+ print("Saving and pushing to hub...")
128
+ trainer.save_model(OUTPUT_DIR)
129
+ trainer.push_to_hub()
130
+
131
+ print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")