| """ |
| Interactive CLI for paragraph-boundary inference using ONNX models. |
| |
| Downloads pre-trained ONNX models from Hugging Face Hub (if not cached), |
| loads SAT-12L for sentence splitting, then enters an interactive loop: |
| paste text, get boundary predictions. |
| |
| Usage: |
| python -m src.models.inference |
| python -m src.models.inference --model distilbert |
| python -m src.models.inference --model bert |
| """ |
|
|
| import argparse |
| import logging |
| from pathlib import Path |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from transformers import AutoTokenizer |
|
|
| from src.datasets.combined_pairs_dataset import ID2LABEL |
| from src.pipelines.sat_loader import load_sat |
| from src.models.export_and_download import HF_MODELS, download_model |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
| LABEL_SYMBOLS = { |
| "SAME_PARAGRAPH": " ", |
| "NEW_PARAGRAPH": "\n\n", |
| "NEWLINE": "\n", |
| } |
|
|
|
|
| LOCAL_CHECKPOINTS = Path("checkpoints") |
|
|
|
|
| def _load_onnx_model(model_name: str, local: bool = False): |
| """Load an ONNX session + tokenizer from local checkpoints or HF Hub.""" |
| if local: |
| model_dir = LOCAL_CHECKPOINTS / model_name / "best" |
| else: |
| repo_id = HF_MODELS[model_name] |
| model_dir = download_model(repo_id) |
|
|
| onnx_path = model_dir / "model.onnx" |
| if not onnx_path.exists(): |
| raise FileNotFoundError(f"No model.onnx found in {model_dir}") |
|
|
| session = ort.InferenceSession( |
| str(onnx_path), |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"], |
| ) |
| input_names = [inp.name for inp in session.get_inputs()] |
| tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) |
| return session, tokenizer, input_names |
|
|
|
|
| def _predict_pairs(session, tokenizer, input_names, sentences: list[str], max_length: int = 512) -> list[dict]: |
| """Classify boundary between each consecutive sentence pair via ONNX.""" |
| if len(sentences) < 2: |
| return [] |
|
|
| results = [] |
| for i in range(len(sentences) - 1): |
| enc = tokenizer( |
| sentences[i], |
| sentences[i + 1], |
| truncation=True, |
| max_length=max_length, |
| padding="max_length", |
| return_tensors="np", |
| ) |
| feeds = {k: enc[k] for k in input_names if k in enc} |
| logits = session.run(None, feeds)[0] |
| probs = _softmax(logits[0]) |
| pred = int(np.argmax(probs)) |
|
|
| results.append({ |
| "sentence1": sentences[i], |
| "sentence2": sentences[i + 1], |
| "label": ID2LABEL[pred], |
| "confidence": round(float(probs[pred]), 4), |
| }) |
|
|
| return results |
|
|
|
|
| def _softmax(x: np.ndarray) -> np.ndarray: |
| e = np.exp(x - np.max(x)) |
| return e / e.sum() |
|
|
|
|
| def _reconstruct(sentences: list[str], predictions: list[dict]) -> str: |
| """Rebuild text from sentences and predicted boundaries.""" |
| if not sentences: |
| return "" |
| parts = [sentences[0]] |
| for i, pred in enumerate(predictions): |
| sep = LABEL_SYMBOLS.get(pred["label"], " ") |
| parts.append(sep + sentences[i + 1]) |
| return "".join(parts) |
|
|
|
|
| def _read_multiline() -> str | None: |
| """Read multi-line input until an empty line is entered.""" |
| print("Paste your text (empty line to submit, 'quit' to exit):") |
| lines = [] |
| while True: |
| try: |
| line = input() |
| except EOFError: |
| return None |
| if line.strip().lower() == "quit": |
| return None |
| if line == "" and lines: |
| break |
| lines.append(line) |
| return "\n".join(lines) |
|
|
|
|
| def interactive_loop(model_name: str, max_length: int = 512, local: bool = False) -> None: |
| source = "local checkpoints" if local else "HuggingFace Hub" |
| log.info(f"Loading ONNX model '{model_name}' from {source} ...") |
| session, tokenizer, input_names = _load_onnx_model(model_name, local=local) |
|
|
| log.info("Loading SAT-12L ...") |
| sat = load_sat() |
|
|
| print("\n" + "=" * 60) |
| print(f" Paragraph Boundary Inference [{model_name} / ONNX]") |
| print("=" * 60 + "\n") |
|
|
| while True: |
| text = _read_multiline() |
| if text is None: |
| print("Bye.") |
| break |
|
|
| if not text.strip(): |
| print("(empty input, skipping)\n") |
| continue |
|
|
| |
| sentences = sat.split(text, split_on_input_newlines=False, strip_whitespace=False) |
| sentences = [s.replace('\n', '').strip() for s in sentences if s.strip()] |
|
|
| print(f"\n--- {len(sentences)} sentence(s) detected ---") |
| if len(sentences) < 2: |
| print(f" {sentences[0] if sentences else '(none)'}") |
| print(" (need at least 2 sentences to classify boundaries)\n") |
| continue |
|
|
| |
| predictions = _predict_pairs(session, tokenizer, input_names, sentences, max_length) |
|
|
| |
| for i, pred in enumerate(predictions): |
| print(f" [{i+1}] {pred['label']:16s} ({pred['confidence']:.2%})") |
| print(f" S1: {pred['sentence1'][:80]}") |
| print(f" S2: {pred['sentence2'][:80]}") |
|
|
| |
| reconstructed = _reconstruct(sentences, predictions) |
| print("\n--- Reconstructed text ---") |
| print(reconstructed) |
| print() |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Interactive paragraph-boundary inference (ONNX).") |
| parser.add_argument( |
| "--model", |
| default="distilbert", |
| choices=list(HF_MODELS.keys()), |
| help="Which model to use (default: distilbert)", |
| ) |
| parser.add_argument("--max_length", type=int, default=512) |
| parser.add_argument("--local", action="store_true", |
| help="Load from checkpoints/<model>/best instead of HF Hub") |
| args = parser.parse_args() |
|
|
| interactive_loop(args.model, args.max_length, local=args.local) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|