LLM_Screener / app.py
diogo.rodrigues.silva
Adding .ris download functionality
b156781
raw
history blame
20.1 kB
#!/usr/bin/env python3
"""Hugging Face Space app: parse references and screen Excel via Azure Foundry."""
from __future__ import annotations
import json
import os
import secrets
import shutil
import tempfile
import threading
import time
import hmac
import hashlib
import re
from datetime import datetime
from pathlib import Path
from tempfile import mkdtemp
from urllib.parse import quote
import gradio as gr
import pandas as pd
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Query, status
from fastapi.responses import FileResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
import reference_parser
import screen_excel_foundry_safe
ALLOWED_REFERENCE_SUFFIXES = {".txt", ".medline", ".ris"}
ALLOWED_CRITERIA_SUFFIXES = {".yml", ".yaml", ".json"}
APP_ROOT_DIR = Path(__file__).resolve().parent
APP_TMP_DIR = Path(os.getenv("APP_TMP_DIR", str(APP_ROOT_DIR / "artifacts"))).resolve()
APP_RUNS_DIR = APP_TMP_DIR / "runs"
APP_TMP_FILE_EXTENSIONS = {".xlsx", ".json", ".ris"}
MAX_UPLOAD_FILES = int(os.getenv("MAX_UPLOAD_FILES", "20"))
MAX_UPLOAD_FILE_MB = int(os.getenv("MAX_UPLOAD_FILE_MB", "25"))
MAX_CRITERIA_FILE_MB = int(os.getenv("MAX_CRITERIA_FILE_MB", "2"))
DOWNLOAD_LINK_TTL_SECONDS = int(os.getenv("DOWNLOAD_LINK_TTL_SECONDS", "86400"))
APP_USERNAME_ENV = "APP_USERNAME"
APP_PASSWORD_ENV = "APP_PASSWORD"
_HTTP_BASIC = HTTPBasic()
def _required_secret_names() -> list[str]:
return [
screen_excel_foundry_safe.FOUNDRY_ENDPOINT_ENV,
screen_excel_foundry_safe.FOUNDRY_DEPLOYMENT_ENV,
screen_excel_foundry_safe.FOUNDRY_API_KEY_ENV,
APP_USERNAME_ENV,
APP_PASSWORD_ENV,
]
def _missing_secrets() -> list[str]:
missing = []
for name in _required_secret_names():
if not (os.getenv(name) or "").strip():
missing.append(name)
return missing
def _setup_temp_directories() -> None:
APP_TMP_DIR.mkdir(parents=True, exist_ok=True)
APP_RUNS_DIR.mkdir(parents=True, exist_ok=True)
# Gradio can reliably serve downloads from registered static paths in Spaces.
gr.set_static_paths(paths=[str(APP_TMP_DIR)])
os.environ["TMPDIR"] = str(APP_TMP_DIR)
tempfile.tempdir = str(APP_TMP_DIR)
def _cleanup_old_temp_files(max_age_hours: int = 24) -> None:
cutoff = time.time() - (max_age_hours * 3600)
if not APP_RUNS_DIR.exists():
return
for run_dir in APP_RUNS_DIR.iterdir():
try:
if not run_dir.is_dir():
continue
if run_dir.stat().st_mtime < cutoff:
shutil.rmtree(run_dir, ignore_errors=True)
continue
for f in run_dir.iterdir():
if not f.is_file():
continue
if f.suffix.lower() in APP_TMP_FILE_EXTENSIONS and f.stat().st_mtime < cutoff:
f.unlink(missing_ok=True)
except Exception:
continue
def _new_run_dir() -> Path:
run_root = Path(mkdtemp(prefix="screening_run_", dir=str(APP_RUNS_DIR)))
run_root.mkdir(parents=True, exist_ok=True)
return run_root
def _timestamp_slug() -> str:
return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
def _parse_criteria_lines(text: str) -> list[str]:
lines = []
for raw in (text or "").splitlines():
line = raw.strip()
if not line:
continue
if line.startswith("-"):
line = line[1:].strip()
lines.append(line)
return lines
def _validate_upload_file(path: Path, allowed_suffixes: set[str], max_mb: int) -> str | None:
if path.suffix.lower() not in allowed_suffixes:
allowed = ", ".join(sorted(allowed_suffixes))
return f"Unsupported file type: {path.name}. Allowed: {allowed}."
size_mb = path.stat().st_size / (1024 * 1024)
if size_mb > max_mb:
return f"File too large: {path.name} ({size_mb:.1f} MB), max {max_mb} MB."
return None
def _download_href(file_path: Path) -> str:
relative = file_path.resolve().relative_to(APP_TMP_DIR).as_posix()
expires = int(time.time()) + DOWNLOAD_LINK_TTL_SECONDS
signature = _sign_download(relative, expires)
return f"/download?path={quote(relative, safe='')}&exp={expires}&sig={signature}"
def _download_markdown(file_path: Path, label: str) -> str:
return f"[{label}]({_download_href(file_path)})"
def _screening_verdict_counts(screened_excel_path: Path) -> dict[str, int]:
df = pd.read_excel(screened_excel_path, engine="openpyxl")
if "LLM_verdict" not in df.columns:
raise KeyError("Expected 'LLM_verdict' column was not found in screening output.")
verdict_counts = (
df["LLM_verdict"]
.astype(str)
.str.strip()
.str.lower()
.value_counts()
.to_dict()
)
return {
"include": int(verdict_counts.get("include", 0)),
"exclude": int(verdict_counts.get("exclude", 0)),
"unclear": int(verdict_counts.get("unclear", 0)),
}
def _write_included_unclear_ris(screened_excel_path: Path, output_ris_path: Path) -> int:
"""Write a RIS file containing only include/unclear references."""
df = pd.read_excel(screened_excel_path, engine="openpyxl")
if "LLM_verdict" not in df.columns:
raise KeyError("Expected 'LLM_verdict' column was not found in screening output.")
selected = df[
df["LLM_verdict"]
.astype(str)
.str.strip()
.str.lower()
.isin({"include", "unclear"})
]
def _clean_value(value) -> str:
if pd.isna(value):
return ""
text = str(value).strip()
return re.sub(r"\s+", " ", text)
def _authors_from_row(row: pd.Series) -> list[str]:
raw = _clean_value(row.get("Authors") or row.get("FullAuthors") or "")
if not raw:
return []
return [a.strip() for a in re.split(r";|\n", raw) if a.strip()]
lines: list[str] = []
for _, row in selected.iterrows():
title = _clean_value(row.get("Title", ""))
abstract = _clean_value(row.get("Abstract", ""))
journal = _clean_value(row.get("Journal", ""))
doi = _clean_value(row.get("DOI", ""))
pmid = _clean_value(row.get("PMID", ""))
url = _clean_value(row.get("URL", ""))
verdict = _clean_value(row.get("LLM_verdict", "")).lower()
rationale = _clean_value(row.get("LLM_rationale", ""))
year = _clean_value(row.get("Year", ""))
if not year:
year = _clean_value(row.get("PublicationDate", ""))[:4]
elif len(year) > 4:
year = year[:4]
lines.append("TY - JOUR")
if title:
lines.append(f"TI - {title}")
for author in _authors_from_row(row):
lines.append(f"AU - {author}")
if journal:
lines.append(f"JO - {journal}")
if year:
lines.append(f"PY - {year}")
if abstract:
lines.append(f"AB - {abstract}")
if doi:
lines.append(f"DO - {doi}")
if pmid:
lines.append(f"ID - PMID:{pmid}")
if url:
lines.append(f"UR - {url}")
if verdict:
lines.append(f"N1 - LLM verdict: {verdict}")
if rationale:
lines.append(f"N1 - LLM rationale: {rationale}")
lines.append("ER -")
lines.append("")
output_ris_path.write_text("\n".join(lines), encoding="utf-8")
return int(len(selected))
def _auth_credentials() -> tuple[str, str]:
username = (os.getenv(APP_USERNAME_ENV) or "").strip()
password = (os.getenv(APP_PASSWORD_ENV) or "").strip()
if not username or not password:
raise RuntimeError(f"Set Space secrets {APP_USERNAME_ENV} and {APP_PASSWORD_ENV}.")
return username, password
def _sign_download(path: str, expires: int) -> str:
_, password = _auth_credentials()
payload = f"{path}|{expires}".encode("utf-8")
return hmac.new(password.encode("utf-8"), payload, hashlib.sha256).hexdigest()
def _verify_download_signature(path: str, expires: int, signature: str) -> bool:
if expires < int(time.time()):
return False
expected = _sign_download(path, expires)
return secrets.compare_digest(signature, expected)
def _require_basic_auth(credentials: HTTPBasicCredentials = Depends(_HTTP_BASIC)) -> str:
expected_username, expected_password = _auth_credentials()
if not (
secrets.compare_digest(credentials.username, expected_username)
and secrets.compare_digest(credentials.password, expected_password)
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials.",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
def parse_files(reference_files: list[str] | None):
if not reference_files:
return "Upload at least one reference file.", None, ""
if len(reference_files) > MAX_UPLOAD_FILES:
return (
f"Too many files. Maximum allowed is {MAX_UPLOAD_FILES}.",
None,
"",
)
input_paths = [Path(p) for p in reference_files]
for p in input_paths:
validation_error = _validate_upload_file(
path=p,
allowed_suffixes=ALLOWED_REFERENCE_SUFFIXES,
max_mb=MAX_UPLOAD_FILE_MB,
)
if validation_error:
return validation_error, None, ""
run_dir = _new_run_dir()
parsed_output = run_dir / f"parsed_{_timestamp_slug()}.xlsx"
df = reference_parser.parse_references(input_paths=input_paths, output_path=parsed_output)
if df.empty:
return "No records found in the uploaded files.", None, ""
msg = f"Parsed {len(df)} deduplicated records."
return (
msg,
str(parsed_output),
_download_markdown(parsed_output, "Download Parsed Excel"),
)
def screen_excel(
parsed_excel_path: str | None,
criteria_file: str | None,
criteria_topic: str,
criteria_inclusion_text: str,
criteria_exclusion_text: str,
criteria_notes: str,
title_column: str = "Title",
abstract_column: str = "Abstract",
progress=gr.Progress(track_tqdm=True),
):
missing = _missing_secrets()
if missing:
yield "Missing required Space secrets: " + ", ".join(missing), ""
return
if not parsed_excel_path:
yield "Run parser first to generate a parsed Excel file.", ""
return
parsed_path = Path(parsed_excel_path)
if not parsed_path.exists():
yield "Parsed Excel file was not found. Please run parser again.", ""
return
screened_output = parsed_path.parent / f"screened_{_timestamp_slug()}.xlsx"
criteria_path = None
if criteria_file:
candidate = Path(criteria_file)
validation_error = _validate_upload_file(
path=candidate,
allowed_suffixes=ALLOWED_CRITERIA_SUFFIXES,
max_mb=MAX_CRITERIA_FILE_MB,
)
if validation_error:
yield validation_error, ""
return
criteria_path = candidate
else:
inclusion = _parse_criteria_lines(criteria_inclusion_text)
exclusion = _parse_criteria_lines(criteria_exclusion_text)
topic = (criteria_topic or "").strip()
if not topic:
yield "Provide a criteria topic (or upload criteria file).", ""
return
if not inclusion:
yield "Provide at least one inclusion criterion (or upload criteria file).", ""
return
if not exclusion:
yield "Provide at least one exclusion criterion (or upload criteria file).", ""
return
criteria_obj = {
"topic": topic,
"inclusion_criteria": inclusion,
"exclusion_criteria": exclusion,
"notes": (criteria_notes or "").strip(),
}
generated_criteria_path = parsed_path.parent / f"criteria_{_timestamp_slug()}.json"
generated_criteria_path.write_text(json.dumps(criteria_obj, indent=2), encoding="utf-8")
criteria_path = generated_criteria_path
progress(0, desc="Preparing screening...")
started = time.perf_counter()
progress_state = {"done": 0, "total": 0, "eta_seconds": 0, "tqdm_line": ""}
worker_error = {"exc": None}
def _on_progress(done: int, total: int):
elapsed = max(time.perf_counter() - started, 1e-6)
rate = done / elapsed
remaining = max(total - done, 0)
eta_seconds = int(remaining / rate) if rate > 0 else 0
progress_state["done"] = done
progress_state["total"] = total
progress_state["eta_seconds"] = eta_seconds
eta_min = eta_seconds // 60
eta_sec = eta_seconds % 60
progress(done / max(total, 1), desc=f"Screening {done}/{total} (ETA {eta_min:02d}:{eta_sec:02d})")
def _on_progress_text(tqdm_line: str, done: int, total: int):
progress_state["tqdm_line"] = tqdm_line
progress_state["done"] = done
progress_state["total"] = total
def _run_screening():
try:
screen_excel_foundry_safe.process_excel_file(
input_excel_path=str(parsed_path),
output_excel_path=str(screened_output),
criteria_path=str(criteria_path),
title_column=title_column,
abstract_column=abstract_column,
progress_callback=_on_progress,
progress_text_callback=_on_progress_text,
)
except Exception as exc:
worker_error["exc"] = exc
worker = threading.Thread(target=_run_screening, daemon=True)
worker.start()
yield "Preparing screening...", ""
while worker.is_alive():
elapsed_s = int(time.perf_counter() - started)
eta_s = int(progress_state["eta_seconds"])
done = int(progress_state["done"])
total = int(progress_state["total"])
tqdm_line = progress_state["tqdm_line"].strip()
elapsed_min, elapsed_sec = divmod(elapsed_s, 60)
if tqdm_line:
status = tqdm_line
elif total > 0:
eta_min, eta_sec = divmod(eta_s, 60)
status = (
f"Screening {done}/{total} | Elapsed {elapsed_min:02d}:{elapsed_sec:02d} "
f"| ETA {eta_min:02d}:{eta_sec:02d}"
)
else:
status = f"Initializing client... Elapsed {elapsed_min:02d}:{elapsed_sec:02d}"
yield status, ""
time.sleep(1)
if worker_error["exc"] is not None:
yield f"Screening failed: {worker_error['exc']}", ""
return
progress(1, desc="Screening complete.")
screened_ris_output = parsed_path.parent / f"screened_included_unclear_{_timestamp_slug()}.ris"
ris_count = 0
ris_error = None
try:
ris_count = _write_included_unclear_ris(screened_output, screened_ris_output)
except Exception as exc:
ris_error = str(exc)
try:
verdict_counts = _screening_verdict_counts(screened_output)
completed_status = (
"Screening complete: "
f"Included {verdict_counts['include']} | "
f"Excluded {verdict_counts['exclude']} | "
f"Unclear {verdict_counts['unclear']} | "
f"RIS references {ris_count}"
)
if ris_error:
completed_status += f" | RIS export failed: {ris_error}"
except Exception:
completed_status = "Screening complete."
if ris_error:
completed_status += f" RIS export failed: {ris_error}"
downloads = [_download_markdown(screened_output, "Download Screened Excel")]
if screened_ris_output.exists():
downloads.append(
_download_markdown(screened_ris_output, "Download Included + Unclear RIS")
)
yield (
completed_status,
" | ".join(downloads),
)
def build_app() -> gr.Blocks:
missing = _missing_secrets()
secrets_note = (
"All required secrets are configured."
if not missing
else "Missing Space secrets: " + ", ".join(missing)
)
with gr.Blocks(title="Reference Parser + Foundry Screener") as demo:
gr.Markdown("# Reference Parsing, Deduplication and LLM-assisted Screening")
gr.Markdown(
"Upload `.txt/.medline/.ris` files, parse into one Excel, then screen for inclusion/exclusion criteria."
)
gr.Markdown(f"**Secrets status:** {secrets_note}")
parsed_excel_state = gr.State(value=None)
with gr.Row():
with gr.Column():
reference_files = gr.Files(
label="Reference Files (.txt/.medline/.ris)",
type="filepath",
file_count="multiple",
)
parse_btn = gr.Button("1) Run Parser", variant="primary")
parse_status = gr.Textbox(label="Parser Status", interactive=False)
parsed_excel_download = gr.Markdown("")
screen_btn = gr.Button("2) Screen Excel", variant="primary")
screen_status = gr.Textbox(label="Screening Status", interactive=False)
screened_excel_download = gr.Markdown("")
with gr.Column():
criteria_file = gr.File(
label="Criteria File (.yml/.yaml/.json) - Optional if form below is filled",
type="filepath",
)
criteria_topic = gr.Textbox(
label="Criteria Topic (used if no file uploaded)",
placeholder="Example: Digital health interventions for adult diabetes self-management",
)
criteria_inclusion_text = gr.Textbox(
label="Inclusion Criteria (one per line)",
lines=2,
)
criteria_exclusion_text = gr.Textbox(
label="Exclusion Criteria (one per line)",
lines=2,
)
criteria_notes = gr.Textbox(label="Notes (optional)", lines=1)
parse_btn.click(
fn=parse_files,
inputs=[reference_files],
outputs=[parse_status, parsed_excel_state, parsed_excel_download],
api_name=False,
)
screen_btn.click(
fn=screen_excel,
inputs=[
parsed_excel_state,
criteria_file,
criteria_topic,
criteria_inclusion_text,
criteria_exclusion_text,
criteria_notes,
],
outputs=[screen_status, screened_excel_download],
api_name=False,
)
return demo
def _resolve_download_path(path: str) -> Path:
candidate = (APP_TMP_DIR / path).resolve()
try:
candidate.relative_to(APP_TMP_DIR)
except ValueError as exc:
raise HTTPException(status_code=403, detail="Invalid download path.") from exc
if not candidate.exists() or not candidate.is_file():
raise HTTPException(status_code=404, detail="File not found.")
return candidate
def _build_server(demo: gr.Blocks) -> FastAPI:
auth = _auth_credentials()
server = FastAPI()
@server.get("/download")
def download(
path: str = Query(...),
exp: int = Query(...),
sig: str = Query(...),
):
if not _verify_download_signature(path, exp, sig):
raise HTTPException(status_code=403, detail="Invalid or expired download link.")
target = _resolve_download_path(path)
return FileResponse(path=str(target), filename=target.name, media_type="application/octet-stream")
return gr.mount_gradio_app(server, demo, path="/", auth=auth)
if __name__ == "__main__":
_setup_temp_directories()
_cleanup_old_temp_files(max_age_hours=24)
demo = build_app()
demo.queue(api_open=False)
server = _build_server(demo)
uvicorn.run(
server,
host="0.0.0.0",
port=int(os.getenv("PORT", "7860")),
)