statLens / src /statlens /cli.py
domizzz2025's picture
sync: src/ now reflects statlens-0.1.7 wheel
f128c67 verified
"""
statlens β€” command-line entry point.
Subcommands:
statlens serve spin up local LLM backend + FastAPI; open browser
statlens download fetch the LoRA adapter into the HF cache
(NOT the 64 GB base β€” that's the user's job)
statlens info show GPU / cache state / paths
statlens classify one-shot CLI run (TSV + context.txt β†’ result folder)
"""
from __future__ import annotations
import os
import signal
import sys
import webbrowser
import click
from . import (
__version__,
DEFAULT_BASE_MODEL_REF,
DEFAULT_LORA_REPO,
DEFAULT_LORA_SUBFOLDER,
DEFAULT_QUANTIZATION,
LOCAL_LLM_PORT,
LOCAL_WEB_PORT,
)
from . import runtime
def _common_serve_opts(f):
"""Decorator that attaches the model/LoRA/quantization options shared by
`serve` and `classify`.
The LLM backend is LLaMA-Factory (`llamafactory-cli api`); options here
map directly to its YAML config.
"""
f = click.option("--quantization", default=DEFAULT_QUANTIZATION,
type=click.Choice(["bitsandbytes", "awq", "gptq", "none"]),
show_default=True,
help="LLM weight-quantization mode (passed to LLaMA-Factory)")(f)
f = click.option("--lora-path", default=None, type=click.Path(),
help="local path to the LoRA folder; "
"if omitted, downloaded from HF")(f)
f = click.option("--base-model", default=None, type=click.Path(),
help="local path to the BF16 base model directory; "
"if omitted, auto-detect from common paths "
"(~/models/qwen3-32b, /root/autodl-tmp/..., HF cache)")(f)
return f
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, prog_name="statlens")
def main():
"""statLens β€” DEA method selector. See `statlens serve --help`."""
# ─────────────────────────────── serve ───────────────────────────────
@main.command()
@_common_serve_opts
@click.option("--web-port", default=LOCAL_WEB_PORT, show_default=True, type=int)
@click.option("--llm-port", default=LOCAL_LLM_PORT, show_default=True, type=int)
@click.option("--no-browser", is_flag=True, help="don't auto-open browser")
@click.option("--no-gpu-check", is_flag=True, help="skip GPU sanity check")
def serve(base_model, lora_path, quantization,
web_port, llm_port, no_browser, no_gpu_check):
"""Start the LLM backend (LLaMA-Factory) + the web app on http://localhost:7860 ."""
if not no_gpu_check:
runtime.check_gpu()
base_path = runtime.resolve_base_model(base_model)
lora_path = runtime.ensure_lora_cached(lora_path)
llm_proc = runtime.start_server(
base_path, lora_path,
port=llm_port,
quantization=quantization,
)
def shutdown(signum=None, frame=None):
runtime.stop_server(llm_proc)
sys.exit(0)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
try:
# show animated progress; abort if subprocess dies
runtime.wait_for_server(port=llm_port, proc=llm_proc)
except (TimeoutError, RuntimeError) as e:
click.secho(f"\n[statlens] {e}", fg="red")
runtime.stop_server(llm_proc)
sys.exit(2)
os.environ["STATLENS_LLM_ENDPOINT"] = f"http://127.0.0.1:{llm_port}/v1"
# LLaMA-Factory's API exposes the model under the name "gpt-3.5-turbo".
os.environ["STATLENS_LLM_MODEL"] = "gpt-3.5-turbo"
web_url = f"http://localhost:{web_port}/"
# Prominent ready banner β€” only what end users care about.
click.echo("")
click.secho("══════════════════════════════════════════════════════", fg="green")
click.secho(f" βœ… statLens ready", fg="green", bold=True)
click.secho(f" open in browser: {web_url}", fg="green")
click.secho(f" Ctrl+C to stop.", fg="green")
click.secho("══════════════════════════════════════════════════════", fg="green")
click.echo("")
if not no_browser:
import threading
threading.Timer(1.0, lambda: webbrowser.open(web_url)).start()
import uvicorn
from . import server as server_mod
try:
# Quiet uvicorn so it doesn't drown the ready banner; users see HTTP
# requests via the LLM log if they want.
uvicorn.run(server_mod.app, host="127.0.0.1", port=web_port, log_level="warning")
finally:
runtime.stop_server(llm_proc)
# ─────────────────────────── download ───────────────────────────
@main.command()
@click.option("--lora-repo", default=DEFAULT_LORA_REPO, show_default=True)
@click.option("--lora-subfolder", default=DEFAULT_LORA_SUBFOLDER, show_default=True)
def download(lora_repo, lora_subfolder):
"""Pre-fetch the LoRA adapter (~1 GB) into the HF cache.
statLens does NOT auto-download the 64 GB base model. Get it yourself with:
huggingface-cli download Qwen/Qwen3-32B --local-dir ~/models/qwen3-32b
"""
p = runtime.ensure_lora_cached(None, lora_repo, lora_subfolder)
click.secho(f"\nβœ… LoRA cached at {p}", fg="green")
click.echo(
"\nReminder: download the BF16 base model separately, e.g.\n"
" huggingface-cli download Qwen/Qwen3-32B --local-dir ~/models/qwen3-32b"
)
# ─────────────────────────── info ───────────────────────────
@main.command()
def info():
"""Show GPU, cache, and default endpoint information."""
click.echo(f"statlens version: {__version__}\n")
try:
import torch
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
p = torch.cuda.get_device_properties(i)
click.echo(f" GPU {i}: {p.name} Β· {p.total_memory/1024**3:.1f} GB")
else:
click.echo(" GPU: (none)")
except Exception as e:
click.echo(f" GPU: (torch not importable: {e})")
click.echo("")
click.echo(f" base model (reference) : {DEFAULT_BASE_MODEL_REF} (NOT auto-downloaded)")
click.echo(f" base model (resolved) : "
f"{os.environ.get('STATLENS_BASE_MODEL', '(unset β€” pass --base-model)')}")
click.echo(f" default quantization : {DEFAULT_QUANTIZATION}")
click.echo(f" lora repo : {DEFAULT_LORA_REPO} (subfolder {DEFAULT_LORA_SUBFOLDER})")
click.echo("")
from huggingface_hub.constants import HF_HUB_CACHE
click.echo(f" HF cache root : {HF_HUB_CACHE}")
d = runtime.cache_dir_for(DEFAULT_LORA_REPO)
if d.exists():
sz = sum(p.stat().st_size for p in d.rglob("*") if p.is_file())
click.echo(f" Β· {DEFAULT_LORA_REPO} βœ“ ({sz/1024**3:.2f} GB at {d})")
else:
click.echo(f" Β· {DEFAULT_LORA_REPO} βœ— (run `statlens download` to fetch)")
# ─────────────────────────── classify ───────────────────────────
@main.command()
@_common_serve_opts
@click.option("--tsv", required=True, type=click.Path(exists=True, dir_okay=False))
@click.option("--context", required=True, type=click.Path(exists=True, dir_okay=False))
@click.option("--out", required=True, type=click.Path(file_okay=False))
@click.option("--endpoint", default=None,
help="OpenAI-compatible endpoint (default: spin up a temporary local LLM backend)")
@click.option("--model", default="gpt-3.5-turbo", show_default=True)
@click.option("--keep-llm", is_flag=True, help="don't shut down the LLM backend after classifying")
def classify(base_model, lora_path, quantization,
tsv, context, out, endpoint, model, keep_llm):
"""One-shot CLI: classify a single TSV+context, run pipeline, exit."""
from pathlib import Path
from .statlens_run import run_one
if endpoint is None:
runtime.check_gpu()
base_path = runtime.resolve_base_model(base_model)
lora_resolved = runtime.ensure_lora_cached(lora_path)
proc = runtime.start_server(base_path, lora_resolved,
quantization=quantization)
try:
runtime.wait_for_server()
endpoint = f"http://127.0.0.1:{LOCAL_LLM_PORT}/v1"
run_one(Path(tsv), Path(context), Path(out), endpoint, model)
finally:
if not keep_llm:
runtime.stop_server(proc)
else:
from pathlib import Path
run_one(Path(tsv), Path(context), Path(out), endpoint, model)
if __name__ == "__main__":
main()