bc-test / src /models /inference.py
lamossta's picture
api and pages
ffcf8df
"""
Interactive CLI for paragraph-boundary inference using ONNX models.
Downloads pre-trained ONNX models from Hugging Face Hub (if not cached),
loads SAT-12L for sentence splitting, then enters an interactive loop:
paste text, get boundary predictions.
Usage:
python -m src.models.inference
python -m src.models.inference --model distilbert
python -m src.models.inference --model bert
"""
import argparse
import logging
from pathlib import Path
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from src.datasets.combined_pairs_dataset import ID2LABEL
from src.pipelines.sat_loader import load_sat
from src.models.export_and_download import HF_MODELS, download_model
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
LABEL_SYMBOLS = {
"SAME_PARAGRAPH": " ",
"NEW_PARAGRAPH": "\n\n",
"NEWLINE": "\n",
}
LOCAL_CHECKPOINTS = Path("checkpoints")
def _load_onnx_model(model_name: str, local: bool = False):
"""Load an ONNX session + tokenizer from local checkpoints or HF Hub."""
if local:
model_dir = LOCAL_CHECKPOINTS / model_name / "best"
else:
repo_id = HF_MODELS[model_name]
model_dir = download_model(repo_id)
onnx_path = model_dir / "model.onnx"
if not onnx_path.exists():
raise FileNotFoundError(f"No model.onnx found in {model_dir}")
session = ort.InferenceSession(
str(onnx_path),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
input_names = [inp.name for inp in session.get_inputs()]
tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
return session, tokenizer, input_names
def _predict_pairs(session, tokenizer, input_names, sentences: list[str], max_length: int = 512) -> list[dict]:
"""Classify boundary between each consecutive sentence pair via ONNX."""
if len(sentences) < 2:
return []
results = []
for i in range(len(sentences) - 1):
enc = tokenizer(
sentences[i],
sentences[i + 1],
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="np",
)
feeds = {k: enc[k] for k in input_names if k in enc}
logits = session.run(None, feeds)[0]
probs = _softmax(logits[0])
pred = int(np.argmax(probs))
results.append({
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"label": ID2LABEL[pred],
"confidence": round(float(probs[pred]), 4),
})
return results
def _softmax(x: np.ndarray) -> np.ndarray:
e = np.exp(x - np.max(x))
return e / e.sum()
def _reconstruct(sentences: list[str], predictions: list[dict]) -> str:
"""Rebuild text from sentences and predicted boundaries."""
if not sentences:
return ""
parts = [sentences[0]]
for i, pred in enumerate(predictions):
sep = LABEL_SYMBOLS.get(pred["label"], " ")
parts.append(sep + sentences[i + 1])
return "".join(parts)
def _read_multiline() -> str | None:
"""Read multi-line input until an empty line is entered."""
print("Paste your text (empty line to submit, 'quit' to exit):")
lines = []
while True:
try:
line = input()
except EOFError:
return None
if line.strip().lower() == "quit":
return None
if line == "" and lines:
break
lines.append(line)
return "\n".join(lines)
def interactive_loop(model_name: str, max_length: int = 512, local: bool = False) -> None:
source = "local checkpoints" if local else "HuggingFace Hub"
log.info(f"Loading ONNX model '{model_name}' from {source} ...")
session, tokenizer, input_names = _load_onnx_model(model_name, local=local)
log.info("Loading SAT-12L ...")
sat = load_sat()
print("\n" + "=" * 60)
print(f" Paragraph Boundary Inference [{model_name} / ONNX]")
print("=" * 60 + "\n")
while True:
text = _read_multiline()
if text is None:
print("Bye.")
break
if not text.strip():
print("(empty input, skipping)\n")
continue
# 1. Sentence-split with SAT first, then strip newlines from each sentence
sentences = sat.split(text, split_on_input_newlines=False, strip_whitespace=False)
sentences = [s.replace('\n', '').strip() for s in sentences if s.strip()]
print(f"\n--- {len(sentences)} sentence(s) detected ---")
if len(sentences) < 2:
print(f" {sentences[0] if sentences else '(none)'}")
print(" (need at least 2 sentences to classify boundaries)\n")
continue
# 3. Predict boundaries
predictions = _predict_pairs(session, tokenizer, input_names, sentences, max_length)
# 4. Show per-pair results
for i, pred in enumerate(predictions):
print(f" [{i+1}] {pred['label']:16s} ({pred['confidence']:.2%})")
print(f" S1: {pred['sentence1'][:80]}")
print(f" S2: {pred['sentence2'][:80]}")
# 5. Show reconstructed text
reconstructed = _reconstruct(sentences, predictions)
print("\n--- Reconstructed text ---")
print(reconstructed)
print()
def main() -> None:
parser = argparse.ArgumentParser(description="Interactive paragraph-boundary inference (ONNX).")
parser.add_argument(
"--model",
default="distilbert",
choices=list(HF_MODELS.keys()),
help="Which model to use (default: distilbert)",
)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--local", action="store_true",
help="Load from checkpoints/<model>/best instead of HF Hub")
args = parser.parse_args()
interactive_loop(args.model, args.max_length, local=args.local)
if __name__ == "__main__":
main()