Vaishnav14220 commited on
Commit
bef2610
·
1 Parent(s): b92ff93

Persist phase completion state to resume reliably

Browse files
Files changed (2) hide show
  1. app.py +93 -16
  2. src/config.py +2 -0
app.py CHANGED
@@ -3,6 +3,8 @@
3
  import os
4
  import sys
5
  import shutil
 
 
6
  import gradio as gr
7
  import subprocess
8
  import threading
@@ -10,7 +12,7 @@ from pathlib import Path
10
  from datetime import datetime
11
  from typing import List, Tuple
12
 
13
- from huggingface_hub import login, hf_hub_download, HfApi
14
  from datasets import load_dataset, DatasetDict
15
  from src.config import (
16
  FORWARD_DATASET_NAME,
@@ -18,6 +20,7 @@ from src.config import (
18
  TOKENIZER_NAME,
19
  FORWARD_MODEL_NAME,
20
  RETRO_MODEL_NAME,
 
21
  )
22
 
23
  # -----------------------------------------------------------------------------
@@ -37,6 +40,8 @@ FORWARD_MODEL_DIR = REPO_ROOT / "forward_model"
37
  RETRO_MODEL_DIR = REPO_ROOT / "retro_model"
38
  TOKENIZER_FILE = REPO_ROOT / "tokenizer.json"
39
 
 
 
40
  # Ensure working directories exist
41
  for path in (CACHE_DIR, HF_CACHE_DIR):
42
  path.mkdir(parents=True, exist_ok=True)
@@ -62,6 +67,65 @@ training_status = {
62
  HF_API = HfApi(token=HF_MODEL_TOKEN)
63
  WEIGHT_FILENAMES = {"pytorch_model.bin", "model.safetensors"}
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def _dir_has_arrow_files(path: Path) -> bool:
66
  return path.exists() and any(path.glob("*.arrow"))
67
 
@@ -220,7 +284,10 @@ def start_training(start_option: str):
220
 
221
  option = start_option or "Auto (skip completed phases)"
222
  skip_completed = option.startswith("Auto")
223
- start_from = 1
 
 
 
224
  if option.startswith("Start from Phase"):
225
  try:
226
  start_from = int(option.split()[3])
@@ -279,6 +346,8 @@ def start_training(start_option: str):
279
  )
280
  log_f.write(skip_msg)
281
  log_f.flush()
 
 
282
  continue
283
 
284
  if phase_number < start_from and not phase_complete:
@@ -295,13 +364,29 @@ def start_training(start_option: str):
295
  log_f.flush()
296
  training_status["phase"] = f"PHASE {phase_number}: {phase_label}"
297
  training_status["progress"] = "Already complete—skipping."
 
 
298
  continue
299
 
300
  if not script_path.exists():
301
- training_status["progress"] = f"Missing script: {script_name}"
 
 
302
  success = False
303
  break
304
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  phase_header = f"--- Phase {phase_number}: {phase_label} ---\n"
306
  log_f.write(phase_header)
307
  log_f.flush()
@@ -313,24 +398,16 @@ def start_training(start_option: str):
313
  )
314
 
315
  if return_code != 0:
316
- training_status["progress"] = (
317
- f"{phase_label} failed (exit code {return_code}). Check the logs above."
318
  )
 
 
319
  success = False
320
  break
321
 
322
  training_status["progress"] = f"✅ {phase_label} completed."
323
-
324
- if phase_number == 5 and not (_phase_completed(3) and _phase_completed(4)):
325
- msg = (
326
- "⚠️ Skipping evaluation: forward and retro models are not yet available on the Hub."
327
- " Complete Phases 3 and 4 before running evaluation.\n"
328
- )
329
- log_f.write(msg)
330
- log_f.flush()
331
- training_status["phase"] = f"PHASE {phase_number}: {phase_label}"
332
- training_status["progress"] = "Skipped evaluation—models missing."
333
- continue
334
 
335
  except Exception as exc: # pragma: no cover - defensive logging
336
  success = False
 
3
  import os
4
  import sys
5
  import shutil
6
+ import json
7
+ import time
8
  import gradio as gr
9
  import subprocess
10
  import threading
 
12
  from datetime import datetime
13
  from typing import List, Tuple
14
 
15
+ from huggingface_hub import login, hf_hub_download, HfApi, create_repo
16
  from datasets import load_dataset, DatasetDict
17
  from src.config import (
18
  FORWARD_DATASET_NAME,
 
20
  TOKENIZER_NAME,
21
  FORWARD_MODEL_NAME,
22
  RETRO_MODEL_NAME,
23
+ STATE_REPO,
24
  )
25
 
26
  # -----------------------------------------------------------------------------
 
40
  RETRO_MODEL_DIR = REPO_ROOT / "retro_model"
41
  TOKENIZER_FILE = REPO_ROOT / "tokenizer.json"
42
 
43
+ STATE_FILE = REPO_ROOT / "training_state.json"
44
+
45
  # Ensure working directories exist
46
  for path in (CACHE_DIR, HF_CACHE_DIR):
47
  path.mkdir(parents=True, exist_ok=True)
 
67
  HF_API = HfApi(token=HF_MODEL_TOKEN)
68
  WEIGHT_FILENAMES = {"pytorch_model.bin", "model.safetensors"}
69
 
70
+ def load_training_state() -> dict:
71
+ if STATE_FILE.exists():
72
+ try:
73
+ with open(STATE_FILE, "r", encoding="utf-8") as f:
74
+ return json.load(f)
75
+ except Exception:
76
+ pass
77
+ if HF_MODEL_TOKEN:
78
+ try:
79
+ downloaded = hf_hub_download(
80
+ repo_id=STATE_REPO,
81
+ filename="training_state.json",
82
+ repo_type="dataset",
83
+ token=HF_MODEL_TOKEN,
84
+ )
85
+ shutil.copy(downloaded, STATE_FILE)
86
+ with open(STATE_FILE, "r", encoding="utf-8") as f:
87
+ return json.load(f)
88
+ except Exception:
89
+ return {}
90
+ return {}
91
+
92
+
93
+ def save_training_state(state: dict):
94
+ if not HF_MODEL_TOKEN:
95
+ return
96
+ STATE_FILE.write_text(json.dumps(state, indent=2), encoding="utf-8")
97
+ try:
98
+ create_repo(STATE_REPO, repo_type="dataset", exist_ok=True, token=HF_MODEL_TOKEN)
99
+ HF_API.upload_file(
100
+ path_or_fileobj=str(STATE_FILE),
101
+ path_in_repo="training_state.json",
102
+ repo_id=STATE_REPO,
103
+ repo_type="dataset",
104
+ )
105
+ except Exception as exc:
106
+ print(f"⚠️ Could not update training state repo: {exc}")
107
+
108
+
109
+ training_state = load_training_state()
110
+
111
+ def mark_phase_complete(phase_number: int):
112
+ training_state[f"phase_{phase_number}"] = {
113
+ "status": "complete",
114
+ "timestamp": time.time(),
115
+ }
116
+ training_state["last_completed_phase"] = phase_number
117
+ save_training_state(training_state)
118
+
119
+
120
+ def mark_phase_failed(phase_number: int, message: str):
121
+ training_state[f"phase_{phase_number}"] = {
122
+ "status": "failed",
123
+ "timestamp": time.time(),
124
+ "message": message,
125
+ }
126
+ save_training_state(training_state)
127
+
128
+
129
  def _dir_has_arrow_files(path: Path) -> bool:
130
  return path.exists() and any(path.glob("*.arrow"))
131
 
 
284
 
285
  option = start_option or "Auto (skip completed phases)"
286
  skip_completed = option.startswith("Auto")
287
+ if option.startswith("Auto"):
288
+ start_from = max(1, training_state.get("last_completed_phase", 0) + 1)
289
+ else:
290
+ start_from = 1
291
  if option.startswith("Start from Phase"):
292
  try:
293
  start_from = int(option.split()[3])
 
346
  )
347
  log_f.write(skip_msg)
348
  log_f.flush()
349
+ if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete":
350
+ mark_phase_complete(phase_number)
351
  continue
352
 
353
  if phase_number < start_from and not phase_complete:
 
364
  log_f.flush()
365
  training_status["phase"] = f"PHASE {phase_number}: {phase_label}"
366
  training_status["progress"] = "Already complete—skipping."
367
+ if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete":
368
+ mark_phase_complete(phase_number)
369
  continue
370
 
371
  if not script_path.exists():
372
+ message = f"Missing script: {script_name}"
373
+ training_status["progress"] = f"❌ {message}"
374
+ mark_phase_failed(phase_number, message)
375
  success = False
376
  break
377
 
378
+ if phase_number == 5 and not (_phase_completed(3) and _phase_completed(4)):
379
+ msg = (
380
+ "⚠️ Skipping evaluation: forward and retro models are not yet available on the Hub."
381
+ " Complete Phases 3 and 4 before running evaluation.\n"
382
+ )
383
+ log_f.write(msg)
384
+ log_f.flush()
385
+ training_status["phase"] = f"PHASE {phase_number}: {phase_label}"
386
+ training_status["progress"] = "Skipped evaluation—models missing."
387
+ mark_phase_failed(phase_number, "Models missing for evaluation")
388
+ continue
389
+
390
  phase_header = f"--- Phase {phase_number}: {phase_label} ---\n"
391
  log_f.write(phase_header)
392
  log_f.flush()
 
398
  )
399
 
400
  if return_code != 0:
401
+ message = (
402
+ f"{phase_label} failed (exit code {return_code}). Check the logs above."
403
  )
404
+ training_status["progress"] = f"❌ {message}"
405
+ mark_phase_failed(phase_number, message)
406
  success = False
407
  break
408
 
409
  training_status["progress"] = f"✅ {phase_label} completed."
410
+ mark_phase_complete(phase_number)
 
 
 
 
 
 
 
 
 
 
411
 
412
  except Exception as exc: # pragma: no cover - defensive logging
413
  success = False
src/config.py CHANGED
@@ -13,6 +13,8 @@ MODELS_DIR = PROJECT_ROOT / "models"
13
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
14
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
15
 
 
 
16
  # Hugging Face Model and Dataset Names
17
  TOKENIZER_NAME = f"{HF_USERNAME}/ord-tokenizer"
18
  FORWARD_MODEL_NAME = f"{HF_USERNAME}/ord-forward-t5"
 
13
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
14
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
15
 
16
+ STATE_REPO = f"{HF_USERNAME}/ord-training-state"
17
+
18
  # Hugging Face Model and Dataset Names
19
  TOKENIZER_NAME = f"{HF_USERNAME}/ord-tokenizer"
20
  FORWARD_MODEL_NAME = f"{HF_USERNAME}/ord-forward-t5"