File size: 6,059 Bytes
ffcf8df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """
Interactive CLI for paragraph-boundary inference using ONNX models.
Downloads pre-trained ONNX models from Hugging Face Hub (if not cached),
loads SAT-12L for sentence splitting, then enters an interactive loop:
paste text, get boundary predictions.
Usage:
python -m src.models.inference
python -m src.models.inference --model distilbert
python -m src.models.inference --model bert
"""
import argparse
import logging
from pathlib import Path
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from src.datasets.combined_pairs_dataset import ID2LABEL
from src.pipelines.sat_loader import load_sat
from src.models.export_and_download import HF_MODELS, download_model
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
LABEL_SYMBOLS = {
"SAME_PARAGRAPH": " ",
"NEW_PARAGRAPH": "\n\n",
"NEWLINE": "\n",
}
LOCAL_CHECKPOINTS = Path("checkpoints")
def _load_onnx_model(model_name: str, local: bool = False):
"""Load an ONNX session + tokenizer from local checkpoints or HF Hub."""
if local:
model_dir = LOCAL_CHECKPOINTS / model_name / "best"
else:
repo_id = HF_MODELS[model_name]
model_dir = download_model(repo_id)
onnx_path = model_dir / "model.onnx"
if not onnx_path.exists():
raise FileNotFoundError(f"No model.onnx found in {model_dir}")
session = ort.InferenceSession(
str(onnx_path),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
input_names = [inp.name for inp in session.get_inputs()]
tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
return session, tokenizer, input_names
def _predict_pairs(session, tokenizer, input_names, sentences: list[str], max_length: int = 512) -> list[dict]:
"""Classify boundary between each consecutive sentence pair via ONNX."""
if len(sentences) < 2:
return []
results = []
for i in range(len(sentences) - 1):
enc = tokenizer(
sentences[i],
sentences[i + 1],
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="np",
)
feeds = {k: enc[k] for k in input_names if k in enc}
logits = session.run(None, feeds)[0]
probs = _softmax(logits[0])
pred = int(np.argmax(probs))
results.append({
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"label": ID2LABEL[pred],
"confidence": round(float(probs[pred]), 4),
})
return results
def _softmax(x: np.ndarray) -> np.ndarray:
e = np.exp(x - np.max(x))
return e / e.sum()
def _reconstruct(sentences: list[str], predictions: list[dict]) -> str:
"""Rebuild text from sentences and predicted boundaries."""
if not sentences:
return ""
parts = [sentences[0]]
for i, pred in enumerate(predictions):
sep = LABEL_SYMBOLS.get(pred["label"], " ")
parts.append(sep + sentences[i + 1])
return "".join(parts)
def _read_multiline() -> str | None:
"""Read multi-line input until an empty line is entered."""
print("Paste your text (empty line to submit, 'quit' to exit):")
lines = []
while True:
try:
line = input()
except EOFError:
return None
if line.strip().lower() == "quit":
return None
if line == "" and lines:
break
lines.append(line)
return "\n".join(lines)
def interactive_loop(model_name: str, max_length: int = 512, local: bool = False) -> None:
source = "local checkpoints" if local else "HuggingFace Hub"
log.info(f"Loading ONNX model '{model_name}' from {source} ...")
session, tokenizer, input_names = _load_onnx_model(model_name, local=local)
log.info("Loading SAT-12L ...")
sat = load_sat()
print("\n" + "=" * 60)
print(f" Paragraph Boundary Inference [{model_name} / ONNX]")
print("=" * 60 + "\n")
while True:
text = _read_multiline()
if text is None:
print("Bye.")
break
if not text.strip():
print("(empty input, skipping)\n")
continue
# 1. Sentence-split with SAT first, then strip newlines from each sentence
sentences = sat.split(text, split_on_input_newlines=False, strip_whitespace=False)
sentences = [s.replace('\n', '').strip() for s in sentences if s.strip()]
print(f"\n--- {len(sentences)} sentence(s) detected ---")
if len(sentences) < 2:
print(f" {sentences[0] if sentences else '(none)'}")
print(" (need at least 2 sentences to classify boundaries)\n")
continue
# 3. Predict boundaries
predictions = _predict_pairs(session, tokenizer, input_names, sentences, max_length)
# 4. Show per-pair results
for i, pred in enumerate(predictions):
print(f" [{i+1}] {pred['label']:16s} ({pred['confidence']:.2%})")
print(f" S1: {pred['sentence1'][:80]}")
print(f" S2: {pred['sentence2'][:80]}")
# 5. Show reconstructed text
reconstructed = _reconstruct(sentences, predictions)
print("\n--- Reconstructed text ---")
print(reconstructed)
print()
def main() -> None:
parser = argparse.ArgumentParser(description="Interactive paragraph-boundary inference (ONNX).")
parser.add_argument(
"--model",
default="distilbert",
choices=list(HF_MODELS.keys()),
help="Which model to use (default: distilbert)",
)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--local", action="store_true",
help="Load from checkpoints/<model>/best instead of HF Hub")
args = parser.parse_args()
interactive_loop(args.model, args.max_length, local=args.local)
if __name__ == "__main__":
main()
|