""" Interactive CLI for paragraph-boundary inference using ONNX models. Downloads pre-trained ONNX models from Hugging Face Hub (if not cached), loads SAT-12L for sentence splitting, then enters an interactive loop: paste text, get boundary predictions. Usage: python -m src.models.inference python -m src.models.inference --model distilbert python -m src.models.inference --model bert """ import argparse import logging from pathlib import Path import numpy as np import onnxruntime as ort from transformers import AutoTokenizer from src.datasets.combined_pairs_dataset import ID2LABEL from src.pipelines.sat_loader import load_sat from src.models.export_and_download import HF_MODELS, download_model logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) LABEL_SYMBOLS = { "SAME_PARAGRAPH": " ", "NEW_PARAGRAPH": "\n\n", "NEWLINE": "\n", } LOCAL_CHECKPOINTS = Path("checkpoints") def _load_onnx_model(model_name: str, local: bool = False): """Load an ONNX session + tokenizer from local checkpoints or HF Hub.""" if local: model_dir = LOCAL_CHECKPOINTS / model_name / "best" else: repo_id = HF_MODELS[model_name] model_dir = download_model(repo_id) onnx_path = model_dir / "model.onnx" if not onnx_path.exists(): raise FileNotFoundError(f"No model.onnx found in {model_dir}") session = ort.InferenceSession( str(onnx_path), providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) input_names = [inp.name for inp in session.get_inputs()] tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) return session, tokenizer, input_names def _predict_pairs(session, tokenizer, input_names, sentences: list[str], max_length: int = 512) -> list[dict]: """Classify boundary between each consecutive sentence pair via ONNX.""" if len(sentences) < 2: return [] results = [] for i in range(len(sentences) - 1): enc = tokenizer( sentences[i], sentences[i + 1], truncation=True, max_length=max_length, padding="max_length", return_tensors="np", ) feeds = {k: enc[k] for k in input_names if k in enc} logits = session.run(None, feeds)[0] probs = _softmax(logits[0]) pred = int(np.argmax(probs)) results.append({ "sentence1": sentences[i], "sentence2": sentences[i + 1], "label": ID2LABEL[pred], "confidence": round(float(probs[pred]), 4), }) return results def _softmax(x: np.ndarray) -> np.ndarray: e = np.exp(x - np.max(x)) return e / e.sum() def _reconstruct(sentences: list[str], predictions: list[dict]) -> str: """Rebuild text from sentences and predicted boundaries.""" if not sentences: return "" parts = [sentences[0]] for i, pred in enumerate(predictions): sep = LABEL_SYMBOLS.get(pred["label"], " ") parts.append(sep + sentences[i + 1]) return "".join(parts) def _read_multiline() -> str | None: """Read multi-line input until an empty line is entered.""" print("Paste your text (empty line to submit, 'quit' to exit):") lines = [] while True: try: line = input() except EOFError: return None if line.strip().lower() == "quit": return None if line == "" and lines: break lines.append(line) return "\n".join(lines) def interactive_loop(model_name: str, max_length: int = 512, local: bool = False) -> None: source = "local checkpoints" if local else "HuggingFace Hub" log.info(f"Loading ONNX model '{model_name}' from {source} ...") session, tokenizer, input_names = _load_onnx_model(model_name, local=local) log.info("Loading SAT-12L ...") sat = load_sat() print("\n" + "=" * 60) print(f" Paragraph Boundary Inference [{model_name} / ONNX]") print("=" * 60 + "\n") while True: text = _read_multiline() if text is None: print("Bye.") break if not text.strip(): print("(empty input, skipping)\n") continue # 1. Sentence-split with SAT first, then strip newlines from each sentence sentences = sat.split(text, split_on_input_newlines=False, strip_whitespace=False) sentences = [s.replace('\n', '').strip() for s in sentences if s.strip()] print(f"\n--- {len(sentences)} sentence(s) detected ---") if len(sentences) < 2: print(f" {sentences[0] if sentences else '(none)'}") print(" (need at least 2 sentences to classify boundaries)\n") continue # 3. Predict boundaries predictions = _predict_pairs(session, tokenizer, input_names, sentences, max_length) # 4. Show per-pair results for i, pred in enumerate(predictions): print(f" [{i+1}] {pred['label']:16s} ({pred['confidence']:.2%})") print(f" S1: {pred['sentence1'][:80]}") print(f" S2: {pred['sentence2'][:80]}") # 5. Show reconstructed text reconstructed = _reconstruct(sentences, predictions) print("\n--- Reconstructed text ---") print(reconstructed) print() def main() -> None: parser = argparse.ArgumentParser(description="Interactive paragraph-boundary inference (ONNX).") parser.add_argument( "--model", default="distilbert", choices=list(HF_MODELS.keys()), help="Which model to use (default: distilbert)", ) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--local", action="store_true", help="Load from checkpoints//best instead of HF Hub") args = parser.parse_args() interactive_loop(args.model, args.max_length, local=args.local) if __name__ == "__main__": main()