Vaishnav14220
Persist phase completion state to resume reliably
bef2610
#!/usr/bin/env python3
"""Gradio interface to orchestrate the full ORD reaction training pipeline."""
import os
import sys
import shutil
import json
import time
import gradio as gr
import subprocess
import threading
from pathlib import Path
from datetime import datetime
from typing import List, Tuple
from huggingface_hub import login, hf_hub_download, HfApi, create_repo
from datasets import load_dataset, DatasetDict
from src.config import (
FORWARD_DATASET_NAME,
RETRO_DATASET_NAME,
TOKENIZER_NAME,
FORWARD_MODEL_NAME,
RETRO_MODEL_NAME,
STATE_REPO,
)
# -----------------------------------------------------------------------------
# Paths & configuration
# -----------------------------------------------------------------------------
HF_MODEL_TOKEN = os.environ.get("HF_MODEL_TOKEN") or os.environ.get("HF_TOKEN")
REPO_ROOT = Path(__file__).resolve().parent
SRC_DIR = REPO_ROOT / "src"
CACHE_DIR = REPO_ROOT / "cache"
HF_CACHE_DIR = REPO_ROOT / "hf_cache"
LOG_FILE = REPO_ROOT / "training.log"
FORWARD_CACHE_DIR = CACHE_DIR / "forward"
RETRO_CACHE_DIR = CACHE_DIR / "retro"
FORWARD_MODEL_DIR = REPO_ROOT / "forward_model"
RETRO_MODEL_DIR = REPO_ROOT / "retro_model"
TOKENIZER_FILE = REPO_ROOT / "tokenizer.json"
STATE_FILE = REPO_ROOT / "training_state.json"
# Ensure working directories exist
for path in (CACHE_DIR, HF_CACHE_DIR):
path.mkdir(parents=True, exist_ok=True)
PHASES: List[Tuple[int, str, str]] = [
(1, "Data Preparation (12-24 hours)", "dataset_prepare.py"),
(2, "Tokenizer Training (~30 minutes)", "tokenizer_train.py"),
(3, "Forward Model Training (4-8 hours)", "train_forward.py"),
(4, "Retro Model Training (4-8 hours)", "train_retro.py"),
(5, "Evaluation & Sample Inference (~10 minutes)", "evaluate.py"),
]
# -----------------------------------------------------------------------------
# Runtime status handling
# -----------------------------------------------------------------------------
training_status = {
"running": False,
"phase": "Idle",
"progress": "Waiting to start...",
"last_update": datetime.now(),
}
HF_API = HfApi(token=HF_MODEL_TOKEN)
WEIGHT_FILENAMES = {"pytorch_model.bin", "model.safetensors"}
def load_training_state() -> dict:
if STATE_FILE.exists():
try:
with open(STATE_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
pass
if HF_MODEL_TOKEN:
try:
downloaded = hf_hub_download(
repo_id=STATE_REPO,
filename="training_state.json",
repo_type="dataset",
token=HF_MODEL_TOKEN,
)
shutil.copy(downloaded, STATE_FILE)
with open(STATE_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
return {}
def save_training_state(state: dict):
if not HF_MODEL_TOKEN:
return
STATE_FILE.write_text(json.dumps(state, indent=2), encoding="utf-8")
try:
create_repo(STATE_REPO, repo_type="dataset", exist_ok=True, token=HF_MODEL_TOKEN)
HF_API.upload_file(
path_or_fileobj=str(STATE_FILE),
path_in_repo="training_state.json",
repo_id=STATE_REPO,
repo_type="dataset",
)
except Exception as exc:
print(f"⚠️ Could not update training state repo: {exc}")
training_state = load_training_state()
def mark_phase_complete(phase_number: int):
training_state[f"phase_{phase_number}"] = {
"status": "complete",
"timestamp": time.time(),
}
training_state["last_completed_phase"] = phase_number
save_training_state(training_state)
def mark_phase_failed(phase_number: int, message: str):
training_state[f"phase_{phase_number}"] = {
"status": "failed",
"timestamp": time.time(),
"message": message,
}
save_training_state(training_state)
def _dir_has_arrow_files(path: Path) -> bool:
return path.exists() and any(path.glob("*.arrow"))
def _ensure_clean_dir(path: Path):
if path.exists():
shutil.rmtree(path)
path.mkdir(parents=True, exist_ok=True)
def _download_dataset(repo_id: str, target_dir: Path) -> bool:
if (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir):
return True
if not HF_MODEL_TOKEN:
print(f"⚠️ Cannot download dataset {repo_id}: HF_MODEL_TOKEN not set.")
return False
try:
print(f"⬇️ Loading dataset {repo_id} from Hugging Face Hub...")
ds = load_dataset(repo_id)
if not isinstance(ds, DatasetDict):
ds = DatasetDict({k: v for k, v in ds.items()})
_ensure_clean_dir(target_dir)
ds.save_to_disk(str(target_dir))
return (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir)
except Exception as exc:
print(f"⚠️ Could not download dataset {repo_id}: {exc}")
return False
def _download_tokenizer() -> bool:
if TOKENIZER_FILE.exists():
return True
try:
print("⬇️ Downloading tokenizer artifact...")
hf_hub_download(
repo_id=TOKENIZER_NAME,
repo_type="model",
filename="tokenizer.json",
local_dir=str(REPO_ROOT),
token=HF_MODEL_TOKEN,
local_dir_use_symlinks=False,
)
return TOKENIZER_FILE.exists()
except Exception as exc:
print(f"⚠️ Could not download tokenizer: {exc}")
return False
def _phase_completed(phase_number: int) -> bool:
if phase_number == 1:
if _dir_has_arrow_files(FORWARD_CACHE_DIR) and _dir_has_arrow_files(RETRO_CACHE_DIR):
return True
forward_ok = _download_dataset(FORWARD_DATASET_NAME, FORWARD_CACHE_DIR)
retro_ok = _download_dataset(RETRO_DATASET_NAME, RETRO_CACHE_DIR)
return forward_ok and retro_ok
if phase_number == 2:
if TOKENIZER_FILE.exists():
return True
try:
HF_API.model_info(TOKENIZER_NAME)
return _download_tokenizer()
except Exception:
return False
if phase_number == 3:
try:
info = HF_API.model_info(FORWARD_MODEL_NAME)
filenames = {s.rfilename for s in info.siblings}
return bool(WEIGHT_FILENAMES & filenames)
except Exception:
return False
if phase_number == 4:
try:
info = HF_API.model_info(RETRO_MODEL_NAME)
filenames = {s.rfilename for s in info.siblings}
return bool(WEIGHT_FILENAMES & filenames)
except Exception:
return False
if phase_number == 5:
return False
return False
def _stream_process(command: List[str], env: dict, phase_label: str, log_handle) -> int:
"""Run a subprocess while streaming stdout to the log and status panel."""
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
cwd=REPO_ROOT,
env=env,
)
try:
for raw_line in iter(process.stdout.readline, ""):
if not raw_line:
continue
log_handle.write(raw_line)
log_handle.flush()
# Echo to container stdout for real-time terminal visibility
sys.stdout.write(raw_line)
sys.stdout.flush()
line = raw_line.strip()
training_status["progress"] = f"[{phase_label}] {line}"[:240]
training_status["last_update"] = datetime.now()
finally:
if process.stdout:
process.stdout.close()
return process.wait()
# -----------------------------------------------------------------------------
# Helpers exposed to UI
# -----------------------------------------------------------------------------
def get_log_content() -> str:
"""Read the tail of the training log for display."""
if LOG_FILE.exists():
try:
with open(LOG_FILE, "r", encoding="utf-8", errors="replace") as f:
content = f.read()
return content[-4000:] if len(content) > 4000 else content
except Exception as exc: # pragma: no cover - best effort logging
return f"Error reading log file: {exc}"
return "Logs will appear here once the pipeline starts."
def reset_training():
"""Reset status and clear log file."""
if training_status.get("running"):
return "⚠️ Training in progress. Wait for completion before resetting."
training_status.update({
"running": False,
"phase": "Idle",
"progress": "Reset complete. Ready to start again.",
"last_update": datetime.now(),
})
if LOG_FILE.exists():
LOG_FILE.unlink()
return get_status()
def start_training(start_option: str):
"""Kick off the full multi-phase training pipeline in a background thread."""
if training_status["running"]:
return "⚠️ Training already running. Use the refresh button to see live updates."
if not HF_MODEL_TOKEN:
return "❌ HF_MODEL_TOKEN not found. Please add it to your Space secrets."
option = start_option or "Auto (skip completed phases)"
skip_completed = option.startswith("Auto")
if option.startswith("Auto"):
start_from = max(1, training_state.get("last_completed_phase", 0) + 1)
else:
start_from = 1
if option.startswith("Start from Phase"):
try:
start_from = int(option.split()[3])
except Exception:
start_from = 1
skip_completed = False
def run_pipeline():
env = os.environ.copy()
env.update(
{
"HF_MODEL_TOKEN": HF_MODEL_TOKEN,
"HF_TOKEN": HF_MODEL_TOKEN,
"HUGGING_FACE_HUB_TOKEN": HF_MODEL_TOKEN,
"HF_HOME": str(HF_CACHE_DIR),
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
"HF_DATASETS_CACHE": str(HF_CACHE_DIR / "datasets"),
"ORD_PROJECT_ROOT": str(REPO_ROOT),
}
)
CACHE_DIR.mkdir(parents=True, exist_ok=True)
(HF_CACHE_DIR / "datasets").mkdir(parents=True, exist_ok=True)
# Authenticate with Hugging Face Hub once up front
try:
print("🔐 Logging into Hugging Face Hub...")
login(token=HF_MODEL_TOKEN, add_to_git_credential=False)
print("✅ Authenticated with Hugging Face Hub")
except Exception as exc:
print(f"⚠️ Login warning: {exc}")
training_status.update(
{
"running": True,
"phase": "Initializing pipeline...",
"progress": f"Starting from phase {start_from} ({option})",
}
)
success = True
try:
with open(LOG_FILE, "w", encoding="utf-8") as log_f:
log_f.write("=" * 72 + "\n")
log_f.write("ORD Reaction Translator - Full Training Pipeline\n")
log_f.write("=" * 72 + "\n\n")
log_f.flush()
for phase_number, phase_label, script_name in PHASES:
script_path = SRC_DIR / script_name
phase_complete = _phase_completed(phase_number)
if phase_number < start_from and phase_complete:
skip_msg = (
f"⏭️ Skipping Phase {phase_number}: {phase_label} (start phase = {start_from})\n"
)
log_f.write(skip_msg)
log_f.flush()
if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete":
mark_phase_complete(phase_number)
continue
if phase_number < start_from and not phase_complete:
warn_msg = (
f"⚠️ Phase {phase_number} artifacts missing. Running {phase_label} even though"
f" start phase is {start_from}.\n"
)
log_f.write(warn_msg)
log_f.flush()
if skip_completed and phase_complete and phase_number >= start_from:
skip_msg = f"⏭️ Phase {phase_number} already completed. Skipping {phase_label}.\n"
log_f.write(skip_msg)
log_f.flush()
training_status["phase"] = f"PHASE {phase_number}: {phase_label}"
training_status["progress"] = "Already complete—skipping."
if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete":
mark_phase_complete(phase_number)
continue
if not script_path.exists():
message = f"Missing script: {script_name}"
training_status["progress"] = f"❌ {message}"
mark_phase_failed(phase_number, message)
success = False
break
if phase_number == 5 and not (_phase_completed(3) and _phase_completed(4)):
msg = (
"⚠️ Skipping evaluation: forward and retro models are not yet available on the Hub."
" Complete Phases 3 and 4 before running evaluation.\n"
)
log_f.write(msg)
log_f.flush()
training_status["phase"] = f"PHASE {phase_number}: {phase_label}"
training_status["progress"] = "Skipped evaluation—models missing."
mark_phase_failed(phase_number, "Models missing for evaluation")
continue
phase_header = f"--- Phase {phase_number}: {phase_label} ---\n"
log_f.write(phase_header)
log_f.flush()
training_status["phase"] = f"PHASE {phase_number}: {phase_label}"
training_status["progress"] = "Starting..."
return_code = _stream_process(
[sys.executable, str(script_path)], env, f"PHASE {phase_number}", log_f
)
if return_code != 0:
message = (
f"{phase_label} failed (exit code {return_code}). Check the logs above."
)
training_status["progress"] = f"❌ {message}"
mark_phase_failed(phase_number, message)
success = False
break
training_status["progress"] = f"✅ {phase_label} completed."
mark_phase_complete(phase_number)
except Exception as exc: # pragma: no cover - defensive logging
success = False
training_status["progress"] = f"❌ Pipeline crashed: {exc}"
finally:
training_status["running"] = False
training_status["last_update"] = datetime.now()
if success:
training_status.update(
{
"phase": "Completed ✅",
"progress": "Full pipeline finished. Models and tokenizer pushed to Hugging Face Hub.",
}
)
else:
if "phase" not in training_status or "PHASE" not in training_status["phase"]:
training_status["phase"] = "Stopped"
thread = threading.Thread(target=run_pipeline, daemon=False)
thread.start()
return get_status()
def get_status() -> str:
"""Get the current training status and log content."""
log = get_log_content()
return f"""
**Phase:** {training_status['phase']}
**Status:** {'Running ⏳' if training_status['running'] else 'Ready ✅'}
**Progress:** {training_status['progress']}
---
## 📋 Real-time Logs:
```
{log}
```
"""
# Create UI
with gr.Blocks(title="ORD Training") as demo:
gr.Markdown("# 🧪 ORD Reaction Training Pipeline")
gr.Markdown("Train AI models on 2.4M chemical reactions from Open Reaction Database")
phase_selector = gr.Dropdown(
label="Resume / start phase",
choices=[
"Auto (skip completed phases)",
"Start from Phase 1",
"Start from Phase 2",
"Start from Phase 3",
"Start from Phase 4",
"Start from Phase 5",
],
value="Auto (skip completed phases)",
)
with gr.Row():
start_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
refresh_btn = gr.Button("🔄 Refresh Logs", variant="secondary", size="lg")
reset_btn = gr.Button("🔧 Reset", size="lg")
gr.Markdown("### 📊 Status & Real-Time Logs")
status_box = gr.Markdown()
# Event handlers
start_btn.click(start_training, inputs=phase_selector, outputs=status_box).then(get_status, outputs=status_box)
refresh_btn.click(get_status, outputs=status_box)
reset_btn.click(reset_training, outputs=status_box)
demo.load(get_status, outputs=status_box)
if __name__ == "__main__":
demo.launch()