Spaces:
Sleeping
Sleeping
| """ | |
| G.U.I.D.E. β Single entry point. | |
| Training sequence (each step skipped if checkpoint already exists): | |
| 1. EvidenceNER β synthetic data, no download, ~2-5 min CPU | |
| 2. NextActionPredictor β synthetic data, no download, <30 s CPU | |
| 3. DomainClassifier β requires data/raw/complaints.csv (CFPB dataset) | |
| Server startup: | |
| 4. FastAPI backend (src/api/main.py) β stdout prefixed with [API] | |
| 5. Gradio frontend (ui/app.py) β stdout prefixed with [UI] | |
| Usage: | |
| python start.py # auto-train missing, then serve | |
| python start.py --cfpb_csv /path/f.csv # custom CFPB CSV location | |
| python start.py --no-train # skip training entirely | |
| python start.py --train # force re-train all models | |
| python start.py --train-only # train then exit (no servers) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import platform | |
| import shutil | |
| import subprocess | |
| import sys | |
| import threading | |
| import time | |
| import urllib.error | |
| import urllib.request | |
| from pathlib import Path | |
| # ββ Windows UTF-8 fix βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # The default Windows console encoding (cp1252) cannot represent box-drawing | |
| # characters (U+2500 β) or arrows (U+2192 β) used in our output. | |
| # Reconfigure before the first print() so the launcher never crashes on encode. | |
| if hasattr(sys.stdout, "reconfigure"): | |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") | |
| if hasattr(sys.stderr, "reconfigure"): | |
| sys.stderr.reconfigure(encoding="utf-8", errors="replace") | |
| # ββ Environment βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass # FastAPI will load .env on startup; we only need it here for ports | |
| _ROOT = Path(__file__).parent | |
| _API_PORT = int(os.getenv("API_PORT", "8000")) | |
| _GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860")) | |
| _LOG_LEVEL = os.getenv("LOG_LEVEL", "info") | |
| _HF_REPO = os.getenv("HF_MODEL_REPO", "sarav95/guide-models") | |
| # ββ ANSI colours (disabled when stdout is not a TTY) βββββββββββββββββββββββββ | |
| _COLOUR = sys.stdout.isatty() | |
| def _c(code: str, text: str) -> str: | |
| return f"\033[{code}m{text}\033[0m" if _COLOUR else text | |
| _PFX_API = _c("36;1", "[API] ") # bright cyan β 7 chars incl. spaces | |
| _PFX_UI = _c("32;1", "[UI] ") # bright green | |
| _PFX_START = _c("34;1", "[start]") # bright blue | |
| _PFX_WARN = _c("33;1", "[warn] ") # bright yellow | |
| _PFX_ERR = _c("31;1", "[error]") # bright red | |
| def _log(pfx: str, msg: str) -> None: | |
| print(f"{pfx} {msg}", flush=True) | |
| def _info(msg: str) -> None: _log(_PFX_START, msg) | |
| def _warn(msg: str) -> None: _log(_PFX_WARN, msg) | |
| def _err(msg: str) -> None: _log(_PFX_ERR, msg) | |
| # ββ Argument parsing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| prog="python start.py", | |
| description="G.U.I.D.E. launcher β trains DL models then starts both servers.", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=( | |
| "examples:\n" | |
| " python start.py # auto-train + serve\n" | |
| " python start.py --no-train # serve only (models must exist)\n" | |
| " python start.py --train # force retrain + serve\n" | |
| " python start.py --train-only # train and exit\n" | |
| " python start.py --cfpb_csv ~/data/complaints.csv\n" | |
| ), | |
| ) | |
| p.add_argument( | |
| "--cfpb_csv", | |
| metavar="PATH", | |
| default=None, | |
| help=( | |
| "Path to the CFPB complaints CSV used to train DomainClassifier. " | |
| "Defaults to data/raw/complaints.csv relative to the project root." | |
| ), | |
| ) | |
| mode = p.add_mutually_exclusive_group() | |
| mode.add_argument( | |
| "--no-train", | |
| action="store_true", | |
| help="Skip all model training; start servers immediately.", | |
| ) | |
| mode.add_argument( | |
| "--train", | |
| action="store_true", | |
| help="Force re-train all three models even if checkpoints already exist.", | |
| ) | |
| p.add_argument( | |
| "--train-only", | |
| action="store_true", | |
| help="Run the training phase then exit without starting servers.", | |
| ) | |
| p.add_argument( | |
| "--download-models", | |
| action="store_true", | |
| help=( | |
| f"Download model checkpoints from HuggingFace ({_HF_REPO}) " | |
| "into models/ before starting. Skipped automatically if all " | |
| "checkpoints already exist." | |
| ), | |
| ) | |
| return p.parse_args() | |
| # ββ Checkpoint detection ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _hf_ckpt(rel_dir: str) -> bool: | |
| """True when a HuggingFace checkpoint directory contains config.json.""" | |
| return (_ROOT / rel_dir / "config.json").exists() | |
| def _pt_ckpt(rel_path: str) -> bool: | |
| """True when a PyTorch .pt file exists at the given relative path.""" | |
| return (_ROOT / rel_path).exists() | |
| # ββ Model download from HuggingFace ββββββββββββββββββββββββββββββββββββββββββ | |
| def _models_complete() -> bool: | |
| """True when required checkpoints exist (domain_classifier is optional β has keyword fallback).""" | |
| return ( | |
| _hf_ckpt("models/evidence_ner") | |
| and _pt_ckpt("models/next_action/model.pt") | |
| ) | |
| def _download_models() -> None: | |
| """Download all model checkpoints from HuggingFace if not already present.""" | |
| if _models_complete(): | |
| _info("All model checkpoints already present β skipping download.") | |
| return | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except ImportError: | |
| _err("huggingface_hub is not installed. Run: pip install huggingface_hub") | |
| sys.exit(1) | |
| import shutil | |
| _info(f"Downloading model checkpoints from {_HF_REPO!r}β¦") | |
| models_dir = _ROOT / "models" | |
| models_dir.mkdir(parents=True, exist_ok=True) | |
| # Download to HF cache (no local_dir) and get the snapshot directory. | |
| # Using local_dir causes hard-link failures in Docker overlay filesystems, | |
| # leaving models/ empty even though the download "succeeds". | |
| snapshot_path = Path(snapshot_download(repo_id=_HF_REPO, repo_type="model")) | |
| _info(f"Snapshot cached at: {snapshot_path}") | |
| for subdir in ("evidence_ner", "next_action", "domain_classifier"): | |
| src = snapshot_path / subdir | |
| dst = models_dir / subdir | |
| if dst.exists(): | |
| _info(f" {subdir} already present β skipped.") | |
| elif src.exists(): | |
| shutil.copytree(str(src), str(dst)) | |
| _info(f" Copied {subdir}.") | |
| else: | |
| _warn(f" {subdir} not found in snapshot β skipping.") | |
| _info("Model download complete.") | |
| # ββ Pre-flight checks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _check_tesseract() -> None: | |
| """Exit with an actionable message if the Tesseract OCR binary is missing.""" | |
| # 1. Check if it's already in the System PATH | |
| if shutil.which("tesseract") is not None: | |
| return | |
| # 2. Check default Windows installation path if running on Windows | |
| if platform.system() == "Windows": | |
| default_win_path = r"C:\Program Files\Tesseract-OCR\tesseract.exe" | |
| if os.path.exists(default_win_path): | |
| # Dynamically add it to the environment PATH for this session | |
| os.environ["PATH"] += os.pathsep + os.path.dirname(default_win_path) | |
| return | |
| # 3. Fail if not found anywhere | |
| _err("Tesseract OCR not found. Install it before starting G.U.I.D.E.:") | |
| print("", flush=True) | |
| print(" macOS : brew install tesseract", flush=True) | |
| print(" Ubuntu : sudo apt install tesseract-ocr", flush=True) | |
| print(" Windows : https://github.com/UB-Mannheim/tesseract/wiki", flush=True) | |
| print("", flush=True) | |
| print("Then re-run: python start.py", flush=True) | |
| sys.exit(1) | |
| # ββ Output relay (subprocess stdout β our stdout with prefix) βββββββββββββββββ | |
| def _relay(stream, prefix: str) -> None: | |
| """ | |
| Read lines from *stream* and write them to stdout with *prefix*. | |
| Designed to run as a daemon thread; exits when *stream* closes. | |
| """ | |
| try: | |
| for line in stream: | |
| sys.stdout.write(f"{prefix} {line}") | |
| sys.stdout.flush() | |
| except Exception: | |
| pass | |
| # ββ Training helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _run_training(label: str, cmd: list[str]) -> bool: | |
| """ | |
| Run *cmd* synchronously with inherited stdio (no prefix β output flows | |
| directly to the terminal so the user can watch training progress). | |
| Returns True on success (exit code 0), False otherwise. | |
| """ | |
| _info(f"Training {label}β¦") | |
| _info(f" $ {' '.join(str(c) for c in cmd)}") | |
| result = subprocess.run(cmd, cwd=_ROOT) | |
| if result.returncode != 0: | |
| _warn(f"{label} training exited {result.returncode} β continuing with fallback.") | |
| return False | |
| _info(f"{label} training complete.") | |
| return True | |
| def _train_models(args: argparse.Namespace) -> None: | |
| """Train all three DL models in dependency order, skipping existing checkpoints.""" | |
| force = args.train | |
| _info("β" * 60) | |
| _info("Model training phase") | |
| _info("β" * 60) | |
| # ββ 1. EvidenceNER ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ner_dir = "models/evidence_ner" | |
| if not force and _hf_ckpt(ner_dir): | |
| _info(f"EvidenceNER: checkpoint exists at {ner_dir!r} β skipped.") | |
| else: | |
| _run_training( | |
| "EvidenceNER", | |
| [sys.executable, "-m", "src.ner.train", | |
| "--output_dir", ner_dir], | |
| ) | |
| # ββ 2. NextActionPredictor ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| na_path = "models/next_action/model.pt" | |
| if not force and _pt_ckpt(na_path): | |
| _info(f"NextActionPredictor: checkpoint exists at {na_path!r} β skipped.") | |
| else: | |
| # Ensure parent directory exists before training writes the file | |
| (_ROOT / "models" / "next_action").mkdir(parents=True, exist_ok=True) | |
| _run_training( | |
| "NextActionPredictor", | |
| [sys.executable, "-m", "src.next_action.train", | |
| "--output_path", na_path], | |
| ) | |
| # ββ 3. DomainClassifier βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cls_dir = "models/domain_classifier" | |
| if not force and _hf_ckpt(cls_dir): | |
| _info(f"DomainClassifier: checkpoint exists at {cls_dir!r} β skipped.") | |
| else: | |
| cfpb_csv = ( | |
| Path(args.cfpb_csv) | |
| if args.cfpb_csv | |
| else _ROOT / "data" / "raw" / "complaints.csv" | |
| ) | |
| if cfpb_csv.exists(): | |
| _run_training( | |
| "DomainClassifier", | |
| [sys.executable, "-m", "src.classifier.train", | |
| "--cfpb_csv", str(cfpb_csv), | |
| "--output_dir", cls_dir], | |
| ) | |
| else: | |
| _warn( | |
| f"DomainClassifier: {cfpb_csv} not found β skipped.\n" | |
| f"{_PFX_WARN} Download from Kaggle (CFPB Consumer Complaint Database),\n" | |
| f"{_PFX_WARN} place in data/raw/complaints.csv, then re-run or use --train.\n" | |
| f"{_PFX_WARN} Keyword-based fallback will be active until a checkpoint exists." | |
| ) | |
| _info("β" * 60) | |
| _info("Training phase complete.") | |
| _info("β" * 60) | |
| # ββ Server spawning βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _spawn( | |
| cmd: list[str], | |
| prefix: str, | |
| extra_env: dict[str, str] | None = None, | |
| ) -> subprocess.Popen: | |
| """ | |
| Launch *cmd* as a subprocess. Relay its merged stdout+stderr to our | |
| stdout, prepending each line with *prefix*. | |
| PYTHONUNBUFFERED=1 is set so subprocess output is not line-buffered | |
| internally; this gives near-real-time log visibility. | |
| """ | |
| env = {**os.environ, "PYTHONUNBUFFERED": "1"} | |
| if extra_env: | |
| env.update(extra_env) | |
| proc = subprocess.Popen( | |
| cmd, | |
| cwd=_ROOT, | |
| env=env, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| encoding="utf-8", | |
| errors="replace", | |
| bufsize=1, | |
| ) | |
| relay_thread = threading.Thread( | |
| target=_relay, | |
| args=(proc.stdout, prefix), | |
| daemon=True, | |
| name=f"relay-{prefix.strip()}", | |
| ) | |
| relay_thread.start() | |
| return proc | |
| def _start_backend() -> subprocess.Popen: | |
| cmd = [ | |
| sys.executable, "-m", "uvicorn", | |
| "src.api.main:app", | |
| "--host", "0.0.0.0", | |
| "--port", str(_API_PORT), | |
| "--log-level", _LOG_LEVEL, | |
| ] | |
| _info(f"Starting FastAPI backend (port {_API_PORT})β¦") | |
| return _spawn(cmd, _PFX_API) | |
| def _start_frontend() -> subprocess.Popen: | |
| cmd = [sys.executable, str(_ROOT / "ui" / "app.py")] | |
| _info(f"Starting Gradio frontend (port {_GRADIO_PORT})β¦") | |
| return _spawn( | |
| cmd, | |
| _PFX_UI, | |
| extra_env={"GRADIO_SERVER_PORT": str(_GRADIO_PORT)}, | |
| ) | |
| # ββ Backend health polling ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _wait_for_backend(backend: subprocess.Popen, timeout: float = 90.0) -> bool: | |
| """ | |
| Poll GET /api/health every 2 s until the backend responds 200 or the | |
| process dies. Returns True when the backend is ready, False on failure. | |
| """ | |
| _info(f"Waiting for backend health check (up to {int(timeout)} s)β¦") | |
| url = f"http://localhost:{_API_PORT}/api/health" | |
| deadline = time.monotonic() + timeout | |
| while time.monotonic() < deadline: | |
| # Abort early if the process already died | |
| if backend.poll() is not None: | |
| return False | |
| try: | |
| with urllib.request.urlopen(url, timeout=2) as resp: | |
| if resp.status == 200: | |
| return True | |
| except (urllib.error.URLError, OSError): | |
| pass | |
| time.sleep(2) | |
| return False | |
| # ββ Graceful shutdown βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _shutdown(procs: list[subprocess.Popen], timeout: float = 8.0) -> None: | |
| """SIGTERM all processes; SIGKILL any that don't exit within *timeout* s.""" | |
| for proc in procs: | |
| if proc.poll() is None: | |
| proc.terminate() | |
| deadline = time.monotonic() + timeout | |
| for proc in procs: | |
| remaining = max(0.1, deadline - time.monotonic()) | |
| try: | |
| proc.wait(timeout=remaining) | |
| except subprocess.TimeoutExpired: | |
| _warn(f"Process {proc.pid} still alive after {timeout:.0f} s β killing.") | |
| proc.kill() | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| args = _parse_args() | |
| # Banner | |
| sep = "=" * 66 | |
| print(sep) | |
| print(" G.U.I.D.E. β Grievance Utility for Information Extraction,") | |
| print(" Drafting and Enrichment") | |
| print(sep) | |
| print() | |
| # ββ Pre-flight ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _check_tesseract() | |
| # ββ Model download phase ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.download_models or (not args.no_train and not _models_complete()): | |
| _download_models() | |
| print() | |
| # ββ Training phase ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.no_train: | |
| _info("--no-train specified β skipping all model training.") | |
| else: | |
| _train_models(args) | |
| print() | |
| if args.train_only: | |
| _info("--train-only specified β exiting without starting servers.") | |
| return | |
| # ββ Server phase ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| backend = _start_backend() | |
| if not _wait_for_backend(backend): | |
| rc = backend.poll() | |
| _err( | |
| f"Backend failed to start" | |
| + (f" (exit code {rc})" if rc is not None else " (health check timed out)") | |
| + ". Check the [API] output above." | |
| ) | |
| backend.kill() | |
| sys.exit(1) | |
| _info("Backend is healthy.") | |
| frontend = _start_frontend() | |
| print() | |
| print(" " + "β" * 62) | |
| print(f" Backend API β http://localhost:{_API_PORT}/docs") | |
| print(f" Frontend UI β http://localhost:{_GRADIO_PORT}") | |
| print(" " + "β" * 62) | |
| print(" Press Ctrl+C to stop both servers.") | |
| print() | |
| # ββ Watch loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| procs = [backend, frontend] | |
| try: | |
| while True: | |
| time.sleep(1) | |
| rc_b = backend.poll() | |
| rc_f = frontend.poll() | |
| if rc_b is not None: | |
| _warn(f"Backend process exited (code {rc_b}).") | |
| break | |
| if rc_f is not None: | |
| _warn(f"Frontend process exited (code {rc_f}).") | |
| break | |
| except KeyboardInterrupt: | |
| print() | |
| _info("Ctrl+C β shutting downβ¦") | |
| finally: | |
| _shutdown(procs) | |
| _info("All processes stopped.") | |
| if __name__ == "__main__": | |
| main() | |