statLens / src /statlens /server.py
domizzz2025's picture
sync: src/ at 0.1.11
a694491 verified
"""
server.py β€” FastAPI app fronting the statLens orchestrator (two-stage flow).
Mounted by `statlens serve` (see cli.py); also runnable standalone for dev.
Two-stage flow:
POST /api/extract β†’ upload TSV+context, get back SchemaSummary
(no label β€” that's decided in stage 3)
POST /api/run_pipeline β†’ submit (run_id + edited schema). Server calls LLM
with the confirmed schema β†’ label, then runs pipeline.
Plus:
GET / β†’ serve packaged index.html
POST /api/run β†’ one-shot legacy path (extract+pick+run)
GET /api/artifact/{run}/{f} β†’ fetch a single PNG/CSV from the result
GET /api/zip/{run} β†’ fetch the whole result as a zip
GET /api/csv_preview/{run}/{f} β†’ first N rows of a CSV as JSON
Hardening:
- run_id format validated (12-char hex) on every endpoint
- Path traversal blocked via .resolve().is_relative_to(RUNS)
- Multipart upload size capped at STATLENS_MAX_UPLOAD_MB (default 100 MB)
- LLM raw outputs persisted under <run_id>/llm_*_raw.txt for debug
Env vars (set by `statlens serve`):
STATLENS_LLM_ENDPOINT default http://127.0.0.1:8000/v1
STATLENS_LLM_MODEL default statlens
STATLENS_RUNS_DIR default ~/.cache/statlens/runs
STATLENS_MAX_UPLOAD_MB default 100
"""
from __future__ import annotations
import csv
import json
import os
import re
import sys
import uuid
import zipfile
from pathlib import Path
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from . import PKG_DATA, LOCAL_LLM_PORT
from .statlens_run import extract_one, run_pipeline_with_schema, run_one
# ─────────────────── runtime locations / config ───────────────────
INDEX_HTML = Path(str(PKG_DATA / "index.html"))
RUN_ID_RE = re.compile(r"^[a-f0-9]{12}$")
MAX_UPLOAD_BYTES = int(os.environ.get("STATLENS_MAX_UPLOAD_MB", "100")) * 1024 * 1024
def _runs_root() -> Path:
explicit = os.environ.get("STATLENS_RUNS_DIR")
if explicit:
p = Path(explicit).expanduser()
else:
p = Path.home() / ".cache" / "statlens" / "runs"
p.mkdir(parents=True, exist_ok=True)
return p
RUNS = _runs_root()
RUNS_RESOLVED = RUNS.resolve()
DEFAULT_ENDPOINT = os.environ.get("STATLENS_LLM_ENDPOINT", f"http://127.0.0.1:{LOCAL_LLM_PORT}/v1")
DEFAULT_MODEL = os.environ.get("STATLENS_LLM_MODEL", "statlens")
app = FastAPI(title="statLens")
# ─────────────────── validation helpers ───────────────────
def _validate_run_id(run_id: str) -> None:
if not RUN_ID_RE.fullmatch(run_id):
raise HTTPException(404, f"invalid run_id format: {run_id!r}")
def _classify_error(exc: Exception) -> dict:
"""Map an adapter / pipeline exception to a user-facing description.
Returns a dict with:
- type : machine code, e.g. "schema_mismatch_batch"
- title : short human-readable title
- desc : one-paragraph description + how to fix
- field : schema field the user should edit, or None if no schema fix
"""
msg = str(exc)
if "requires a batch column" in msg or (
"batch_variables" in msg and "no such column" in msg
):
return {
"type": "schema_mismatch_batch",
"title": "Batch column missing",
"desc": ("The schema says your study has a batch variable, but no "
"matching column was found in your TSV. Either set "
"'Batch variable?' to No (if there really is no batch), "
"or fill 'Batch columns' with the exact column name."),
"field": "batch_variables",
}
if "requires a time column" in msg or (
"time_variable_column" in msg and "no such column" in msg
):
return {
"type": "schema_mismatch_time",
"title": "Time column missing",
"desc": ("The schema says your study has a time variable, but no "
"matching column was found in your TSV. Either set "
"'Time variable?' to No (if it's cross-sectional), or fill "
"'Time column' with the exact column name."),
"field": "time_variable_column",
}
if "requires a subject id column" in msg or (
"subject_id_column" in msg and "no such column" in msg
):
return {
"type": "schema_mismatch_subject",
"title": "Subject ID column missing",
"desc": ("The schema requires a subject id column (paired or "
"repeated design), but no matching column was found in "
"your TSV. Fill 'Subject ID column' with the exact column "
"name, or set 'Paired/repeated?' to No."),
"field": "subject_id_column",
}
if "reference_level" in msg and "not in observed levels" in msg:
return {
"type": "schema_mismatch_ref",
"title": "Reference level doesn't match data",
"desc": ("'Reference (control) level' is set to a value that doesn't "
"exist in your data. Set it to one of the values shown in "
"'Group levels', or clear it to let statLens auto-infer."),
"field": "reference_level",
}
if "design (group/condition) column" in msg or (
"primary_group_variable" in msg and "no such column" in msg
):
return {
"type": "schema_mismatch_design",
"title": "Group/condition column missing",
"desc": ("No column matching your study's design variable was found "
"in the TSV. Fill 'Group column' with the exact column "
"name."),
"field": "primary_group_variable",
}
if "No feature columns found" in msg:
return {
"type": "no_features",
"title": "No feature columns recognized",
"desc": ("Your TSV doesn't contain any column starting with one of "
"the recognized prefixes: gene_, asv_, prot_, metab_, "
"otu_, feat_. Rename your feature columns to use one of "
"these prefixes and re-upload."),
"field": None,
}
if "missing sample_id" in msg:
return {
"type": "missing_sample_id",
"title": "sample_id column missing",
"desc": ("Your TSV must have a 'sample_id' column with one unique "
"value per row. Add it and re-upload."),
"field": None,
}
if "Pipeline run failed" in msg:
return {
"type": "pipeline_error",
"title": "Pipeline computation failed",
"desc": ("The DEA pipeline ran but errored out. Check that your "
"data has enough samples per group and that values match "
"the data type (raw counts for DESeq2, log/intensity for "
"limma). See technical details below."),
"field": None,
}
# default: opaque
return {
"type": "internal_error",
"title": "Run failed",
"desc": ("Something went wrong inside statLens. See technical details "
"below; if you can't tell what to fix, click Start over."),
"field": None,
}
def _error_payload(run_id: str, exc: Exception) -> dict:
info = _classify_error(exc)
return {
"ok": False,
"run_id": run_id,
"error_type": info["type"],
"error_title": info["title"],
"error_message": info["desc"],
"error_field": info["field"],
"error_raw": f"{type(exc).__name__}: {exc}",
}
def _safe_path_under_runs(*parts: str) -> Path:
"""Join parts under RUNS and verify the resolved path stays under RUNS."""
p = (RUNS.joinpath(*parts)).resolve()
if not str(p).startswith(str(RUNS_RESOLVED) + os.sep) and p != RUNS_RESOLVED:
raise HTTPException(404, "path traversal detected")
return p
def _persist_inputs(tsv: UploadFile, context: str) -> tuple[str, Path, Path, Path]:
"""Save the upload to a fresh run_id directory; return (run_id, rd, tsv_path, ctx_path).
Enforces MAX_UPLOAD_BYTES; raises HTTPException(413) if exceeded.
"""
run_id = uuid.uuid4().hex[:12]
rd = RUNS / run_id
rd.mkdir(parents=True, exist_ok=True)
tsv_path = rd / "data.tsv"
# Streaming size-limited copy
written = 0
chunk_size = 64 * 1024
with tsv_path.open("wb") as f:
while True:
chunk = tsv.file.read(chunk_size)
if not chunk:
break
written += len(chunk)
if written > MAX_UPLOAD_BYTES:
f.close()
tsv_path.unlink(missing_ok=True)
raise HTTPException(
413,
f"upload exceeds {MAX_UPLOAD_BYTES // (1024*1024)} MB limit "
f"(set STATLENS_MAX_UPLOAD_MB to raise)",
)
f.write(chunk)
ctx_path = rd / "ctx.txt"
ctx_path.write_text(context.strip())
return run_id, rd, tsv_path, ctx_path
# ─────────────────── routes ───────────────────
@app.get("/")
def root():
return FileResponse(str(INDEX_HTML))
# ─────────────────── stage 1: extract schema (no label) ───────────────────
@app.post("/api/extract")
async def api_extract(
tsv: UploadFile = File(...),
context: str = Form(...),
):
"""LLM call #1 β€” returns SchemaSummary only. Label decision is in stage 3."""
if not tsv.filename or not tsv.filename.lower().endswith((".tsv", ".txt")):
raise HTTPException(400, "Please upload a .tsv file")
if not context.strip():
raise HTTPException(400, "Study context is empty")
run_id, rd, tsv_path, ctx_path = _persist_inputs(tsv, context)
try:
res = extract_one(
tsv=tsv_path,
context_path=ctx_path,
endpoint=DEFAULT_ENDPOINT,
model=DEFAULT_MODEL,
)
except Exception as e:
return JSONResponse(_error_payload(run_id, e), status_code=500)
# Persist LLM raw output for debugging. Best-effort β€” log on failure
# rather than swallow silently (catches disk-full / permission-denied).
try:
(rd / "llm_extract_raw.txt").write_text(res.raw or "")
except Exception as e:
print(f"[statlens] warning: could not persist llm_extract_raw.txt: {e}",
file=sys.stderr)
return {
"ok": True,
"run_id": run_id,
"schema": res.schema,
"schema_warnings": res.schema_warnings,
}
# ─────────────────── stage 3: pick label + run pipeline ───────────────────
@app.post("/api/run_pipeline")
async def api_run_pipeline(
run_id: str = Form(...),
schema: str = Form(...), # JSON-serialized schema dict
):
"""Take a (confirmed/edited) schema β†’ LLM picks label β†’ run pipeline."""
_validate_run_id(run_id)
rd = _safe_path_under_runs(run_id)
tsv_path = rd / "data.tsv"
ctx_path = rd / "ctx.txt"
if not tsv_path.exists() or not ctx_path.exists():
raise HTTPException(404, f"unknown run_id: {run_id}")
try:
schema_dict = json.loads(schema)
except Exception as e:
raise HTTPException(400, f"schema is not valid JSON: {e}")
out_dir = rd / "out"
try:
result = run_pipeline_with_schema(
tsv=tsv_path, context_path=ctx_path,
schema=schema_dict, out_dir=out_dir,
endpoint=DEFAULT_ENDPOINT, model=DEFAULT_MODEL,
)
except Exception as e:
return JSONResponse(_error_payload(run_id, e), status_code=500)
# Persist the actual stage-3 LLM raw output (not just a summary).
# statlens_run leaves it under '_llm_pick_raw' so the server can write it
# then strip it from the response sent to the client.
raw = result.pop("_llm_pick_raw", None)
if raw is not None:
try:
(rd / "llm_pick_raw.txt").write_text(raw)
except Exception as e:
print(f"[statlens] warning: could not persist llm_pick_raw.txt: {e}",
file=sys.stderr)
return _build_result_response(run_id, rd, out_dir, result)
# ─────────────────── legacy one-shot ───────────────────
@app.post("/api/run")
async def api_run(
tsv: UploadFile = File(...),
context: str = Form(...),
force_label: str | None = Form(None),
):
"""Legacy single-step path: extract + pick + run, no schema editing."""
if not tsv.filename or not tsv.filename.lower().endswith((".tsv", ".txt")):
raise HTTPException(400, "Please upload a .tsv file")
if not context.strip():
raise HTTPException(400, "Study context is empty")
run_id, _rd, tsv_path, ctx_path = _persist_inputs(tsv, context)
out_dir = RUNS / run_id / "out"
try:
result = run_one(
tsv=tsv_path, context_path=ctx_path, out_dir=out_dir,
endpoint=DEFAULT_ENDPOINT, model=DEFAULT_MODEL,
force_label=force_label,
)
except Exception as e:
return JSONResponse(_error_payload(run_id, e), status_code=500)
# /api/run: persist raw too, then strip
rd = RUNS / run_id
raw = result.pop("_llm_pick_raw", None)
if raw is not None:
try:
(rd / "llm_pick_raw.txt").write_text(raw)
except Exception as e:
print(f"[statlens] warning: could not persist llm_pick_raw.txt: {e}",
file=sys.stderr)
return _build_result_response(run_id, rd, out_dir, result)
# ─────────────────── result packaging ───────────────────
def _build_result_response(run_id: str, rd: Path, out_dir: Path, result: dict):
sidecar = out_dir / "statlens_report.json"
if not sidecar.exists():
return JSONResponse(
{"ok": False, "run_id": run_id, "error": "no statlens_report.json produced"},
status_code=500,
)
report = json.loads(sidecar.read_text())
pipeline_dir = out_dir / "pipeline_output"
images, tables = [], []
if pipeline_dir.exists():
images = sorted(p.name for p in pipeline_dir.glob("*.png"))
skip = {"log_normalized_counts.csv", "metadata_used_for_analysis.csv",
"normalized_counts.csv"}
tables = sorted(p.name for p in pipeline_dir.glob("*.csv") if p.name not in skip)
zip_path = rd / "result.zip"
if pipeline_dir.exists():
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for p in pipeline_dir.iterdir():
zf.write(p, arcname=f"pipeline_output/{p.name}")
for tag in ("statlens_report.md", "statlens_report.json"):
src = out_dir / tag
if src.exists():
zf.write(src, arcname=tag)
return {
"ok": result.get("ok", True),
"run_id": run_id,
"label": report.get("label"),
"reasoning": report.get("reasoning"),
"valid": report.get("valid", False),
"elapsed": round(report.get("elapsed_sec", 0), 1),
"pipeline": report.get("pipeline"),
"images": images,
"tables": tables,
"has_zip": zip_path.exists(),
"schema": result.get("schema"),
}
# ─────────────────── artifact / preview / zip ───────────────────
def _validate_filename(filename: str) -> None:
if "/" in filename or "\\" in filename or ".." in filename or filename.startswith("."):
raise HTTPException(404)
@app.get("/api/artifact/{run_id}/{filename}")
def artifact(run_id: str, filename: str):
_validate_run_id(run_id)
_validate_filename(filename)
p = _safe_path_under_runs(run_id, "out", "pipeline_output", filename)
if not p.exists():
raise HTTPException(404)
return FileResponse(p)
@app.get("/api/zip/{run_id}")
def zip_dl(run_id: str):
_validate_run_id(run_id)
p = _safe_path_under_runs(run_id, "result.zip")
if not p.exists():
raise HTTPException(404)
return FileResponse(p, filename=f"statlens_{run_id}.zip",
media_type="application/zip")
@app.get("/api/csv_preview/{run_id}/{filename}")
def csv_preview(run_id: str, filename: str, limit: int = 20):
_validate_run_id(run_id)
_validate_filename(filename)
p = _safe_path_under_runs(run_id, "out", "pipeline_output", filename)
if not p.exists():
raise HTTPException(404)
rows: list[list[str]] = []
headers: list[str] = []
total = 0
with p.open() as f:
rdr = csv.reader(f)
try:
headers = next(rdr)
except StopIteration:
return {"headers": [], "rows": [], "total": 0}
for i, row in enumerate(rdr):
total += 1
if i < limit:
rows.append(row)
return {"headers": headers, "rows": rows, "total": total, "shown": len(rows)}
if __name__ == "__main__":
import uvicorn
print(f"statLens web on http://localhost:7860 β€” endpoint={DEFAULT_ENDPOINT}")
uvicorn.run(app, host="0.0.0.0", port=7860)