"""Standalone inference script for TRACT CRE hub assignment. Dependencies: sentence-transformers, torch, numpy No TRACT package required — all inference logic is inlined. Usage: python predict.py "Ensure AI models are tested for bias" python predict.py --file controls.txt --top-k 10 """ import argparse import json import sys import unicodedata from pathlib import Path import numpy as np from sentence_transformers import SentenceTransformer def sanitize_text(text: str) -> str: """Full sanitization pipeline matching training-time preprocessing. Steps: null bytes → NFC → zero-width chars → HTML unescape+strip → PDF ligatures → broken hyphenation → whitespace collapse → strip. Must match tract/sanitize.py exactly to avoid train/inference skew. """ import html import re text = text.replace("\x00", " ") text = unicodedata.normalize("NFC", text) text = re.sub("[\u200b\u200c\u200d\ufeff]", "", text) text = re.sub(r"]*>", "", html.unescape(text)) for lig, repl in [("\ufb04", "ffl"), ("\ufb03", "ffi"), ("\ufb00", "ff"), ("\ufb01", "fi"), ("\ufb02", "fl")]: text = text.replace(lig, repl) text = re.sub(r"(\w)-\n(\w)", r"\1\2", text) text = re.sub(r"\s+", " ", text) return text.strip() def softmax(x): """Numerically stable softmax.""" e = np.exp(x - np.max(x, axis=-1, keepdims=True)) return e / e.sum(axis=-1, keepdims=True) def predict( texts: list[str], model_dir: str = ".", top_k: int = 5, ) -> list[list[dict]]: """Predict CRE hub assignments for input texts. Args: texts: List of control text strings. model_dir: Path to this repository (contains model + bundled data). top_k: Number of top predictions to return. Returns: List of prediction lists, one per input text. """ base = Path(model_dir) model = SentenceTransformer(str(base)) with open(base / "calibration.json") as f: cal = json.load(f) with open(base / "hub_ids.json") as f: hub_ids = json.load(f) with open(base / "cre_hierarchy.json") as f: hierarchy = json.load(f) hub_emb = np.load(str(base / "hub_embeddings.npy")) temperature = cal["t_deploy"] ood_threshold = cal["ood_threshold"] cleaned = [sanitize_text(t) for t in texts] query_emb = model.encode(cleaned, normalize_embeddings=True, show_progress_bar=False) similarities = query_emb @ hub_emb.T calibrated = softmax(similarities / temperature) results = [] for i in range(len(texts)): sims = similarities[i] confs = calibrated[i] max_sim = float(np.max(sims)) is_ood = max_sim < ood_threshold top_indices = np.argsort(confs)[-top_k:][::-1] preds = [] for idx in top_indices: hub_id = hub_ids[idx] hub_info = hierarchy.get("hubs", {}).get(hub_id, {}) preds.append({ "hub_id": hub_id, "hub_name": hub_info.get("name", hub_id), "hierarchy_path": hub_info.get("hierarchy_path", ""), "raw_similarity": round(float(sims[idx]), 4), "calibrated_confidence": round(float(confs[idx]), 4), "is_ood": is_ood, }) results.append(preds) return results def main(): parser = argparse.ArgumentParser(description="TRACT CRE hub assignment") parser.add_argument("text", nargs="?", help="Control text to assign") parser.add_argument("--file", help="File with one control per line") parser.add_argument("--top-k", type=int, default=5, help="Number of predictions") parser.add_argument("--model-dir", default=".", help="Path to model directory") parser.add_argument("--json", action="store_true", help="JSON output") args = parser.parse_args() if args.file: with open(args.file) as f: texts = [line.strip() for line in f if line.strip()] elif args.text: texts = [args.text] else: parser.print_help() sys.exit(1) results = predict(texts, model_dir=args.model_dir, top_k=args.top_k) if args.json: print(json.dumps(results, indent=2)) else: for i, preds in enumerate(results): if len(texts) > 1: print(f"\n--- Control {i+1}: {texts[i][:80]} ---") for p in preds: ood = " [OOD]" if p["is_ood"] else "" print(f" {p['hub_id']} ({p['calibrated_confidence']:.3f}){ood} {p['hub_name']}") if __name__ == "__main__": main()