guide / start.py
sangram kumar yerra
fix: use copytree from HF snapshot to avoid overlay hard-link failures
7bb43f2
Raw
History Blame Contribute Delete
19.5 kB
"""
G.U.I.D.E. β€” Single entry point.
Training sequence (each step skipped if checkpoint already exists):
1. EvidenceNER β€” synthetic data, no download, ~2-5 min CPU
2. NextActionPredictor β€” synthetic data, no download, <30 s CPU
3. DomainClassifier β€” requires data/raw/complaints.csv (CFPB dataset)
Server startup:
4. FastAPI backend (src/api/main.py) β€” stdout prefixed with [API]
5. Gradio frontend (ui/app.py) β€” stdout prefixed with [UI]
Usage:
python start.py # auto-train missing, then serve
python start.py --cfpb_csv /path/f.csv # custom CFPB CSV location
python start.py --no-train # skip training entirely
python start.py --train # force re-train all models
python start.py --train-only # train then exit (no servers)
"""
from __future__ import annotations
import argparse
import os
import platform
import shutil
import subprocess
import sys
import threading
import time
import urllib.error
import urllib.request
from pathlib import Path
# ── Windows UTF-8 fix ─────────────────────────────────────────────────────────
# The default Windows console encoding (cp1252) cannot represent box-drawing
# characters (U+2500 ─) or arrows (U+2192 β†’) used in our output.
# Reconfigure before the first print() so the launcher never crashes on encode.
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
# ── Environment ───────────────────────────────────────────────────────────────
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass # FastAPI will load .env on startup; we only need it here for ports
_ROOT = Path(__file__).parent
_API_PORT = int(os.getenv("API_PORT", "8000"))
_GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
_LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
_HF_REPO = os.getenv("HF_MODEL_REPO", "sarav95/guide-models")
# ── ANSI colours (disabled when stdout is not a TTY) ─────────────────────────
_COLOUR = sys.stdout.isatty()
def _c(code: str, text: str) -> str:
return f"\033[{code}m{text}\033[0m" if _COLOUR else text
_PFX_API = _c("36;1", "[API] ") # bright cyan β€” 7 chars incl. spaces
_PFX_UI = _c("32;1", "[UI] ") # bright green
_PFX_START = _c("34;1", "[start]") # bright blue
_PFX_WARN = _c("33;1", "[warn] ") # bright yellow
_PFX_ERR = _c("31;1", "[error]") # bright red
def _log(pfx: str, msg: str) -> None:
print(f"{pfx} {msg}", flush=True)
def _info(msg: str) -> None: _log(_PFX_START, msg)
def _warn(msg: str) -> None: _log(_PFX_WARN, msg)
def _err(msg: str) -> None: _log(_PFX_ERR, msg)
# ── Argument parsing ──────────────────────────────────────────────────────────
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
prog="python start.py",
description="G.U.I.D.E. launcher β€” trains DL models then starts both servers.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=(
"examples:\n"
" python start.py # auto-train + serve\n"
" python start.py --no-train # serve only (models must exist)\n"
" python start.py --train # force retrain + serve\n"
" python start.py --train-only # train and exit\n"
" python start.py --cfpb_csv ~/data/complaints.csv\n"
),
)
p.add_argument(
"--cfpb_csv",
metavar="PATH",
default=None,
help=(
"Path to the CFPB complaints CSV used to train DomainClassifier. "
"Defaults to data/raw/complaints.csv relative to the project root."
),
)
mode = p.add_mutually_exclusive_group()
mode.add_argument(
"--no-train",
action="store_true",
help="Skip all model training; start servers immediately.",
)
mode.add_argument(
"--train",
action="store_true",
help="Force re-train all three models even if checkpoints already exist.",
)
p.add_argument(
"--train-only",
action="store_true",
help="Run the training phase then exit without starting servers.",
)
p.add_argument(
"--download-models",
action="store_true",
help=(
f"Download model checkpoints from HuggingFace ({_HF_REPO}) "
"into models/ before starting. Skipped automatically if all "
"checkpoints already exist."
),
)
return p.parse_args()
# ── Checkpoint detection ──────────────────────────────────────────────────────
def _hf_ckpt(rel_dir: str) -> bool:
"""True when a HuggingFace checkpoint directory contains config.json."""
return (_ROOT / rel_dir / "config.json").exists()
def _pt_ckpt(rel_path: str) -> bool:
"""True when a PyTorch .pt file exists at the given relative path."""
return (_ROOT / rel_path).exists()
# ── Model download from HuggingFace ──────────────────────────────────────────
def _models_complete() -> bool:
"""True when required checkpoints exist (domain_classifier is optional β€” has keyword fallback)."""
return (
_hf_ckpt("models/evidence_ner")
and _pt_ckpt("models/next_action/model.pt")
)
def _download_models() -> None:
"""Download all model checkpoints from HuggingFace if not already present."""
if _models_complete():
_info("All model checkpoints already present β€” skipping download.")
return
try:
from huggingface_hub import snapshot_download
except ImportError:
_err("huggingface_hub is not installed. Run: pip install huggingface_hub")
sys.exit(1)
import shutil
_info(f"Downloading model checkpoints from {_HF_REPO!r}…")
models_dir = _ROOT / "models"
models_dir.mkdir(parents=True, exist_ok=True)
# Download to HF cache (no local_dir) and get the snapshot directory.
# Using local_dir causes hard-link failures in Docker overlay filesystems,
# leaving models/ empty even though the download "succeeds".
snapshot_path = Path(snapshot_download(repo_id=_HF_REPO, repo_type="model"))
_info(f"Snapshot cached at: {snapshot_path}")
for subdir in ("evidence_ner", "next_action", "domain_classifier"):
src = snapshot_path / subdir
dst = models_dir / subdir
if dst.exists():
_info(f" {subdir} already present β€” skipped.")
elif src.exists():
shutil.copytree(str(src), str(dst))
_info(f" Copied {subdir}.")
else:
_warn(f" {subdir} not found in snapshot β€” skipping.")
_info("Model download complete.")
# ── Pre-flight checks ────────────────────────────────────────────────────────
def _check_tesseract() -> None:
"""Exit with an actionable message if the Tesseract OCR binary is missing."""
# 1. Check if it's already in the System PATH
if shutil.which("tesseract") is not None:
return
# 2. Check default Windows installation path if running on Windows
if platform.system() == "Windows":
default_win_path = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
if os.path.exists(default_win_path):
# Dynamically add it to the environment PATH for this session
os.environ["PATH"] += os.pathsep + os.path.dirname(default_win_path)
return
# 3. Fail if not found anywhere
_err("Tesseract OCR not found. Install it before starting G.U.I.D.E.:")
print("", flush=True)
print(" macOS : brew install tesseract", flush=True)
print(" Ubuntu : sudo apt install tesseract-ocr", flush=True)
print(" Windows : https://github.com/UB-Mannheim/tesseract/wiki", flush=True)
print("", flush=True)
print("Then re-run: python start.py", flush=True)
sys.exit(1)
# ── Output relay (subprocess stdout β†’ our stdout with prefix) ─────────────────
def _relay(stream, prefix: str) -> None:
"""
Read lines from *stream* and write them to stdout with *prefix*.
Designed to run as a daemon thread; exits when *stream* closes.
"""
try:
for line in stream:
sys.stdout.write(f"{prefix} {line}")
sys.stdout.flush()
except Exception:
pass
# ── Training helpers ──────────────────────────────────────────────────────────
def _run_training(label: str, cmd: list[str]) -> bool:
"""
Run *cmd* synchronously with inherited stdio (no prefix β€” output flows
directly to the terminal so the user can watch training progress).
Returns True on success (exit code 0), False otherwise.
"""
_info(f"Training {label}…")
_info(f" $ {' '.join(str(c) for c in cmd)}")
result = subprocess.run(cmd, cwd=_ROOT)
if result.returncode != 0:
_warn(f"{label} training exited {result.returncode} β€” continuing with fallback.")
return False
_info(f"{label} training complete.")
return True
def _train_models(args: argparse.Namespace) -> None:
"""Train all three DL models in dependency order, skipping existing checkpoints."""
force = args.train
_info("─" * 60)
_info("Model training phase")
_info("─" * 60)
# ── 1. EvidenceNER ────────────────────────────────────────────────────────
ner_dir = "models/evidence_ner"
if not force and _hf_ckpt(ner_dir):
_info(f"EvidenceNER: checkpoint exists at {ner_dir!r} β€” skipped.")
else:
_run_training(
"EvidenceNER",
[sys.executable, "-m", "src.ner.train",
"--output_dir", ner_dir],
)
# ── 2. NextActionPredictor ────────────────────────────────────────────────
na_path = "models/next_action/model.pt"
if not force and _pt_ckpt(na_path):
_info(f"NextActionPredictor: checkpoint exists at {na_path!r} β€” skipped.")
else:
# Ensure parent directory exists before training writes the file
(_ROOT / "models" / "next_action").mkdir(parents=True, exist_ok=True)
_run_training(
"NextActionPredictor",
[sys.executable, "-m", "src.next_action.train",
"--output_path", na_path],
)
# ── 3. DomainClassifier ───────────────────────────────────────────────────
cls_dir = "models/domain_classifier"
if not force and _hf_ckpt(cls_dir):
_info(f"DomainClassifier: checkpoint exists at {cls_dir!r} β€” skipped.")
else:
cfpb_csv = (
Path(args.cfpb_csv)
if args.cfpb_csv
else _ROOT / "data" / "raw" / "complaints.csv"
)
if cfpb_csv.exists():
_run_training(
"DomainClassifier",
[sys.executable, "-m", "src.classifier.train",
"--cfpb_csv", str(cfpb_csv),
"--output_dir", cls_dir],
)
else:
_warn(
f"DomainClassifier: {cfpb_csv} not found β€” skipped.\n"
f"{_PFX_WARN} Download from Kaggle (CFPB Consumer Complaint Database),\n"
f"{_PFX_WARN} place in data/raw/complaints.csv, then re-run or use --train.\n"
f"{_PFX_WARN} Keyword-based fallback will be active until a checkpoint exists."
)
_info("─" * 60)
_info("Training phase complete.")
_info("─" * 60)
# ── Server spawning ───────────────────────────────────────────────────────────
def _spawn(
cmd: list[str],
prefix: str,
extra_env: dict[str, str] | None = None,
) -> subprocess.Popen:
"""
Launch *cmd* as a subprocess. Relay its merged stdout+stderr to our
stdout, prepending each line with *prefix*.
PYTHONUNBUFFERED=1 is set so subprocess output is not line-buffered
internally; this gives near-real-time log visibility.
"""
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
if extra_env:
env.update(extra_env)
proc = subprocess.Popen(
cmd,
cwd=_ROOT,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
errors="replace",
bufsize=1,
)
relay_thread = threading.Thread(
target=_relay,
args=(proc.stdout, prefix),
daemon=True,
name=f"relay-{prefix.strip()}",
)
relay_thread.start()
return proc
def _start_backend() -> subprocess.Popen:
cmd = [
sys.executable, "-m", "uvicorn",
"src.api.main:app",
"--host", "0.0.0.0",
"--port", str(_API_PORT),
"--log-level", _LOG_LEVEL,
]
_info(f"Starting FastAPI backend (port {_API_PORT})…")
return _spawn(cmd, _PFX_API)
def _start_frontend() -> subprocess.Popen:
cmd = [sys.executable, str(_ROOT / "ui" / "app.py")]
_info(f"Starting Gradio frontend (port {_GRADIO_PORT})…")
return _spawn(
cmd,
_PFX_UI,
extra_env={"GRADIO_SERVER_PORT": str(_GRADIO_PORT)},
)
# ── Backend health polling ────────────────────────────────────────────────────
def _wait_for_backend(backend: subprocess.Popen, timeout: float = 90.0) -> bool:
"""
Poll GET /api/health every 2 s until the backend responds 200 or the
process dies. Returns True when the backend is ready, False on failure.
"""
_info(f"Waiting for backend health check (up to {int(timeout)} s)…")
url = f"http://localhost:{_API_PORT}/api/health"
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
# Abort early if the process already died
if backend.poll() is not None:
return False
try:
with urllib.request.urlopen(url, timeout=2) as resp:
if resp.status == 200:
return True
except (urllib.error.URLError, OSError):
pass
time.sleep(2)
return False
# ── Graceful shutdown ─────────────────────────────────────────────────────────
def _shutdown(procs: list[subprocess.Popen], timeout: float = 8.0) -> None:
"""SIGTERM all processes; SIGKILL any that don't exit within *timeout* s."""
for proc in procs:
if proc.poll() is None:
proc.terminate()
deadline = time.monotonic() + timeout
for proc in procs:
remaining = max(0.1, deadline - time.monotonic())
try:
proc.wait(timeout=remaining)
except subprocess.TimeoutExpired:
_warn(f"Process {proc.pid} still alive after {timeout:.0f} s β€” killing.")
proc.kill()
# ── Main ──────────────────────────────────────────────────────────────────────
def main() -> None:
args = _parse_args()
# Banner
sep = "=" * 66
print(sep)
print(" G.U.I.D.E. β€” Grievance Utility for Information Extraction,")
print(" Drafting and Enrichment")
print(sep)
print()
# ── Pre-flight ────────────────────────────────────────────────────────────
_check_tesseract()
# ── Model download phase ──────────────────────────────────────────────────
if args.download_models or (not args.no_train and not _models_complete()):
_download_models()
print()
# ── Training phase ────────────────────────────────────────────────────────
if args.no_train:
_info("--no-train specified β€” skipping all model training.")
else:
_train_models(args)
print()
if args.train_only:
_info("--train-only specified β€” exiting without starting servers.")
return
# ── Server phase ──────────────────────────────────────────────────────────
backend = _start_backend()
if not _wait_for_backend(backend):
rc = backend.poll()
_err(
f"Backend failed to start"
+ (f" (exit code {rc})" if rc is not None else " (health check timed out)")
+ ". Check the [API] output above."
)
backend.kill()
sys.exit(1)
_info("Backend is healthy.")
frontend = _start_frontend()
print()
print(" " + "─" * 62)
print(f" Backend API β†’ http://localhost:{_API_PORT}/docs")
print(f" Frontend UI β†’ http://localhost:{_GRADIO_PORT}")
print(" " + "─" * 62)
print(" Press Ctrl+C to stop both servers.")
print()
# ── Watch loop ────────────────────────────────────────────────────────────
procs = [backend, frontend]
try:
while True:
time.sleep(1)
rc_b = backend.poll()
rc_f = frontend.poll()
if rc_b is not None:
_warn(f"Backend process exited (code {rc_b}).")
break
if rc_f is not None:
_warn(f"Frontend process exited (code {rc_f}).")
break
except KeyboardInterrupt:
print()
_info("Ctrl+C β€” shutting down…")
finally:
_shutdown(procs)
_info("All processes stopped.")
if __name__ == "__main__":
main()