""" G.U.I.D.E. — Single entry point. Training sequence (each step skipped if checkpoint already exists): 1. EvidenceNER — synthetic data, no download, ~2-5 min CPU 2. NextActionPredictor — synthetic data, no download, <30 s CPU 3. DomainClassifier — requires data/raw/complaints.csv (CFPB dataset) Server startup: 4. FastAPI backend (src/api/main.py) — stdout prefixed with [API] 5. Gradio frontend (ui/app.py) — stdout prefixed with [UI] Usage: python start.py # auto-train missing, then serve python start.py --cfpb_csv /path/f.csv # custom CFPB CSV location python start.py --no-train # skip training entirely python start.py --train # force re-train all models python start.py --train-only # train then exit (no servers) """ from __future__ import annotations import argparse import os import platform import shutil import subprocess import sys import threading import time import urllib.error import urllib.request from pathlib import Path # ── Windows UTF-8 fix ───────────────────────────────────────────────────────── # The default Windows console encoding (cp1252) cannot represent box-drawing # characters (U+2500 ─) or arrows (U+2192 →) used in our output. # Reconfigure before the first print() so the launcher never crashes on encode. if hasattr(sys.stdout, "reconfigure"): sys.stdout.reconfigure(encoding="utf-8", errors="replace") if hasattr(sys.stderr, "reconfigure"): sys.stderr.reconfigure(encoding="utf-8", errors="replace") # ── Environment ─────────────────────────────────────────────────────────────── try: from dotenv import load_dotenv load_dotenv() except ImportError: pass # FastAPI will load .env on startup; we only need it here for ports _ROOT = Path(__file__).parent _API_PORT = int(os.getenv("API_PORT", "8000")) _GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860")) _LOG_LEVEL = os.getenv("LOG_LEVEL", "info") _HF_REPO = os.getenv("HF_MODEL_REPO", "sarav95/guide-models") # ── ANSI colours (disabled when stdout is not a TTY) ───────────────────────── _COLOUR = sys.stdout.isatty() def _c(code: str, text: str) -> str: return f"\033[{code}m{text}\033[0m" if _COLOUR else text _PFX_API = _c("36;1", "[API] ") # bright cyan — 7 chars incl. spaces _PFX_UI = _c("32;1", "[UI] ") # bright green _PFX_START = _c("34;1", "[start]") # bright blue _PFX_WARN = _c("33;1", "[warn] ") # bright yellow _PFX_ERR = _c("31;1", "[error]") # bright red def _log(pfx: str, msg: str) -> None: print(f"{pfx} {msg}", flush=True) def _info(msg: str) -> None: _log(_PFX_START, msg) def _warn(msg: str) -> None: _log(_PFX_WARN, msg) def _err(msg: str) -> None: _log(_PFX_ERR, msg) # ── Argument parsing ────────────────────────────────────────────────────────── def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( prog="python start.py", description="G.U.I.D.E. launcher — trains DL models then starts both servers.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=( "examples:\n" " python start.py # auto-train + serve\n" " python start.py --no-train # serve only (models must exist)\n" " python start.py --train # force retrain + serve\n" " python start.py --train-only # train and exit\n" " python start.py --cfpb_csv ~/data/complaints.csv\n" ), ) p.add_argument( "--cfpb_csv", metavar="PATH", default=None, help=( "Path to the CFPB complaints CSV used to train DomainClassifier. " "Defaults to data/raw/complaints.csv relative to the project root." ), ) mode = p.add_mutually_exclusive_group() mode.add_argument( "--no-train", action="store_true", help="Skip all model training; start servers immediately.", ) mode.add_argument( "--train", action="store_true", help="Force re-train all three models even if checkpoints already exist.", ) p.add_argument( "--train-only", action="store_true", help="Run the training phase then exit without starting servers.", ) p.add_argument( "--download-models", action="store_true", help=( f"Download model checkpoints from HuggingFace ({_HF_REPO}) " "into models/ before starting. Skipped automatically if all " "checkpoints already exist." ), ) return p.parse_args() # ── Checkpoint detection ────────────────────────────────────────────────────── def _hf_ckpt(rel_dir: str) -> bool: """True when a HuggingFace checkpoint directory contains config.json.""" return (_ROOT / rel_dir / "config.json").exists() def _pt_ckpt(rel_path: str) -> bool: """True when a PyTorch .pt file exists at the given relative path.""" return (_ROOT / rel_path).exists() # ── Model download from HuggingFace ────────────────────────────────────────── def _models_complete() -> bool: """True when required checkpoints exist (domain_classifier is optional — has keyword fallback).""" return ( _hf_ckpt("models/evidence_ner") and _pt_ckpt("models/next_action/model.pt") ) def _download_models() -> None: """Download all model checkpoints from HuggingFace if not already present.""" if _models_complete(): _info("All model checkpoints already present — skipping download.") return try: from huggingface_hub import snapshot_download except ImportError: _err("huggingface_hub is not installed. Run: pip install huggingface_hub") sys.exit(1) import shutil _info(f"Downloading model checkpoints from {_HF_REPO!r}…") models_dir = _ROOT / "models" models_dir.mkdir(parents=True, exist_ok=True) # Download to HF cache (no local_dir) and get the snapshot directory. # Using local_dir causes hard-link failures in Docker overlay filesystems, # leaving models/ empty even though the download "succeeds". snapshot_path = Path(snapshot_download(repo_id=_HF_REPO, repo_type="model")) _info(f"Snapshot cached at: {snapshot_path}") for subdir in ("evidence_ner", "next_action", "domain_classifier"): src = snapshot_path / subdir dst = models_dir / subdir if dst.exists(): _info(f" {subdir} already present — skipped.") elif src.exists(): shutil.copytree(str(src), str(dst)) _info(f" Copied {subdir}.") else: _warn(f" {subdir} not found in snapshot — skipping.") _info("Model download complete.") # ── Pre-flight checks ──────────────────────────────────────────────────────── def _check_tesseract() -> None: """Exit with an actionable message if the Tesseract OCR binary is missing.""" # 1. Check if it's already in the System PATH if shutil.which("tesseract") is not None: return # 2. Check default Windows installation path if running on Windows if platform.system() == "Windows": default_win_path = r"C:\Program Files\Tesseract-OCR\tesseract.exe" if os.path.exists(default_win_path): # Dynamically add it to the environment PATH for this session os.environ["PATH"] += os.pathsep + os.path.dirname(default_win_path) return # 3. Fail if not found anywhere _err("Tesseract OCR not found. Install it before starting G.U.I.D.E.:") print("", flush=True) print(" macOS : brew install tesseract", flush=True) print(" Ubuntu : sudo apt install tesseract-ocr", flush=True) print(" Windows : https://github.com/UB-Mannheim/tesseract/wiki", flush=True) print("", flush=True) print("Then re-run: python start.py", flush=True) sys.exit(1) # ── Output relay (subprocess stdout → our stdout with prefix) ───────────────── def _relay(stream, prefix: str) -> None: """ Read lines from *stream* and write them to stdout with *prefix*. Designed to run as a daemon thread; exits when *stream* closes. """ try: for line in stream: sys.stdout.write(f"{prefix} {line}") sys.stdout.flush() except Exception: pass # ── Training helpers ────────────────────────────────────────────────────────── def _run_training(label: str, cmd: list[str]) -> bool: """ Run *cmd* synchronously with inherited stdio (no prefix — output flows directly to the terminal so the user can watch training progress). Returns True on success (exit code 0), False otherwise. """ _info(f"Training {label}…") _info(f" $ {' '.join(str(c) for c in cmd)}") result = subprocess.run(cmd, cwd=_ROOT) if result.returncode != 0: _warn(f"{label} training exited {result.returncode} — continuing with fallback.") return False _info(f"{label} training complete.") return True def _train_models(args: argparse.Namespace) -> None: """Train all three DL models in dependency order, skipping existing checkpoints.""" force = args.train _info("─" * 60) _info("Model training phase") _info("─" * 60) # ── 1. EvidenceNER ──────────────────────────────────────────────────────── ner_dir = "models/evidence_ner" if not force and _hf_ckpt(ner_dir): _info(f"EvidenceNER: checkpoint exists at {ner_dir!r} — skipped.") else: _run_training( "EvidenceNER", [sys.executable, "-m", "src.ner.train", "--output_dir", ner_dir], ) # ── 2. NextActionPredictor ──────────────────────────────────────────────── na_path = "models/next_action/model.pt" if not force and _pt_ckpt(na_path): _info(f"NextActionPredictor: checkpoint exists at {na_path!r} — skipped.") else: # Ensure parent directory exists before training writes the file (_ROOT / "models" / "next_action").mkdir(parents=True, exist_ok=True) _run_training( "NextActionPredictor", [sys.executable, "-m", "src.next_action.train", "--output_path", na_path], ) # ── 3. DomainClassifier ─────────────────────────────────────────────────── cls_dir = "models/domain_classifier" if not force and _hf_ckpt(cls_dir): _info(f"DomainClassifier: checkpoint exists at {cls_dir!r} — skipped.") else: cfpb_csv = ( Path(args.cfpb_csv) if args.cfpb_csv else _ROOT / "data" / "raw" / "complaints.csv" ) if cfpb_csv.exists(): _run_training( "DomainClassifier", [sys.executable, "-m", "src.classifier.train", "--cfpb_csv", str(cfpb_csv), "--output_dir", cls_dir], ) else: _warn( f"DomainClassifier: {cfpb_csv} not found — skipped.\n" f"{_PFX_WARN} Download from Kaggle (CFPB Consumer Complaint Database),\n" f"{_PFX_WARN} place in data/raw/complaints.csv, then re-run or use --train.\n" f"{_PFX_WARN} Keyword-based fallback will be active until a checkpoint exists." ) _info("─" * 60) _info("Training phase complete.") _info("─" * 60) # ── Server spawning ─────────────────────────────────────────────────────────── def _spawn( cmd: list[str], prefix: str, extra_env: dict[str, str] | None = None, ) -> subprocess.Popen: """ Launch *cmd* as a subprocess. Relay its merged stdout+stderr to our stdout, prepending each line with *prefix*. PYTHONUNBUFFERED=1 is set so subprocess output is not line-buffered internally; this gives near-real-time log visibility. """ env = {**os.environ, "PYTHONUNBUFFERED": "1"} if extra_env: env.update(extra_env) proc = subprocess.Popen( cmd, cwd=_ROOT, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding="utf-8", errors="replace", bufsize=1, ) relay_thread = threading.Thread( target=_relay, args=(proc.stdout, prefix), daemon=True, name=f"relay-{prefix.strip()}", ) relay_thread.start() return proc def _start_backend() -> subprocess.Popen: cmd = [ sys.executable, "-m", "uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", str(_API_PORT), "--log-level", _LOG_LEVEL, ] _info(f"Starting FastAPI backend (port {_API_PORT})…") return _spawn(cmd, _PFX_API) def _start_frontend() -> subprocess.Popen: cmd = [sys.executable, str(_ROOT / "ui" / "app.py")] _info(f"Starting Gradio frontend (port {_GRADIO_PORT})…") return _spawn( cmd, _PFX_UI, extra_env={"GRADIO_SERVER_PORT": str(_GRADIO_PORT)}, ) # ── Backend health polling ──────────────────────────────────────────────────── def _wait_for_backend(backend: subprocess.Popen, timeout: float = 90.0) -> bool: """ Poll GET /api/health every 2 s until the backend responds 200 or the process dies. Returns True when the backend is ready, False on failure. """ _info(f"Waiting for backend health check (up to {int(timeout)} s)…") url = f"http://localhost:{_API_PORT}/api/health" deadline = time.monotonic() + timeout while time.monotonic() < deadline: # Abort early if the process already died if backend.poll() is not None: return False try: with urllib.request.urlopen(url, timeout=2) as resp: if resp.status == 200: return True except (urllib.error.URLError, OSError): pass time.sleep(2) return False # ── Graceful shutdown ───────────────────────────────────────────────────────── def _shutdown(procs: list[subprocess.Popen], timeout: float = 8.0) -> None: """SIGTERM all processes; SIGKILL any that don't exit within *timeout* s.""" for proc in procs: if proc.poll() is None: proc.terminate() deadline = time.monotonic() + timeout for proc in procs: remaining = max(0.1, deadline - time.monotonic()) try: proc.wait(timeout=remaining) except subprocess.TimeoutExpired: _warn(f"Process {proc.pid} still alive after {timeout:.0f} s — killing.") proc.kill() # ── Main ────────────────────────────────────────────────────────────────────── def main() -> None: args = _parse_args() # Banner sep = "=" * 66 print(sep) print(" G.U.I.D.E. — Grievance Utility for Information Extraction,") print(" Drafting and Enrichment") print(sep) print() # ── Pre-flight ──────────────────────────────────────────────────────────── _check_tesseract() # ── Model download phase ────────────────────────────────────────────────── if args.download_models or (not args.no_train and not _models_complete()): _download_models() print() # ── Training phase ──────────────────────────────────────────────────────── if args.no_train: _info("--no-train specified — skipping all model training.") else: _train_models(args) print() if args.train_only: _info("--train-only specified — exiting without starting servers.") return # ── Server phase ────────────────────────────────────────────────────────── backend = _start_backend() if not _wait_for_backend(backend): rc = backend.poll() _err( f"Backend failed to start" + (f" (exit code {rc})" if rc is not None else " (health check timed out)") + ". Check the [API] output above." ) backend.kill() sys.exit(1) _info("Backend is healthy.") frontend = _start_frontend() print() print(" " + "─" * 62) print(f" Backend API → http://localhost:{_API_PORT}/docs") print(f" Frontend UI → http://localhost:{_GRADIO_PORT}") print(" " + "─" * 62) print(" Press Ctrl+C to stop both servers.") print() # ── Watch loop ──────────────────────────────────────────────────────────── procs = [backend, frontend] try: while True: time.sleep(1) rc_b = backend.poll() rc_f = frontend.poll() if rc_b is not None: _warn(f"Backend process exited (code {rc_b}).") break if rc_f is not None: _warn(f"Frontend process exited (code {rc_f}).") break except KeyboardInterrupt: print() _info("Ctrl+C — shutting down…") finally: _shutdown(procs) _info("All processes stopped.") if __name__ == "__main__": main()