algorythmtechnologies commited on
Commit
b49c004
Β·
verified Β·
1 Parent(s): ec77427

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +168 -164
train.py CHANGED
@@ -1,164 +1,168 @@
1
- import os
2
- import random
3
- import numpy as np
4
- import torch
5
- from datasets import load_dataset
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, EarlyStoppingCallback
7
- from trl import SFTTrainer, SFTConfig
8
- from peft import LoraConfig
9
- from transformers import BitsAndBytesConfig
10
-
11
- # Config from env vars
12
- BASE_MODEL = os.environ.get("BASE_MODEL", "DeepSeek-Coder-V2-Lite-Instruct")
13
- OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "outputs/zenith-lora")
14
- DATA_PATH = os.environ.get("DATA_PATH", "data/zenith.jsonl")
15
- VAL_PATH = os.environ.get("VAL_PATH")
16
- MAX_STEPS = int(os.environ.get("STEPS", 200))
17
- USE_4BIT = os.environ.get("USE_4BIT", "1") == "1"
18
- SEED = int(os.environ.get("SEED", 42))
19
-
20
- os.makedirs(OUTPUT_DIR, exist_ok=True)
21
-
22
- # Set seeds for reproducibility
23
- random.seed(SEED)
24
- np.random.seed(SEED)
25
- torch.manual_seed(SEED)
26
- if torch.cuda.is_available():
27
- torch.cuda.manual_seed_all(SEED)
28
-
29
- print(f"Loading tokenizer and model from: {BASE_MODEL}")
30
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
31
- if tokenizer.pad_token is None:
32
- tokenizer.pad_token = tokenizer.eos_token
33
-
34
- # Set compute dtype based on GPU capability
35
- compute_dtype = torch.float16
36
- if torch.cuda.is_available():
37
- device_cap = torch.cuda.get_device_capability(0)
38
- if device_cap[0] >= 8: # Ampere or higher
39
- print("Using bfloat16 for Ampere GPU")
40
- compute_dtype = torch.bfloat16
41
-
42
- # 4-bit quantization config
43
- bnb_config = BitsAndBytesConfig(
44
- load_in_4bit=True,
45
- bnb_4bit_quant_type="nf4",
46
- bnb_4bit_compute_dtype=compute_dtype,
47
- bnb_4bit_use_double_quant=True,
48
- llm_int8_enable_fp32_cpu_offload=True,
49
- )
50
-
51
- print("Loading model with 4-bit quantization...")
52
- model = AutoModelForCausalLM.from_pretrained(
53
- BASE_MODEL,
54
- quantization_config=bnb_config,
55
- device_map="auto",
56
- trust_remote_code=True,
57
- )
58
-
59
- # Memory-saving configurations
60
- model.config.use_cache = False
61
-
62
- data_files = [DATA_PATH, "data/training_data_v2.jsonl"]
63
- print(f"Loading datasets: {data_files}")
64
- raw_train = load_dataset("json", data_files=data_files, split="train")
65
-
66
- # Optional external validation file
67
- if VAL_PATH:
68
- print(f"Loading validation dataset: {VAL_PATH}")
69
- raw_val = load_dataset("json", data_files=VAL_PATH, split="train")
70
- else:
71
- split = raw_train.train_test_split(test_size=0.05, seed=SEED)
72
- raw_train, raw_val = split["train"], split["test"]
73
-
74
- # Validate and format examples safely
75
- MAX_SEQ_LEN = int(os.environ.get("MAX_SEQ_LEN", 2048))
76
-
77
- def _valid(example):
78
- msgs = example.get("messages")
79
- if not isinstance(msgs, list) or not msgs:
80
- return False
81
- for m in msgs:
82
- if not isinstance(m, dict) or "role" not in m or "content" not in m:
83
- return False
84
- return True
85
-
86
- def _to_text(example):
87
- try:
88
- text = tokenizer.apply_chat_template(
89
- example["messages"], tokenize=False, add_generation_prompt=False
90
- )
91
- return {"text": text}
92
- except Exception:
93
- return {"text": ""}
94
-
95
- train_ds = raw_train.filter(_valid)
96
- val_ds = raw_val.filter(_valid)
97
-
98
- train_ds = train_ds.map(_to_text, remove_columns=train_ds.column_names)
99
- val_ds = val_ds.map(_to_text, remove_columns=val_ds.column_names)
100
-
101
- # Drop empty or pathological items
102
- train_ds = train_ds.filter(lambda x: isinstance(x.get("text"), str) and len(x["text"]) > 0)
103
- val_ds = val_ds.filter(lambda x: isinstance(x.get("text"), str) and len(x["text"]) > 0)
104
-
105
- # LoRA config
106
- peft_config = LoraConfig(
107
- r=int(os.environ.get("LORA_R", 16)),
108
- lora_alpha=int(os.environ.get("LORA_ALPHA", 32)),
109
- lora_dropout=float(os.environ.get("LORA_DROPOUT", 0.05)),
110
- bias="none",
111
- task_type="CAUSAL_LM",
112
- )
113
-
114
- # Training config - step-based for quick runs with stability
115
- training_args = SFTConfig(
116
- output_dir=OUTPUT_DIR,
117
- max_steps=MAX_STEPS, # Use steps instead of epochs for precise timing
118
- per_device_train_batch_size=int(os.environ.get("BATCH", 2)),
119
- gradient_accumulation_steps=int(os.environ.get("GRAD_ACC", 2)),
120
- learning_rate=float(os.environ.get("LR", 1e-4)),
121
- lr_scheduler_type=os.environ.get("LR_SCHED", "cosine"),
122
- warmup_ratio=float(os.environ.get("WARMUP_RATIO", 0.05)),
123
- weight_decay=float(os.environ.get("WEIGHT_DECAY", 0.01)),
124
- max_grad_norm=float(os.environ.get("MAX_GRAD_NORM", 1.0)),
125
- logging_steps=int(os.environ.get("LOG_STEPS", 10)),
126
- save_steps=int(os.environ.get("SAVE_STEPS", 50)),
127
- save_total_limit=int(os.environ.get("SAVE_LIMIT", 3)),
128
- evaluation_strategy="steps",
129
- eval_steps=int(os.environ.get("EVAL_STEPS", 50)),
130
- load_best_model_at_end=True,
131
- metric_for_best_model="eval_loss",
132
- greater_is_better=False,
133
- fp16=torch.cuda.is_available(),
134
- bf16=torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8,
135
- packing=False,
136
- max_seq_length=MAX_SEQ_LEN,
137
- dataloader_drop_last=True,
138
- gradient_checkpointing=True,
139
- gradient_checkpointing_kwargs={"use_reentrant": False},
140
- report_to=os.environ.get("REPORT_TO", "none"),
141
- seed=SEED,
142
- )
143
-
144
- print(f"Starting SFT training for {MAX_STEPS} steps...")
145
- trainer = SFTTrainer(
146
- model=model,
147
- tokenizer=tokenizer,
148
- train_dataset=train_ds,
149
- eval_dataset=val_ds,
150
- peft_config=peft_config,
151
- args=training_args,
152
- dataset_text_field="text",
153
- callbacks=[EarlyStoppingCallback(early_stopping_patience=int(os.environ.get("EARLY_STOP_PATIENCE", 3)))]
154
- )
155
-
156
- trainer.train()
157
-
158
- print("Saving LoRA adapter...")
159
- trainer.model.save_pretrained(OUTPUT_DIR)
160
- tokenizer.save_pretrained(OUTPUT_DIR)
161
-
162
- print(f"βœ… ZENITH LoRA adapter saved to: {OUTPUT_DIR}")
163
- print("🎯 World's most advanced autonomous AI development partner ready!")
164
- print("πŸš€ Ready for Aspetos platform integration!")
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, EarlyStoppingCallback, TrainerCallback
7
+ from trl import SFTTrainer, SFTConfig
8
+ from peft import LoraConfig
9
+ from transformers import BitsAndBytesConfig
10
+
11
+ # ====== CONFIG ======
12
+ BASE_MODEL = os.environ.get("BASE_MODEL", "DeepSeek-Coder-V2-Lite-Instruct")
13
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "outputs/zenith-lora")
14
+ DATA_PATH = os.environ.get("DATA_PATH", "data/zenith_combined.jsonl")
15
+ VAL_PATH = os.environ.get("VAL_PATH")
16
+ MAX_STEPS = int(os.environ.get("STEPS", 300)) # ~2 hr on A100
17
+ SEED = int(os.environ.get("SEED", 42))
18
+
19
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
20
+
21
+ # ====== SEED CONTROL ======
22
+ random.seed(SEED)
23
+ np.random.seed(SEED)
24
+ torch.manual_seed(SEED)
25
+ if torch.cuda.is_available():
26
+ torch.cuda.manual_seed_all(SEED)
27
+
28
+ torch.backends.cuda.matmul.allow_tf32 = True
29
+ torch.backends.cudnn.allow_tf32 = True
30
+
31
+ print(f"πŸš€ Loading tokenizer and model from: {BASE_MODEL}")
32
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ # ====== GPU PRECISION CONFIG ======
37
+ compute_dtype = torch.float16
38
+ if torch.cuda.is_available():
39
+ major, _ = torch.cuda.get_device_capability(0)
40
+ if major >= 8:
41
+ print("βœ… Using bfloat16 for Ampere+ GPU")
42
+ compute_dtype = torch.bfloat16
43
+
44
+ # ====== 4-BIT QUANTIZATION ======
45
+ bnb_config = BitsAndBytesConfig(
46
+ load_in_4bit=True,
47
+ bnb_4bit_quant_type="nf4",
48
+ bnb_4bit_compute_dtype=compute_dtype,
49
+ bnb_4bit_use_double_quant=True,
50
+ )
51
+
52
+ print("βš™οΈ Loading model with 4-bit quantization...")
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ BASE_MODEL,
55
+ quantization_config=bnb_config,
56
+ device_map="auto",
57
+ trust_remote_code=True,
58
+ )
59
+ model.config.use_cache = False
60
+
61
+ # ====== DATASET LOADING ======
62
+ data_files = [DATA_PATH]
63
+ print(f"πŸ“‚ Loading dataset: {data_files}")
64
+ raw_train = load_dataset("json", data_files=data_files, split="train")
65
+
66
+ if VAL_PATH and os.path.exists(VAL_PATH):
67
+ print(f"πŸ“ Using external validation: {VAL_PATH}")
68
+ raw_val = load_dataset("json", data_files=VAL_PATH, split="train")
69
+ else:
70
+ split = raw_train.train_test_split(test_size=0.05, seed=SEED)
71
+ raw_train, raw_val = split["train"], split["test"]
72
+
73
+ MAX_SEQ_LEN = int(os.environ.get("MAX_SEQ_LEN", 2048))
74
+
75
+ def _valid(example):
76
+ msgs = example.get("messages")
77
+ if not isinstance(msgs, list) or not msgs:
78
+ return False
79
+ for m in msgs:
80
+ if not isinstance(m, dict) or "role" not in m or "content" not in m:
81
+ return False
82
+ return True
83
+
84
+ def _to_text(example):
85
+ try:
86
+ text = tokenizer.apply_chat_template(
87
+ example["messages"], tokenize=False, add_generation_prompt=False
88
+ )
89
+ return {"text": text}
90
+ except Exception:
91
+ return {"text": ""}
92
+
93
+ train_ds = raw_train.filter(_valid)
94
+ val_ds = raw_val.filter(_valid)
95
+ train_ds = train_ds.map(_to_text, remove_columns=train_ds.column_names)
96
+ val_ds = val_ds.map(_to_text, remove_columns=val_ds.column_names)
97
+
98
+ train_ds = train_ds.filter(lambda x: len(x.get("text", "")) > 0)
99
+ val_ds = val_ds.filter(lambda x: len(x.get("text", "")) > 0)
100
+
101
+ print(f"βœ… Training samples: {len(train_ds)}, Validation: {len(val_ds)}")
102
+
103
+ # ====== LORA CONFIG (gentle mode) ======
104
+ peft_config = LoraConfig(
105
+ r=int(os.environ.get("LORA_R", 8)),
106
+ lora_alpha=int(os.environ.get("LORA_ALPHA", 16)),
107
+ lora_dropout=float(os.environ.get("LORA_DROPOUT", 0.1)),
108
+ bias="none",
109
+ task_type="CAUSAL_LM",
110
+ )
111
+
112
+ # ====== EVAL CALLBACK ======
113
+ class EvalEveryCallback(TrainerCallback):
114
+ def __init__(self, eval_steps=100):
115
+ self.eval_steps = eval_steps
116
+ def on_step_end(self, args, state, control, **kwargs):
117
+ if state.global_step % self.eval_steps == 0 and state.global_step > 0:
118
+ control.should_evaluate = True
119
+ return control
120
+
121
+ # ====== TRAINING CONFIG ======
122
+ training_args = SFTConfig(
123
+ output_dir=OUTPUT_DIR,
124
+ max_steps=MAX_STEPS,
125
+ per_device_train_batch_size=int(os.environ.get("BATCH", 2)),
126
+ gradient_accumulation_steps=int(os.environ.get("GRAD_ACC", 2)),
127
+ learning_rate=float(os.environ.get("LR", 5e-5)),
128
+ lr_scheduler_type=os.environ.get("LR_SCHED", "cosine"),
129
+ warmup_ratio=float(os.environ.get("WARMUP_RATIO", 0.1)),
130
+ weight_decay=float(os.environ.get("WEIGHT_DECAY", 0.01)),
131
+ max_grad_norm=float(os.environ.get("MAX_GRAD_NORM", 1.0)),
132
+ logging_steps=int(os.environ.get("LOG_STEPS", 10)),
133
+ save_steps=int(os.environ.get("SAVE_STEPS", 50)),
134
+ save_total_limit=int(os.environ.get("SAVE_LIMIT", 2)),
135
+ fp16=torch.cuda.is_available(),
136
+ bf16=torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8,
137
+ max_seq_length=MAX_SEQ_LEN,
138
+ gradient_checkpointing=True,
139
+ gradient_checkpointing_kwargs={"use_reentrant": False},
140
+ dataloader_drop_last=True,
141
+ report_to="none",
142
+ seed=SEED,
143
+ )
144
+
145
+ # ====== TRAINER ======
146
+ print(f"🏁 Starting Zenith fine-tuning for {MAX_STEPS} steps (~2h runtime)...")
147
+ trainer = SFTTrainer(
148
+ model=model,
149
+ tokenizer=tokenizer,
150
+ train_dataset=train_ds,
151
+ eval_dataset=val_ds,
152
+ peft_config=peft_config,
153
+ args=training_args,
154
+ dataset_text_field="text",
155
+ callbacks=[
156
+ EarlyStoppingCallback(early_stopping_patience=int(os.environ.get("EARLY_STOP_PATIENCE", 3))),
157
+ EvalEveryCallback(eval_steps=int(os.environ.get("EVAL_STEPS", 50)))
158
+ ],
159
+ )
160
+
161
+ trainer.train()
162
+
163
+ print("πŸ’Ύ Saving LoRA adapter...")
164
+ trainer.model.save_pretrained(OUTPUT_DIR)
165
+ tokenizer.save_pretrained(OUTPUT_DIR)
166
+
167
+ print(f"βœ… Zenith LoRA adapter saved to: {OUTPUT_DIR}")
168
+ print("🎯 Training complete under 2 hours.")