#!/usr/bin/env python3 """Gradio interface to orchestrate the full ORD reaction training pipeline.""" import os import sys import shutil import json import time import gradio as gr import subprocess import threading from pathlib import Path from datetime import datetime from typing import List, Tuple from huggingface_hub import login, hf_hub_download, HfApi, create_repo from datasets import load_dataset, DatasetDict from src.config import ( FORWARD_DATASET_NAME, RETRO_DATASET_NAME, TOKENIZER_NAME, FORWARD_MODEL_NAME, RETRO_MODEL_NAME, STATE_REPO, ) # ----------------------------------------------------------------------------- # Paths & configuration # ----------------------------------------------------------------------------- HF_MODEL_TOKEN = os.environ.get("HF_MODEL_TOKEN") or os.environ.get("HF_TOKEN") REPO_ROOT = Path(__file__).resolve().parent SRC_DIR = REPO_ROOT / "src" CACHE_DIR = REPO_ROOT / "cache" HF_CACHE_DIR = REPO_ROOT / "hf_cache" LOG_FILE = REPO_ROOT / "training.log" FORWARD_CACHE_DIR = CACHE_DIR / "forward" RETRO_CACHE_DIR = CACHE_DIR / "retro" FORWARD_MODEL_DIR = REPO_ROOT / "forward_model" RETRO_MODEL_DIR = REPO_ROOT / "retro_model" TOKENIZER_FILE = REPO_ROOT / "tokenizer.json" STATE_FILE = REPO_ROOT / "training_state.json" # Ensure working directories exist for path in (CACHE_DIR, HF_CACHE_DIR): path.mkdir(parents=True, exist_ok=True) PHASES: List[Tuple[int, str, str]] = [ (1, "Data Preparation (12-24 hours)", "dataset_prepare.py"), (2, "Tokenizer Training (~30 minutes)", "tokenizer_train.py"), (3, "Forward Model Training (4-8 hours)", "train_forward.py"), (4, "Retro Model Training (4-8 hours)", "train_retro.py"), (5, "Evaluation & Sample Inference (~10 minutes)", "evaluate.py"), ] # ----------------------------------------------------------------------------- # Runtime status handling # ----------------------------------------------------------------------------- training_status = { "running": False, "phase": "Idle", "progress": "Waiting to start...", "last_update": datetime.now(), } HF_API = HfApi(token=HF_MODEL_TOKEN) WEIGHT_FILENAMES = {"pytorch_model.bin", "model.safetensors"} def load_training_state() -> dict: if STATE_FILE.exists(): try: with open(STATE_FILE, "r", encoding="utf-8") as f: return json.load(f) except Exception: pass if HF_MODEL_TOKEN: try: downloaded = hf_hub_download( repo_id=STATE_REPO, filename="training_state.json", repo_type="dataset", token=HF_MODEL_TOKEN, ) shutil.copy(downloaded, STATE_FILE) with open(STATE_FILE, "r", encoding="utf-8") as f: return json.load(f) except Exception: return {} return {} def save_training_state(state: dict): if not HF_MODEL_TOKEN: return STATE_FILE.write_text(json.dumps(state, indent=2), encoding="utf-8") try: create_repo(STATE_REPO, repo_type="dataset", exist_ok=True, token=HF_MODEL_TOKEN) HF_API.upload_file( path_or_fileobj=str(STATE_FILE), path_in_repo="training_state.json", repo_id=STATE_REPO, repo_type="dataset", ) except Exception as exc: print(f"⚠️ Could not update training state repo: {exc}") training_state = load_training_state() def mark_phase_complete(phase_number: int): training_state[f"phase_{phase_number}"] = { "status": "complete", "timestamp": time.time(), } training_state["last_completed_phase"] = phase_number save_training_state(training_state) def mark_phase_failed(phase_number: int, message: str): training_state[f"phase_{phase_number}"] = { "status": "failed", "timestamp": time.time(), "message": message, } save_training_state(training_state) def _dir_has_arrow_files(path: Path) -> bool: return path.exists() and any(path.glob("*.arrow")) def _ensure_clean_dir(path: Path): if path.exists(): shutil.rmtree(path) path.mkdir(parents=True, exist_ok=True) def _download_dataset(repo_id: str, target_dir: Path) -> bool: if (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir): return True if not HF_MODEL_TOKEN: print(f"⚠️ Cannot download dataset {repo_id}: HF_MODEL_TOKEN not set.") return False try: print(f"⬇️ Loading dataset {repo_id} from Hugging Face Hub...") ds = load_dataset(repo_id) if not isinstance(ds, DatasetDict): ds = DatasetDict({k: v for k, v in ds.items()}) _ensure_clean_dir(target_dir) ds.save_to_disk(str(target_dir)) return (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir) except Exception as exc: print(f"⚠️ Could not download dataset {repo_id}: {exc}") return False def _download_tokenizer() -> bool: if TOKENIZER_FILE.exists(): return True try: print("⬇️ Downloading tokenizer artifact...") hf_hub_download( repo_id=TOKENIZER_NAME, repo_type="model", filename="tokenizer.json", local_dir=str(REPO_ROOT), token=HF_MODEL_TOKEN, local_dir_use_symlinks=False, ) return TOKENIZER_FILE.exists() except Exception as exc: print(f"⚠️ Could not download tokenizer: {exc}") return False def _phase_completed(phase_number: int) -> bool: if phase_number == 1: if _dir_has_arrow_files(FORWARD_CACHE_DIR) and _dir_has_arrow_files(RETRO_CACHE_DIR): return True forward_ok = _download_dataset(FORWARD_DATASET_NAME, FORWARD_CACHE_DIR) retro_ok = _download_dataset(RETRO_DATASET_NAME, RETRO_CACHE_DIR) return forward_ok and retro_ok if phase_number == 2: if TOKENIZER_FILE.exists(): return True try: HF_API.model_info(TOKENIZER_NAME) return _download_tokenizer() except Exception: return False if phase_number == 3: try: info = HF_API.model_info(FORWARD_MODEL_NAME) filenames = {s.rfilename for s in info.siblings} return bool(WEIGHT_FILENAMES & filenames) except Exception: return False if phase_number == 4: try: info = HF_API.model_info(RETRO_MODEL_NAME) filenames = {s.rfilename for s in info.siblings} return bool(WEIGHT_FILENAMES & filenames) except Exception: return False if phase_number == 5: return False return False def _stream_process(command: List[str], env: dict, phase_label: str, log_handle) -> int: """Run a subprocess while streaming stdout to the log and status panel.""" process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, cwd=REPO_ROOT, env=env, ) try: for raw_line in iter(process.stdout.readline, ""): if not raw_line: continue log_handle.write(raw_line) log_handle.flush() # Echo to container stdout for real-time terminal visibility sys.stdout.write(raw_line) sys.stdout.flush() line = raw_line.strip() training_status["progress"] = f"[{phase_label}] {line}"[:240] training_status["last_update"] = datetime.now() finally: if process.stdout: process.stdout.close() return process.wait() # ----------------------------------------------------------------------------- # Helpers exposed to UI # ----------------------------------------------------------------------------- def get_log_content() -> str: """Read the tail of the training log for display.""" if LOG_FILE.exists(): try: with open(LOG_FILE, "r", encoding="utf-8", errors="replace") as f: content = f.read() return content[-4000:] if len(content) > 4000 else content except Exception as exc: # pragma: no cover - best effort logging return f"Error reading log file: {exc}" return "Logs will appear here once the pipeline starts." def reset_training(): """Reset status and clear log file.""" if training_status.get("running"): return "⚠️ Training in progress. Wait for completion before resetting." training_status.update({ "running": False, "phase": "Idle", "progress": "Reset complete. Ready to start again.", "last_update": datetime.now(), }) if LOG_FILE.exists(): LOG_FILE.unlink() return get_status() def start_training(start_option: str): """Kick off the full multi-phase training pipeline in a background thread.""" if training_status["running"]: return "⚠️ Training already running. Use the refresh button to see live updates." if not HF_MODEL_TOKEN: return "❌ HF_MODEL_TOKEN not found. Please add it to your Space secrets." option = start_option or "Auto (skip completed phases)" skip_completed = option.startswith("Auto") if option.startswith("Auto"): start_from = max(1, training_state.get("last_completed_phase", 0) + 1) else: start_from = 1 if option.startswith("Start from Phase"): try: start_from = int(option.split()[3]) except Exception: start_from = 1 skip_completed = False def run_pipeline(): env = os.environ.copy() env.update( { "HF_MODEL_TOKEN": HF_MODEL_TOKEN, "HF_TOKEN": HF_MODEL_TOKEN, "HUGGING_FACE_HUB_TOKEN": HF_MODEL_TOKEN, "HF_HOME": str(HF_CACHE_DIR), "TRANSFORMERS_CACHE": str(HF_CACHE_DIR), "HF_DATASETS_CACHE": str(HF_CACHE_DIR / "datasets"), "ORD_PROJECT_ROOT": str(REPO_ROOT), } ) CACHE_DIR.mkdir(parents=True, exist_ok=True) (HF_CACHE_DIR / "datasets").mkdir(parents=True, exist_ok=True) # Authenticate with Hugging Face Hub once up front try: print("🔐 Logging into Hugging Face Hub...") login(token=HF_MODEL_TOKEN, add_to_git_credential=False) print("✅ Authenticated with Hugging Face Hub") except Exception as exc: print(f"⚠️ Login warning: {exc}") training_status.update( { "running": True, "phase": "Initializing pipeline...", "progress": f"Starting from phase {start_from} ({option})", } ) success = True try: with open(LOG_FILE, "w", encoding="utf-8") as log_f: log_f.write("=" * 72 + "\n") log_f.write("ORD Reaction Translator - Full Training Pipeline\n") log_f.write("=" * 72 + "\n\n") log_f.flush() for phase_number, phase_label, script_name in PHASES: script_path = SRC_DIR / script_name phase_complete = _phase_completed(phase_number) if phase_number < start_from and phase_complete: skip_msg = ( f"⏭️ Skipping Phase {phase_number}: {phase_label} (start phase = {start_from})\n" ) log_f.write(skip_msg) log_f.flush() if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete": mark_phase_complete(phase_number) continue if phase_number < start_from and not phase_complete: warn_msg = ( f"⚠️ Phase {phase_number} artifacts missing. Running {phase_label} even though" f" start phase is {start_from}.\n" ) log_f.write(warn_msg) log_f.flush() if skip_completed and phase_complete and phase_number >= start_from: skip_msg = f"⏭️ Phase {phase_number} already completed. Skipping {phase_label}.\n" log_f.write(skip_msg) log_f.flush() training_status["phase"] = f"PHASE {phase_number}: {phase_label}" training_status["progress"] = "Already complete—skipping." if training_state.get(f"phase_{phase_number}", {}).get("status") != "complete": mark_phase_complete(phase_number) continue if not script_path.exists(): message = f"Missing script: {script_name}" training_status["progress"] = f"❌ {message}" mark_phase_failed(phase_number, message) success = False break if phase_number == 5 and not (_phase_completed(3) and _phase_completed(4)): msg = ( "⚠️ Skipping evaluation: forward and retro models are not yet available on the Hub." " Complete Phases 3 and 4 before running evaluation.\n" ) log_f.write(msg) log_f.flush() training_status["phase"] = f"PHASE {phase_number}: {phase_label}" training_status["progress"] = "Skipped evaluation—models missing." mark_phase_failed(phase_number, "Models missing for evaluation") continue phase_header = f"--- Phase {phase_number}: {phase_label} ---\n" log_f.write(phase_header) log_f.flush() training_status["phase"] = f"PHASE {phase_number}: {phase_label}" training_status["progress"] = "Starting..." return_code = _stream_process( [sys.executable, str(script_path)], env, f"PHASE {phase_number}", log_f ) if return_code != 0: message = ( f"{phase_label} failed (exit code {return_code}). Check the logs above." ) training_status["progress"] = f"❌ {message}" mark_phase_failed(phase_number, message) success = False break training_status["progress"] = f"✅ {phase_label} completed." mark_phase_complete(phase_number) except Exception as exc: # pragma: no cover - defensive logging success = False training_status["progress"] = f"❌ Pipeline crashed: {exc}" finally: training_status["running"] = False training_status["last_update"] = datetime.now() if success: training_status.update( { "phase": "Completed ✅", "progress": "Full pipeline finished. Models and tokenizer pushed to Hugging Face Hub.", } ) else: if "phase" not in training_status or "PHASE" not in training_status["phase"]: training_status["phase"] = "Stopped" thread = threading.Thread(target=run_pipeline, daemon=False) thread.start() return get_status() def get_status() -> str: """Get the current training status and log content.""" log = get_log_content() return f""" **Phase:** {training_status['phase']} **Status:** {'Running ⏳' if training_status['running'] else 'Ready ✅'} **Progress:** {training_status['progress']} --- ## 📋 Real-time Logs: ``` {log} ``` """ # Create UI with gr.Blocks(title="ORD Training") as demo: gr.Markdown("# 🧪 ORD Reaction Training Pipeline") gr.Markdown("Train AI models on 2.4M chemical reactions from Open Reaction Database") phase_selector = gr.Dropdown( label="Resume / start phase", choices=[ "Auto (skip completed phases)", "Start from Phase 1", "Start from Phase 2", "Start from Phase 3", "Start from Phase 4", "Start from Phase 5", ], value="Auto (skip completed phases)", ) with gr.Row(): start_btn = gr.Button("🚀 Start Training", variant="primary", size="lg") refresh_btn = gr.Button("🔄 Refresh Logs", variant="secondary", size="lg") reset_btn = gr.Button("🔧 Reset", size="lg") gr.Markdown("### 📊 Status & Real-Time Logs") status_box = gr.Markdown() # Event handlers start_btn.click(start_training, inputs=phase_selector, outputs=status_box).then(get_status, outputs=status_box) refresh_btn.click(get_status, outputs=status_box) reset_btn.click(reset_training, outputs=status_box) demo.load(get_status, outputs=status_box) if __name__ == "__main__": demo.launch()