|
|
|
|
|
"""Gradio interface to orchestrate the full ORD reaction training pipeline.""" |
|
|
import os |
|
|
import sys |
|
|
import shutil |
|
|
import json |
|
|
import time |
|
|
import gradio as gr |
|
|
import subprocess |
|
|
import threading |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
from typing import List, Tuple |
|
|
|
|
|
from huggingface_hub import login, hf_hub_download, HfApi, create_repo |
|
|
from datasets import load_dataset, DatasetDict |
|
|
from src.config import ( |
|
|
FORWARD_DATASET_NAME, |
|
|
RETRO_DATASET_NAME, |
|
|
TOKENIZER_NAME, |
|
|
FORWARD_MODEL_NAME, |
|
|
RETRO_MODEL_NAME, |
|
|
STATE_REPO, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_MODEL_TOKEN = os.environ.get("HF_MODEL_TOKEN") or os.environ.get("HF_TOKEN") |
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent |
|
|
SRC_DIR = REPO_ROOT / "src" |
|
|
CACHE_DIR = REPO_ROOT / "cache" |
|
|
HF_CACHE_DIR = REPO_ROOT / "hf_cache" |
|
|
LOG_FILE = REPO_ROOT / "training.log" |
|
|
|
|
|
FORWARD_CACHE_DIR = CACHE_DIR / "forward" |
|
|
RETRO_CACHE_DIR = CACHE_DIR / "retro" |
|
|
FORWARD_MODEL_DIR = REPO_ROOT / "forward_model" |
|
|
RETRO_MODEL_DIR = REPO_ROOT / "retro_model" |
|
|
TOKENIZER_FILE = REPO_ROOT / "tokenizer.json" |
|
|
|
|
|
STATE_FILE = REPO_ROOT / "training_state.json" |
|
|
|
|
|
|
|
|
for path in (CACHE_DIR, HF_CACHE_DIR): |
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
PHASES: List[Tuple[int, str, str]] = [ |
|
|
(1, "Data Preparation (12-24 hours)", "dataset_prepare.py"), |
|
|
(2, "Tokenizer Training (~30 minutes)", "tokenizer_train.py"), |
|
|
(3, "Forward Model Training (4-8 hours)", "train_forward.py"), |
|
|
(4, "Retro Model Training (4-8 hours)", "train_retro.py"), |
|
|
(5, "Evaluation & Sample Inference (~10 minutes)", "evaluate.py"), |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_status = { |
|
|
"running": False, |
|
|
"phase": "Idle", |
|
|
"progress": "Waiting to start...", |
|
|
"last_update": datetime.now(), |
|
|
} |
|
|
|
|
|
HF_API = HfApi(token=HF_MODEL_TOKEN) |
|
|
WEIGHT_FILENAMES = {"pytorch_model.bin", "model.safetensors"} |
|
|
|
|
|
def load_training_state() -> dict: |
|
|
if STATE_FILE.exists(): |
|
|
try: |
|
|
with open(STATE_FILE, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except Exception: |
|
|
pass |
|
|
if HF_MODEL_TOKEN: |
|
|
try: |
|
|
downloaded = hf_hub_download( |
|
|
repo_id=STATE_REPO, |
|
|
filename="training_state.json", |
|
|
repo_type="dataset", |
|
|
token=HF_MODEL_TOKEN, |
|
|
) |
|
|
shutil.copy(downloaded, STATE_FILE) |
|
|
with open(STATE_FILE, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except Exception: |
|
|
return {} |
|
|
return {} |
|
|
|
|
|
|
|
|
def save_training_state(state: dict): |
|
|
if not HF_MODEL_TOKEN: |
|
|
return |
|
|
STATE_FILE.write_text(json.dumps(state, indent=2), encoding="utf-8") |
|
|
try: |
|
|
create_repo(STATE_REPO, repo_type="dataset", exist_ok=True, token=HF_MODEL_TOKEN) |
|
|
HF_API.upload_file( |
|
|
path_or_fileobj=str(STATE_FILE), |
|
|
path_in_repo="training_state.json", |
|
|
repo_id=STATE_REPO, |
|
|
repo_type="dataset", |
|
|
) |
|
|
except Exception as exc: |
|
|
print(f"⚠️ Could not update training state repo: {exc}") |
|
|
|
|
|
|
|
|
training_state = load_training_state() |
|
|
|
|
|
def mark_phase_complete(phase_number: int): |
|
|
training_state[f"phase_{phase_number}"] = { |
|
|
"status": "complete", |
|
|
"timestamp": time.time(), |
|
|
} |
|
|
training_state["last_completed_phase"] = phase_number |
|
|
save_training_state(training_state) |
|
|
|
|
|
|
|
|
def mark_phase_failed(phase_number: int, message: str): |
|
|
training_state[f"phase_{phase_number}"] = { |
|
|
"status": "failed", |
|
|
"timestamp": time.time(), |
|
|
"message": message, |
|
|
} |
|
|
save_training_state(training_state) |
|
|
|
|
|
|
|
|
def _dir_has_arrow_files(path: Path) -> bool: |
|
|
return path.exists() and any(path.glob("*.arrow")) |
|
|
|
|
|
|
|
|
def _ensure_clean_dir(path: Path): |
|
|
if path.exists(): |
|
|
shutil.rmtree(path) |
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
def _download_dataset(repo_id: str, target_dir: Path) -> bool: |
|
|
if (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir): |
|
|
return True |
|
|
if not HF_MODEL_TOKEN: |
|
|
print(f"⚠️ Cannot download dataset {repo_id}: HF_MODEL_TOKEN not set.") |
|
|
return False |
|
|
try: |
|
|
print(f"⬇️ Loading dataset {repo_id} from Hugging Face Hub...") |
|
|
ds = load_dataset(repo_id) |
|
|
if not isinstance(ds, DatasetDict): |
|
|
ds = DatasetDict({k: v for k, v in ds.items()}) |
|
|
_ensure_clean_dir(target_dir) |
|
|
ds.save_to_disk(str(target_dir)) |
|
|
return (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir) |
|
|
except Exception as exc: |
|
|
print(f"⚠️ Could not download dataset {repo_id}: {exc}") |
|
|
return False |
|
|
|
|
|
|
|
|
def _download_tokenizer() -> bool: |
|
|
if TOKENIZER_FILE.exists(): |
|
|
return True |
|
|
try: |
|
|
print("⬇️ Downloading tokenizer artifact...") |
|
|
hf_hub_download( |
|
|
repo_id=TOKENIZER_NAME, |
|
|
repo_type="model", |
|
|
filename="tokenizer.json", |
|
|
local_dir=str(REPO_ROOT), |
|
|
token=HF_MODEL_TOKEN, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
return TOKENIZER_FILE.exists() |
|
|
except Exception as exc: |
|
|
print(f"⚠️ Could not download tokenizer: {exc}") |
|
|
return False |
|
|
|
|
|
|
|
|
def _phase_completed(phase_number: int) -> bool: |
|
|
if phase_number == 1: |
|
|
if _dir_has_arrow_files(FORWARD_CACHE_DIR) and _dir_has_arrow_files(RETRO_CACHE_DIR): |
|
|
return True |
|
|
forward_ok = _download_dataset(FORWARD_DATASET_NAME, FORWARD_CACHE_DIR) |
|
|
retro_ok = _download_dataset(RETRO_DATASET_NAME, RETRO_CACHE_DIR) |
|
|
return forward_ok and retro_ok |
|
|
if phase_number == 2: |
|
|
if TOKENIZER_FILE.exists(): |
|
|
return True |
|
|
try: |
|
|
HF_API.model_info(TOKENIZER_NAME) |
|
|
return _download_tokenizer() |
|
|
except Exception: |
|
|
return False |
|
|
if phase_number == 3: |
|
|
try: |
|
|
info = HF_API.model_info(FORWARD_MODEL_NAME) |
|
|
filenames = {s.rfilename for s in info.siblings} |
|
|
return bool(WEIGHT_FILENAMES & filenames) |
|
|
except Exception: |
|
|
return False |
|
|
if phase_number == 4: |
|
|
try: |
|
|
info = HF_API.model_info(RETRO_MODEL_NAME) |
|
|
filenames = {s.rfilename for s in info.siblings} |
|
|
return bool(WEIGHT_FILENAMES & filenames) |
|
|
except Exception: |
|
|
return False |
|
|
if phase_number == 5: |
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
def _stream_process(command: List[str], env: dict, phase_label: str, log_handle) -> int: |
|
|
"""Run a subprocess while streaming stdout to the log and status panel.""" |
|
|
process = subprocess.Popen( |
|
|
command, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
text=True, |
|
|
bufsize=1, |
|
|
cwd=REPO_ROOT, |
|
|
env=env, |
|
|
) |
|
|
|
|
|
try: |
|
|
for raw_line in iter(process.stdout.readline, ""): |
|
|
if not raw_line: |
|
|
continue |
|
|
|
|
|
log_handle.write(raw_line) |
|
|
log_handle.flush() |
|
|
|
|
|
|
|
|
sys.stdout.write(raw_line) |
|
|
sys.stdout.flush() |
|
|
|
|
|
line = raw_line.strip() |
|
|
training_status["progress"] = f"[{phase_label}] {line}"[:240] |
|
|
training_status["last_update"] = datetime.now() |
|
|
finally: |
|
|
if process.stdout: |
|
|
process.stdout.close() |
|
|
|
|
|
return process.wait() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_log_content() -> str: |
|
|
"""Read the tail of the training log for display.""" |
|
|
if LOG_FILE.exists(): |
|
|
try: |
|
|
with open(LOG_FILE, "r", encoding="utf-8", errors="replace") as f: |
|
|
content = f.read() |
|
|
return content[-4000:] if len(content) > 4000 else content |
|
|
except Exception as exc: |
|
|
return f"Error reading log file: {exc}" |
|
|
return "Logs will appear here once the pipeline starts." |
|
|
|
|
|
|
|
|
def reset_training(): |
|
|
"""Reset status and clear log file.""" |
|
|
if training_status.get("running"): |
|
|
return "⚠️ Training in progress. Wait for completion before resetting." |
|
|
|
|
|
training_status.update({ |
|
|
"running": False, |
|
|
"phase": "Idle", |
|
|
"progress": "Reset complete. Ready to start again.", |
|
|
"last_update": datetime.now(), |
|
|
}) |
|
|
|
|
|
if LOG_FILE.exists(): |
|
|
LOG_FILE.unlink() |
|
|
return get_status() |
|
|
|
|
|
|
|
|
def start_training(start_option: str): |
|
|
"""Kick off the full multi-phase training pipeline in a background thread.""" |
|
|
if training_status["running"]: |
|
|
return "⚠️ Training already running. Use the refresh button to see live updates." |
|
|
|
|
|
if not HF_MODEL_TOKEN: |
|
|
return "❌ HF_MODEL_TOKEN not found. Please add it to your Space secrets." |
|
|
|
|
|
option = start_option or "Auto (skip completed phases)" |
|
|
skip_completed = option.startswith("Auto") |
|
|
if option.startswith("Auto"): |
|
|
start_from = max(1, training_state.get("last_completed_phase", 0) + 1) |
|
|
else: |
|
|
start_from = 1 |
|
|
if option.startswith("Start from Phase"): |
|
|
try: |
|
|
start_from = int(option.split()[3]) |
|
|
except Exception: |
|
|
start_from = 1 |
|
|
skip_completed = False |
|
|
|
|
|
def run_pipeline(): |
|
|
env = os.environ.copy() |
|
|
env.update( |
|
|
{ |
|
|
"HF_MODEL_TOKEN": HF_MODEL_TOKEN, |
|
|
"HF_TOKEN": HF_MODEL_TOKEN, |
|
|
"HUGGING_FACE_HUB_TOKEN": HF_MODEL_TOKEN, |
|
|
"HF_HOME": str(HF_CACHE_DIR), |
|
|
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR), |
|
|
"HF_DATASETS_CACHE": str(HF_CACHE_DIR / "datasets"), |
|
|
"ORD_PROJECT_ROOT": str(REPO_ROOT), |
|
|
} |
|
|
) |
|
|
|
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
(HF_CACHE_DIR / "datasets").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
print("🔐 Logging into Hugging Face Hub...") |
|
|
login(token=HF_MODEL_TOKEN, add_to_git_credential=False) |
|
|
print("✅ Authenticated with Hugging Face Hub") |
|
|
except Exception as exc: |
|
|
print(f"⚠️ Login warning: {exc}") |
|
|
|
|
|
training_status.update( |
|
|
{ |
|
|
"running": True, |
|
|
"phase": "Initializing pipeline...", |
|
|
"progress": f"Starting from phase {start_from} ({option})", |
|
|
} |
|
|
) |
|
|
|
|
|
success = True |
|
|
try: |
|
|
with open(LOG_FILE, "w", encoding="utf-8") as log_f: |
|
|
log_f.write("=" * 72 + "\n") |
|
|
log_f.write("ORD Reaction Translator - Full Training Pipeline\n") |
|
|
log_f.write("=" * 72 + "\n\n") |
|
|
log_f.flush() |
|
|
|
|
|
for phase_number, phase_label, script_name in PHASES: |
|
|
script_path = SRC_DIR / script_name |
|
|
phase_complete = _phase_completed(phase_number) |
|
|
|
|
|
if phase_number < start_from and phase_complete: |
|
|
skip_msg = ( |
|
|
f"⏭️ Skipping Phase {phase_number}: {phase_label} (start phase = {start_from})\n" |
|
|
) |
|
|
log_f.write(skip_msg) |
|
|
log_f.flush() |
|
|
if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete": |
|
|
mark_phase_complete(phase_number) |
|
|
continue |
|
|
|
|
|
if phase_number < start_from and not phase_complete: |
|
|
warn_msg = ( |
|
|
f"⚠️ Phase {phase_number} artifacts missing. Running {phase_label} even though" |
|
|
f" start phase is {start_from}.\n" |
|
|
) |
|
|
log_f.write(warn_msg) |
|
|
log_f.flush() |
|
|
|
|
|
if skip_completed and phase_complete and phase_number >= start_from: |
|
|
skip_msg = f"⏭️ Phase {phase_number} already completed. Skipping {phase_label}.\n" |
|
|
log_f.write(skip_msg) |
|
|
log_f.flush() |
|
|
training_status["phase"] = f"PHASE {phase_number}: {phase_label}" |
|
|
training_status["progress"] = "Already complete—skipping." |
|
|
if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete": |
|
|
mark_phase_complete(phase_number) |
|
|
continue |
|
|
|
|
|
if not script_path.exists(): |
|
|
message = f"Missing script: {script_name}" |
|
|
training_status["progress"] = f"❌ {message}" |
|
|
mark_phase_failed(phase_number, message) |
|
|
success = False |
|
|
break |
|
|
|
|
|
if phase_number == 5 and not (_phase_completed(3) and _phase_completed(4)): |
|
|
msg = ( |
|
|
"⚠️ Skipping evaluation: forward and retro models are not yet available on the Hub." |
|
|
" Complete Phases 3 and 4 before running evaluation.\n" |
|
|
) |
|
|
log_f.write(msg) |
|
|
log_f.flush() |
|
|
training_status["phase"] = f"PHASE {phase_number}: {phase_label}" |
|
|
training_status["progress"] = "Skipped evaluation—models missing." |
|
|
mark_phase_failed(phase_number, "Models missing for evaluation") |
|
|
continue |
|
|
|
|
|
phase_header = f"--- Phase {phase_number}: {phase_label} ---\n" |
|
|
log_f.write(phase_header) |
|
|
log_f.flush() |
|
|
|
|
|
training_status["phase"] = f"PHASE {phase_number}: {phase_label}" |
|
|
training_status["progress"] = "Starting..." |
|
|
return_code = _stream_process( |
|
|
[sys.executable, str(script_path)], env, f"PHASE {phase_number}", log_f |
|
|
) |
|
|
|
|
|
if return_code != 0: |
|
|
message = ( |
|
|
f"{phase_label} failed (exit code {return_code}). Check the logs above." |
|
|
) |
|
|
training_status["progress"] = f"❌ {message}" |
|
|
mark_phase_failed(phase_number, message) |
|
|
success = False |
|
|
break |
|
|
|
|
|
training_status["progress"] = f"✅ {phase_label} completed." |
|
|
mark_phase_complete(phase_number) |
|
|
|
|
|
except Exception as exc: |
|
|
success = False |
|
|
training_status["progress"] = f"❌ Pipeline crashed: {exc}" |
|
|
finally: |
|
|
training_status["running"] = False |
|
|
training_status["last_update"] = datetime.now() |
|
|
|
|
|
if success: |
|
|
training_status.update( |
|
|
{ |
|
|
"phase": "Completed ✅", |
|
|
"progress": "Full pipeline finished. Models and tokenizer pushed to Hugging Face Hub.", |
|
|
} |
|
|
) |
|
|
else: |
|
|
if "phase" not in training_status or "PHASE" not in training_status["phase"]: |
|
|
training_status["phase"] = "Stopped" |
|
|
|
|
|
thread = threading.Thread(target=run_pipeline, daemon=False) |
|
|
thread.start() |
|
|
|
|
|
return get_status() |
|
|
|
|
|
|
|
|
def get_status() -> str: |
|
|
"""Get the current training status and log content.""" |
|
|
log = get_log_content() |
|
|
|
|
|
return f""" |
|
|
**Phase:** {training_status['phase']} |
|
|
**Status:** {'Running ⏳' if training_status['running'] else 'Ready ✅'} |
|
|
**Progress:** {training_status['progress']} |
|
|
|
|
|
--- |
|
|
|
|
|
## 📋 Real-time Logs: |
|
|
|
|
|
``` |
|
|
{log} |
|
|
``` |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="ORD Training") as demo: |
|
|
gr.Markdown("# 🧪 ORD Reaction Training Pipeline") |
|
|
gr.Markdown("Train AI models on 2.4M chemical reactions from Open Reaction Database") |
|
|
|
|
|
phase_selector = gr.Dropdown( |
|
|
label="Resume / start phase", |
|
|
choices=[ |
|
|
"Auto (skip completed phases)", |
|
|
"Start from Phase 1", |
|
|
"Start from Phase 2", |
|
|
"Start from Phase 3", |
|
|
"Start from Phase 4", |
|
|
"Start from Phase 5", |
|
|
], |
|
|
value="Auto (skip completed phases)", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
start_btn = gr.Button("🚀 Start Training", variant="primary", size="lg") |
|
|
refresh_btn = gr.Button("🔄 Refresh Logs", variant="secondary", size="lg") |
|
|
reset_btn = gr.Button("🔧 Reset", size="lg") |
|
|
|
|
|
gr.Markdown("### 📊 Status & Real-Time Logs") |
|
|
status_box = gr.Markdown() |
|
|
|
|
|
|
|
|
start_btn.click(start_training, inputs=phase_selector, outputs=status_box).then(get_status, outputs=status_box) |
|
|
refresh_btn.click(get_status, outputs=status_box) |
|
|
reset_btn.click(reset_training, outputs=status_box) |
|
|
|
|
|
demo.load(get_status, outputs=status_box) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|