File size: 6,059 Bytes
ffcf8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Interactive CLI for paragraph-boundary inference using ONNX models.

Downloads pre-trained ONNX models from Hugging Face Hub (if not cached),
loads SAT-12L for sentence splitting, then enters an interactive loop:
paste text, get boundary predictions.

Usage:
    python -m src.models.inference
    python -m src.models.inference --model distilbert
    python -m src.models.inference --model bert
"""

import argparse
import logging
from pathlib import Path

import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer

from src.datasets.combined_pairs_dataset import ID2LABEL
from src.pipelines.sat_loader import load_sat
from src.models.export_and_download import HF_MODELS, download_model

logging.basicConfig(level=logging.INFO, format="%(asctime)s  %(levelname)s  %(message)s")
log = logging.getLogger(__name__)

LABEL_SYMBOLS = {
    "SAME_PARAGRAPH": " ",
    "NEW_PARAGRAPH": "\n\n",
    "NEWLINE": "\n",
}


LOCAL_CHECKPOINTS = Path("checkpoints")


def _load_onnx_model(model_name: str, local: bool = False):
    """Load an ONNX session + tokenizer from local checkpoints or HF Hub."""
    if local:
        model_dir = LOCAL_CHECKPOINTS / model_name / "best"
    else:
        repo_id = HF_MODELS[model_name]
        model_dir = download_model(repo_id)

    onnx_path = model_dir / "model.onnx"
    if not onnx_path.exists():
        raise FileNotFoundError(f"No model.onnx found in {model_dir}")

    session = ort.InferenceSession(
        str(onnx_path),
        providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
    )
    input_names = [inp.name for inp in session.get_inputs()]
    tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
    return session, tokenizer, input_names


def _predict_pairs(session, tokenizer, input_names, sentences: list[str], max_length: int = 512) -> list[dict]:
    """Classify boundary between each consecutive sentence pair via ONNX."""
    if len(sentences) < 2:
        return []

    results = []
    for i in range(len(sentences) - 1):
        enc = tokenizer(
            sentences[i],
            sentences[i + 1],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="np",
        )
        feeds = {k: enc[k] for k in input_names if k in enc}
        logits = session.run(None, feeds)[0]
        probs = _softmax(logits[0])
        pred = int(np.argmax(probs))

        results.append({
            "sentence1": sentences[i],
            "sentence2": sentences[i + 1],
            "label": ID2LABEL[pred],
            "confidence": round(float(probs[pred]), 4),
        })

    return results


def _softmax(x: np.ndarray) -> np.ndarray:
    e = np.exp(x - np.max(x))
    return e / e.sum()


def _reconstruct(sentences: list[str], predictions: list[dict]) -> str:
    """Rebuild text from sentences and predicted boundaries."""
    if not sentences:
        return ""
    parts = [sentences[0]]
    for i, pred in enumerate(predictions):
        sep = LABEL_SYMBOLS.get(pred["label"], " ")
        parts.append(sep + sentences[i + 1])
    return "".join(parts)


def _read_multiline() -> str | None:
    """Read multi-line input until an empty line is entered."""
    print("Paste your text (empty line to submit, 'quit' to exit):")
    lines = []
    while True:
        try:
            line = input()
        except EOFError:
            return None
        if line.strip().lower() == "quit":
            return None
        if line == "" and lines:
            break
        lines.append(line)
    return "\n".join(lines)


def interactive_loop(model_name: str, max_length: int = 512, local: bool = False) -> None:
    source = "local checkpoints" if local else "HuggingFace Hub"
    log.info(f"Loading ONNX model '{model_name}' from {source} ...")
    session, tokenizer, input_names = _load_onnx_model(model_name, local=local)

    log.info("Loading SAT-12L ...")
    sat = load_sat()

    print("\n" + "=" * 60)
    print(f"  Paragraph Boundary Inference  [{model_name} / ONNX]")
    print("=" * 60 + "\n")

    while True:
        text = _read_multiline()
        if text is None:
            print("Bye.")
            break

        if not text.strip():
            print("(empty input, skipping)\n")
            continue

        # 1. Sentence-split with SAT first, then strip newlines from each sentence
        sentences = sat.split(text, split_on_input_newlines=False, strip_whitespace=False)
        sentences = [s.replace('\n', '').strip() for s in sentences if s.strip()]

        print(f"\n--- {len(sentences)} sentence(s) detected ---")
        if len(sentences) < 2:
            print(f"  {sentences[0] if sentences else '(none)'}")
            print("  (need at least 2 sentences to classify boundaries)\n")
            continue

        # 3. Predict boundaries
        predictions = _predict_pairs(session, tokenizer, input_names, sentences, max_length)

        # 4. Show per-pair results
        for i, pred in enumerate(predictions):
            print(f"  [{i+1}] {pred['label']:16s} ({pred['confidence']:.2%})")
            print(f"       S1: {pred['sentence1'][:80]}")
            print(f"       S2: {pred['sentence2'][:80]}")

        # 5. Show reconstructed text
        reconstructed = _reconstruct(sentences, predictions)
        print("\n--- Reconstructed text ---")
        print(reconstructed)
        print()


def main() -> None:
    parser = argparse.ArgumentParser(description="Interactive paragraph-boundary inference (ONNX).")
    parser.add_argument(
        "--model",
        default="distilbert",
        choices=list(HF_MODELS.keys()),
        help="Which model to use (default: distilbert)",
    )
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--local", action="store_true",
                        help="Load from checkpoints/<model>/best instead of HF Hub")
    args = parser.parse_args()

    interactive_loop(args.model, args.max_length, local=args.local)


if __name__ == "__main__":
    main()