tract-cre-assignment / predict.py
rockCO78's picture
Upload folder using huggingface_hub
4535ee1 verified
"""Standalone inference script for TRACT CRE hub assignment.
Dependencies: sentence-transformers, torch, numpy
No TRACT package required — all inference logic is inlined.
Usage:
python predict.py "Ensure AI models are tested for bias"
python predict.py --file controls.txt --top-k 10
"""
import argparse
import json
import sys
import unicodedata
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
def sanitize_text(text: str) -> str:
"""Full sanitization pipeline matching training-time preprocessing.
Steps: null bytes → NFC → zero-width chars → HTML unescape+strip →
PDF ligatures → broken hyphenation → whitespace collapse → strip.
Must match tract/sanitize.py exactly to avoid train/inference skew.
"""
import html
import re
text = text.replace("\x00", " ")
text = unicodedata.normalize("NFC", text)
text = re.sub("[\u200b\u200c\u200d\ufeff]", "", text)
text = re.sub(r"</?[a-zA-Z][^>]*>", "", html.unescape(text))
for lig, repl in [("\ufb04", "ffl"), ("\ufb03", "ffi"), ("\ufb00", "ff"), ("\ufb01", "fi"), ("\ufb02", "fl")]:
text = text.replace(lig, repl)
text = re.sub(r"(\w)-\n(\w)", r"\1\2", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def softmax(x):
"""Numerically stable softmax."""
e = np.exp(x - np.max(x, axis=-1, keepdims=True))
return e / e.sum(axis=-1, keepdims=True)
def predict(
texts: list[str],
model_dir: str = ".",
top_k: int = 5,
) -> list[list[dict]]:
"""Predict CRE hub assignments for input texts.
Args:
texts: List of control text strings.
model_dir: Path to this repository (contains model + bundled data).
top_k: Number of top predictions to return.
Returns:
List of prediction lists, one per input text.
"""
base = Path(model_dir)
model = SentenceTransformer(str(base))
with open(base / "calibration.json") as f:
cal = json.load(f)
with open(base / "hub_ids.json") as f:
hub_ids = json.load(f)
with open(base / "cre_hierarchy.json") as f:
hierarchy = json.load(f)
hub_emb = np.load(str(base / "hub_embeddings.npy"))
temperature = cal["t_deploy"]
ood_threshold = cal["ood_threshold"]
cleaned = [sanitize_text(t) for t in texts]
query_emb = model.encode(cleaned, normalize_embeddings=True, show_progress_bar=False)
similarities = query_emb @ hub_emb.T
calibrated = softmax(similarities / temperature)
results = []
for i in range(len(texts)):
sims = similarities[i]
confs = calibrated[i]
max_sim = float(np.max(sims))
is_ood = max_sim < ood_threshold
top_indices = np.argsort(confs)[-top_k:][::-1]
preds = []
for idx in top_indices:
hub_id = hub_ids[idx]
hub_info = hierarchy.get("hubs", {}).get(hub_id, {})
preds.append({
"hub_id": hub_id,
"hub_name": hub_info.get("name", hub_id),
"hierarchy_path": hub_info.get("hierarchy_path", ""),
"raw_similarity": round(float(sims[idx]), 4),
"calibrated_confidence": round(float(confs[idx]), 4),
"is_ood": is_ood,
})
results.append(preds)
return results
def main():
parser = argparse.ArgumentParser(description="TRACT CRE hub assignment")
parser.add_argument("text", nargs="?", help="Control text to assign")
parser.add_argument("--file", help="File with one control per line")
parser.add_argument("--top-k", type=int, default=5, help="Number of predictions")
parser.add_argument("--model-dir", default=".", help="Path to model directory")
parser.add_argument("--json", action="store_true", help="JSON output")
args = parser.parse_args()
if args.file:
with open(args.file) as f:
texts = [line.strip() for line in f if line.strip()]
elif args.text:
texts = [args.text]
else:
parser.print_help()
sys.exit(1)
results = predict(texts, model_dir=args.model_dir, top_k=args.top_k)
if args.json:
print(json.dumps(results, indent=2))
else:
for i, preds in enumerate(results):
if len(texts) > 1:
print(f"\n--- Control {i+1}: {texts[i][:80]} ---")
for p in preds:
ood = " [OOD]" if p["is_ood"] else ""
print(f" {p['hub_id']} ({p['calibrated_confidence']:.3f}){ood} {p['hub_name']}")
if __name__ == "__main__":
main()