algorythmtechnologies commited on
Commit
f2d3b70
Β·
verified Β·
1 Parent(s): 706b5ee

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +16 -34
train.py CHANGED
@@ -3,7 +3,7 @@ import random
3
  import numpy as np
4
  import torch
5
  from datasets import load_dataset
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, EarlyStoppingCallback, TrainerCallback
7
  from trl import SFTTrainer, SFTConfig
8
  from peft import LoraConfig
9
  from transformers import BitsAndBytesConfig
@@ -13,7 +13,7 @@ BASE_MODEL = os.environ.get("BASE_MODEL", "DeepSeek-Coder-V2-Lite-Instruct")
13
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "outputs/zenith-lora")
14
  DATA_PATH = os.environ.get("DATA_PATH", "data/zenith_combined.jsonl")
15
  VAL_PATH = os.environ.get("VAL_PATH")
16
- MAX_STEPS = int(os.environ.get("STEPS", 300)) # ~2 hr on A100
17
  SEED = int(os.environ.get("SEED", 42))
18
 
19
  os.makedirs(OUTPUT_DIR, exist_ok=True)
@@ -24,24 +24,20 @@ np.random.seed(SEED)
24
  torch.manual_seed(SEED)
25
  if torch.cuda.is_available():
26
  torch.cuda.manual_seed_all(SEED)
27
-
28
  torch.backends.cuda.matmul.allow_tf32 = True
29
  torch.backends.cudnn.allow_tf32 = True
30
 
 
31
  print(f"πŸš€ Loading tokenizer and model from: {BASE_MODEL}")
32
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
33
  if tokenizer.pad_token is None:
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
36
- # ====== GPU PRECISION CONFIG ======
37
  compute_dtype = torch.float16
38
- if torch.cuda.is_available():
39
- major, _ = torch.cuda.get_device_capability(0)
40
- if major >= 8:
41
- print("βœ… Using bfloat16 for Ampere+ GPU")
42
- compute_dtype = torch.bfloat16
43
 
44
- # ====== 4-BIT QUANTIZATION ======
45
  bnb_config = BitsAndBytesConfig(
46
  load_in_4bit=True,
47
  bnb_4bit_quant_type="nf4",
@@ -58,20 +54,16 @@ model = AutoModelForCausalLM.from_pretrained(
58
  )
59
  model.config.use_cache = False
60
 
61
- # ====== DATASET LOADING ======
62
  data_files = [DATA_PATH]
63
- print(f"πŸ“‚ Loading dataset: {data_files}")
64
  raw_train = load_dataset("json", data_files=data_files, split="train")
65
 
66
  if VAL_PATH and os.path.exists(VAL_PATH):
67
- print(f"πŸ“ Using external validation: {VAL_PATH}")
68
  raw_val = load_dataset("json", data_files=VAL_PATH, split="train")
69
  else:
70
  split = raw_train.train_test_split(test_size=0.05, seed=SEED)
71
  raw_train, raw_val = split["train"], split["test"]
72
 
73
- MAX_SEQ_LEN = int(os.environ.get("MAX_SEQ_LEN", 2048))
74
-
75
  def _valid(example):
76
  msgs = example.get("messages")
77
  if not isinstance(msgs, list) or not msgs:
@@ -83,30 +75,27 @@ def _valid(example):
83
 
84
  def _to_text(example):
85
  try:
86
- text = tokenizer.apply_chat_template(
87
- example["messages"], tokenize=False, add_generation_prompt=False
88
- )
89
  return {"text": text}
90
  except Exception:
91
  return {"text": ""}
92
 
93
- train_ds = raw_train.filter(_valid)
94
- val_ds = raw_val.filter(_valid)
95
- train_ds = train_ds.map(_to_text, remove_columns=train_ds.column_names)
96
- val_ds = val_ds.map(_to_text, remove_columns=val_ds.column_names)
97
 
98
  train_ds = train_ds.filter(lambda x: len(x.get("text", "")) > 0)
99
  val_ds = val_ds.filter(lambda x: len(x.get("text", "")) > 0)
100
 
101
  print(f"βœ… Training samples: {len(train_ds)}, Validation: {len(val_ds)}")
102
 
103
- # ====== LORA CONFIG (gentle mode) ======
104
  peft_config = LoraConfig(
105
  r=int(os.environ.get("LORA_R", 8)),
106
  lora_alpha=int(os.environ.get("LORA_ALPHA", 16)),
107
  lora_dropout=float(os.environ.get("LORA_DROPOUT", 0.1)),
108
  bias="none",
109
  task_type="CAUSAL_LM",
 
110
  )
111
 
112
  # ====== EVAL CALLBACK ======
@@ -132,9 +121,8 @@ training_args = SFTConfig(
132
  logging_steps=int(os.environ.get("LOG_STEPS", 10)),
133
  save_steps=int(os.environ.get("SAVE_STEPS", 50)),
134
  save_total_limit=int(os.environ.get("SAVE_LIMIT", 2)),
135
- fp16=torch.cuda.is_available(),
136
- bf16=torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8,
137
- max_seq_length=MAX_SEQ_LEN,
138
  gradient_checkpointing=True,
139
  gradient_checkpointing_kwargs={"use_reentrant": False},
140
  dataloader_drop_last=True,
@@ -143,19 +131,13 @@ training_args = SFTConfig(
143
  )
144
 
145
  # ====== TRAINER ======
146
- print(f"🏁 Starting Zenith fine-tuning for {MAX_STEPS} steps (~2h runtime)...")
147
  trainer = SFTTrainer(
148
  model=model,
149
- tokenizer=tokenizer,
150
  train_dataset=train_ds,
151
  eval_dataset=val_ds,
152
  peft_config=peft_config,
153
  args=training_args,
154
- dataset_text_field="text",
155
- callbacks=[
156
- EarlyStoppingCallback(early_stopping_patience=int(os.environ.get("EARLY_STOP_PATIENCE", 3))),
157
- EvalEveryCallback(eval_steps=int(os.environ.get("EVAL_STEPS", 50)))
158
- ],
159
  )
160
 
161
  trainer.train()
@@ -165,4 +147,4 @@ trainer.model.save_pretrained(OUTPUT_DIR)
165
  tokenizer.save_pretrained(OUTPUT_DIR)
166
 
167
  print(f"βœ… Zenith LoRA adapter saved to: {OUTPUT_DIR}")
168
- print("🎯 Training complete under 2 hours.")
 
3
  import numpy as np
4
  import torch
5
  from datasets import load_dataset
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, EarlyStoppingCallback
7
  from trl import SFTTrainer, SFTConfig
8
  from peft import LoraConfig
9
  from transformers import BitsAndBytesConfig
 
13
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "outputs/zenith-lora")
14
  DATA_PATH = os.environ.get("DATA_PATH", "data/zenith_combined.jsonl")
15
  VAL_PATH = os.environ.get("VAL_PATH")
16
+ MAX_STEPS = int(os.environ.get("STEPS", 300))
17
  SEED = int(os.environ.get("SEED", 42))
18
 
19
  os.makedirs(OUTPUT_DIR, exist_ok=True)
 
24
  torch.manual_seed(SEED)
25
  if torch.cuda.is_available():
26
  torch.cuda.manual_seed_all(SEED)
 
27
  torch.backends.cuda.matmul.allow_tf32 = True
28
  torch.backends.cudnn.allow_tf32 = True
29
 
30
+ # ====== TOKENIZER & MODEL ======
31
  print(f"πŸš€ Loading tokenizer and model from: {BASE_MODEL}")
32
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
33
  if tokenizer.pad_token is None:
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
 
36
  compute_dtype = torch.float16
37
+ if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8:
38
+ compute_dtype = torch.bfloat16
39
+ print("βœ… Ampere+ GPU detected β€” will prefer bf16 where supported.")
 
 
40
 
 
41
  bnb_config = BitsAndBytesConfig(
42
  load_in_4bit=True,
43
  bnb_4bit_quant_type="nf4",
 
54
  )
55
  model.config.use_cache = False
56
 
57
+ # ====== DATASET ======
58
  data_files = [DATA_PATH]
 
59
  raw_train = load_dataset("json", data_files=data_files, split="train")
60
 
61
  if VAL_PATH and os.path.exists(VAL_PATH):
 
62
  raw_val = load_dataset("json", data_files=VAL_PATH, split="train")
63
  else:
64
  split = raw_train.train_test_split(test_size=0.05, seed=SEED)
65
  raw_train, raw_val = split["train"], split["test"]
66
 
 
 
67
  def _valid(example):
68
  msgs = example.get("messages")
69
  if not isinstance(msgs, list) or not msgs:
 
75
 
76
  def _to_text(example):
77
  try:
78
+ text = tokenizer.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False)
 
 
79
  return {"text": text}
80
  except Exception:
81
  return {"text": ""}
82
 
83
+ train_ds = raw_train.filter(_valid).map(_to_text, remove_columns=raw_train.column_names)
84
+ val_ds = raw_val.filter(_valid).map(_to_text, remove_columns=raw_val.column_names)
 
 
85
 
86
  train_ds = train_ds.filter(lambda x: len(x.get("text", "")) > 0)
87
  val_ds = val_ds.filter(lambda x: len(x.get("text", "")) > 0)
88
 
89
  print(f"βœ… Training samples: {len(train_ds)}, Validation: {len(val_ds)}")
90
 
91
+ # ====== LORA CONFIG ======
92
  peft_config = LoraConfig(
93
  r=int(os.environ.get("LORA_R", 8)),
94
  lora_alpha=int(os.environ.get("LORA_ALPHA", 16)),
95
  lora_dropout=float(os.environ.get("LORA_DROPOUT", 0.1)),
96
  bias="none",
97
  task_type="CAUSAL_LM",
98
+ target_modules=["q_proj", "v_proj"], # Required for LoRA injection
99
  )
100
 
101
  # ====== EVAL CALLBACK ======
 
121
  logging_steps=int(os.environ.get("LOG_STEPS", 10)),
122
  save_steps=int(os.environ.get("SAVE_STEPS", 50)),
123
  save_total_limit=int(os.environ.get("SAVE_LIMIT", 2)),
124
+ fp16=torch.cuda.is_available() and compute_dtype==torch.float16,
125
+ bf16=torch.cuda.is_available() and compute_dtype==torch.bfloat16,
 
126
  gradient_checkpointing=True,
127
  gradient_checkpointing_kwargs={"use_reentrant": False},
128
  dataloader_drop_last=True,
 
131
  )
132
 
133
  # ====== TRAINER ======
134
+ print(f"🏁 Starting Zenith fine-tuning for {MAX_STEPS} steps (~2h config)...")
135
  trainer = SFTTrainer(
136
  model=model,
 
137
  train_dataset=train_ds,
138
  eval_dataset=val_ds,
139
  peft_config=peft_config,
140
  args=training_args,
 
 
 
 
 
141
  )
142
 
143
  trainer.train()
 
147
  tokenizer.save_pretrained(OUTPUT_DIR)
148
 
149
  print(f"βœ… Zenith LoRA adapter saved to: {OUTPUT_DIR}")
150
+ print("🎯 Training complete under ~2 hours.")