Mindigenous commited on
Commit
6a1099b
·
1 Parent(s): 3132f2e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +13 -42
train.py CHANGED
@@ -20,7 +20,7 @@ from utils import ensure_dirs, setup_logger
20
 
21
 
22
  # ==============================
23
- # 🔥 FIXED BACKUP CALLBACK
24
  # ==============================
25
  class BackupCallback(TrainerCallback):
26
  def on_save(self, args, state, control, **kwargs):
@@ -46,42 +46,15 @@ class BackupCallback(TrainerCallback):
46
 
47
  print(f"[BACKUP] Saved: {backup_path}")
48
 
49
- # =========================
50
- # 🔥 FIXED NUMERIC SORT
51
- # =========================
52
- backups = [
53
- f for f in os.listdir("backups")
54
- if f.endswith(".tar.gz")
55
- ]
56
-
57
- backups = sorted(
58
- backups,
59
- key=lambda x: int(x.split("step")[1].split(".")[0])
60
- )
61
-
62
- # =========================
63
- # KEEP LAST 5 BACKUPS
64
- # =========================
65
- if len(backups) > 5:
66
- old_backup = backups[0]
67
- old_path = os.path.join("backups", old_backup)
68
-
69
- if os.path.isfile(old_path):
70
- os.remove(old_path)
71
- print(f"[BACKUP] Removed old backup: {old_backup}")
72
-
73
  except Exception as e:
74
  print(f"[BACKUP ERROR] {e}")
75
- # Never crash training
76
 
77
 
78
  # ==============================
79
  # MODEL PATH RESOLUTION
80
  # ==============================
81
  def _is_valid_hf_model_dir(path: Path) -> bool:
82
- if not path.exists():
83
- return False
84
- return (path / "config.json").exists()
85
 
86
 
87
  def _resolve_model_path(logger) -> Path:
@@ -93,8 +66,7 @@ def _resolve_model_path(logger) -> Path:
93
 
94
  if _is_valid_hf_model_dir(fallback):
95
  logger.warning(
96
- "Primary model path %s is missing HF files. Falling back to %s",
97
- primary.resolve(),
98
  fallback.resolve(),
99
  )
100
  return fallback
@@ -103,7 +75,7 @@ def _resolve_model_path(logger) -> Path:
103
 
104
 
105
  # ==============================
106
- # BUILD MODEL
107
  # ==============================
108
  def _build_model_and_tokenizer(model_path: Path):
109
  tokenizer = AutoTokenizer.from_pretrained(
@@ -115,13 +87,14 @@ def _build_model_and_tokenizer(model_path: Path):
115
  if tokenizer.pad_token is None:
116
  tokenizer.pad_token = tokenizer.eos_token
117
 
 
118
  model = AutoModelForCausalLM.from_pretrained(
119
  model_path,
120
  trust_remote_code=True,
121
- local_files_only=True,
122
- dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
123
  )
124
 
 
125
  lora_cfg = LoraConfig(
126
  r=16,
127
  lora_alpha=32,
@@ -136,7 +109,7 @@ def _build_model_and_tokenizer(model_path: Path):
136
 
137
 
138
  # ==============================
139
- # SMART RESUME
140
  # ==============================
141
  def get_latest_checkpoint(checkpoint_dir):
142
  if not os.path.exists(checkpoint_dir):
@@ -162,19 +135,18 @@ def safe_train(trainer, checkpoint_dir, logger):
162
  latest_checkpoint = get_latest_checkpoint(checkpoint_dir)
163
 
164
  if latest_checkpoint:
165
- logger.info(f"Resuming from checkpoint: {latest_checkpoint}")
166
  try:
167
  trainer.train(resume_from_checkpoint=latest_checkpoint)
168
  return
169
  except Exception as e:
170
- logger.warning(f"Resume failed: {e}")
171
 
172
- logger.warning("No valid checkpoint → starting fresh training")
173
  trainer.train()
174
 
175
 
176
  # ==============================
177
- # MAIN TRAIN FUNCTION
178
  # ==============================
179
  def train(resume: bool):
180
  ensure_dirs([
@@ -210,7 +182,6 @@ def train(resume: bool):
210
  logging_steps=50,
211
  save_steps=250,
212
  save_total_limit=3,
213
- gradient_checkpointing=False,
214
  report_to="none",
215
  remove_unused_columns=False,
216
  )
@@ -229,11 +200,11 @@ def train(resume: bool):
229
  trainer.model.save_pretrained(str(PATHS.lora_output_dir))
230
  tokenizer.save_pretrained(str(PATHS.tokenizer_output_dir))
231
 
232
- print("\n✅ Training complete. Model saved.")
233
 
234
 
235
  # ==============================
236
- # ENTRY POINT
237
  # ==============================
238
  if __name__ == "__main__":
239
  parser = argparse.ArgumentParser()
 
20
 
21
 
22
  # ==============================
23
+ # 🔥 BACKUP CALLBACK
24
  # ==============================
25
  class BackupCallback(TrainerCallback):
26
  def on_save(self, args, state, control, **kwargs):
 
46
 
47
  print(f"[BACKUP] Saved: {backup_path}")
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
  print(f"[BACKUP ERROR] {e}")
 
51
 
52
 
53
  # ==============================
54
  # MODEL PATH RESOLUTION
55
  # ==============================
56
  def _is_valid_hf_model_dir(path: Path) -> bool:
57
+ return path.exists() and (path / "config.json").exists()
 
 
58
 
59
 
60
  def _resolve_model_path(logger) -> Path:
 
66
 
67
  if _is_valid_hf_model_dir(fallback):
68
  logger.warning(
69
+ "Primary model missing using fallback %s",
 
70
  fallback.resolve(),
71
  )
72
  return fallback
 
75
 
76
 
77
  # ==============================
78
+ # BUILD MODEL (FIXED)
79
  # ==============================
80
  def _build_model_and_tokenizer(model_path: Path):
81
  tokenizer = AutoTokenizer.from_pretrained(
 
87
  if tokenizer.pad_token is None:
88
  tokenizer.pad_token = tokenizer.eos_token
89
 
90
+ # 🔥 FIXED MODEL LOADING
91
  model = AutoModelForCausalLM.from_pretrained(
92
  model_path,
93
  trust_remote_code=True,
94
+ use_safetensors=True, # IMPORTANT
 
95
  )
96
 
97
+ # LoRA
98
  lora_cfg = LoraConfig(
99
  r=16,
100
  lora_alpha=32,
 
109
 
110
 
111
  # ==============================
112
+ # CHECKPOINT RESUME (SAFE)
113
  # ==============================
114
  def get_latest_checkpoint(checkpoint_dir):
115
  if not os.path.exists(checkpoint_dir):
 
135
  latest_checkpoint = get_latest_checkpoint(checkpoint_dir)
136
 
137
  if latest_checkpoint:
138
+ logger.info(f"Trying resume from: {latest_checkpoint}")
139
  try:
140
  trainer.train(resume_from_checkpoint=latest_checkpoint)
141
  return
142
  except Exception as e:
143
+ logger.warning(f"Resume failed → starting fresh: {e}")
144
 
 
145
  trainer.train()
146
 
147
 
148
  # ==============================
149
+ # MAIN TRAIN
150
  # ==============================
151
  def train(resume: bool):
152
  ensure_dirs([
 
182
  logging_steps=50,
183
  save_steps=250,
184
  save_total_limit=3,
 
185
  report_to="none",
186
  remove_unused_columns=False,
187
  )
 
200
  trainer.model.save_pretrained(str(PATHS.lora_output_dir))
201
  tokenizer.save_pretrained(str(PATHS.tokenizer_output_dir))
202
 
203
+ print("\n✅ Training complete.")
204
 
205
 
206
  # ==============================
207
+ # ENTRY
208
  # ==============================
209
  if __name__ == "__main__":
210
  parser = argparse.ArgumentParser()