lamossta commited on
Commit
ffcf8df
·
1 Parent(s): 24c3bcf

api and pages

Browse files
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit UI entry point for the Newline Fixer service."""
2
+ import streamlit as st
3
+
4
+ st.set_page_config(initial_sidebar_state="collapsed")
5
+
6
+ st.markdown(
7
+ """
8
+ <style>
9
+ [data-testid="collapsedControl"] { display: none; }
10
+ </style>
11
+ """,
12
+ unsafe_allow_html=True,
13
+ )
14
+
15
+ home = st.Page("pages/home.py", title="Home", default=True)
16
+ config = st.Page("pages/config.py", title="Config")
17
+ result = st.Page("pages/result.py", title="Result")
18
+
19
+ pg = st.navigation([home, config, result], position="hidden")
20
+ pg.run()
pages/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.fe_handler import fix_newlines, fix_newlines_all_models
3
+ from pages.nav import show_stepper
4
+
5
+
6
+ show_stepper("Config")
7
+
8
+ st.title("Configure request")
9
+
10
+ if st.button("← Back"):
11
+ st.switch_page("pages/home.py")
12
+
13
+ text = st.text_area(
14
+ "Paste your text here:",
15
+ value=st.session_state.get("original_text", ""),
16
+ height=300,
17
+ key="input_text",
18
+ )
19
+
20
+ endpoint = st.radio(
21
+ "Select endpoint:",
22
+ ["fix-newlines", "fix-newlines-all-models"],
23
+ key="endpoint",
24
+ help="**fix-newlines**: single model (distilbert). "
25
+ "**fix-newlines-all-models**: all models side by side.",
26
+ )
27
+
28
+ if st.button("Submit"):
29
+ if not text.strip():
30
+ st.warning("Please enter some text.")
31
+ else:
32
+ try:
33
+ if endpoint == "fix-newlines":
34
+ result = fix_newlines(text)
35
+ else:
36
+ result = fix_newlines_all_models(text)
37
+
38
+ st.session_state["original_text"] = text
39
+ st.session_state["selected_endpoint"] = endpoint
40
+ st.session_state["result"] = result
41
+ st.switch_page("pages/result.py")
42
+ except Exception as e:
43
+ st.error(f"Request failed: {e}")
pages/home.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.fe_handler import health
3
+ from pages.nav import show_stepper
4
+
5
+
6
+ show_stepper("Home")
7
+
8
+ st.title("Newline Fixer")
9
+ st.write(
10
+ "An ML service for fixing newline placement in English text. "
11
+ "Paste your text, pick an endpoint, and get properly formatted output."
12
+ )
13
+
14
+ st.subheader("Available endpoints")
15
+
16
+ st.markdown(
17
+ "**`/fix-newlines`** — Runs your text through a single model (distilbert). "
18
+ "Returns the fixed text with corrected newline placement."
19
+ )
20
+
21
+ st.markdown(
22
+ "**`/fix-newlines-all-models`** — Runs your text through all available models "
23
+ "(distilbert, bert, deberta) and returns the results from each, "
24
+ "so you can compare their outputs side by side."
25
+ )
26
+
27
+ try:
28
+ h = health()
29
+ st.success(f"API is running. Available models: {', '.join(h['available_models'])}")
30
+ except Exception:
31
+ st.error("API is not reachable. Make sure the server is running on localhost:8000.")
32
+
33
+ if st.button("Next →"):
34
+ st.switch_page("pages/config.py")
src/api/fix_newlines.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Request
2
+
3
+ from src.schemas.requests import FixNewlinesRequest
4
+ from src.schemas.responses import FixNewlinesResponse
5
+
6
+ router = APIRouter()
7
+
8
+
9
+ @router.post("/fix-newlines", response_model=FixNewlinesResponse)
10
+ def fix_newlines(request: Request, body: FixNewlinesRequest):
11
+ pipeline = request.app.state.one_model_pipeline
12
+ fixed = pipeline.predict(body.text)
13
+ return FixNewlinesResponse(fixed_text=fixed, model_used=pipeline.model_name)
src/api/fix_newlines_all_models.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Request
2
+
3
+ from src.schemas.requests import FixNewlinesAllModelsRequest
4
+ from src.schemas.responses import FixNewlinesAllModelsResponse, ModelResult
5
+
6
+ router = APIRouter()
7
+
8
+
9
+ @router.post("/fix-newlines-all-models", response_model=FixNewlinesAllModelsResponse)
10
+ def fix_newlines_all_models(request: Request, body: FixNewlinesAllModelsRequest):
11
+ pipeline = request.app.state.all_models_pipeline
12
+ results_dict = pipeline.predict(body.text)
13
+ results = [
14
+ ModelResult(model_name=name, fixed_text=text)
15
+ for name, text in results_dict.items()
16
+ ]
17
+ return FixNewlinesAllModelsResponse(results=results)
src/api/health.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Request
2
+
3
+ from src.schemas.responses import HealthResponse
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.get("/health", response_model=HealthResponse)
9
+ def health(request: Request):
10
+ models = []
11
+ if request.app.state.one_model_pipeline is not None:
12
+ models.append(request.app.state.one_model_pipeline.model_name)
13
+ models.extend(request.app.state.all_models_pipeline.model_names)
14
+ seen = set()
15
+ unique = []
16
+ for m in models:
17
+ if m not in seen:
18
+ seen.add(m)
19
+ unique.append(m)
20
+
21
+ return HealthResponse(
22
+ status="ok",
23
+ available_models=unique,
24
+ )
src/datasets/build_pairs.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build sentence pairs from sentence-labeled JSONL files.
2
+
3
+ For each document with sentences [s1, s2, s3, ...] and labels [l1, l2, l3, ...],
4
+ produce pairs: (s1, s2, l1), (s2, s3, l2), ..., (s_{n-1}, s_n, l_{n-1}).
5
+
6
+ The label describes the boundary between the two sentences in each pair.
7
+ """
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+
13
+
14
+ def _build_pairs_from_records(
15
+ records: list[dict],
16
+ id_field: str,
17
+ desc: str,
18
+ ) -> list[dict]:
19
+ """Convert sentence-level records into pair-level records."""
20
+ pairs: list[dict] = []
21
+ for record in tqdm(records, desc=desc):
22
+ sentences = record["sentences"]
23
+ labels = record["labels"]
24
+
25
+ doc_id = record.get(id_field, "")
26
+
27
+ for i in range(len(sentences) - 1):
28
+ pairs.append({
29
+ id_field: doc_id,
30
+ "sentence1": sentences[i],
31
+ "sentence2": sentences[i + 1],
32
+ "label": labels[i],
33
+ })
34
+ return pairs
35
+
36
+
37
+ def _load_jsonl(path: Path) -> list[dict]:
38
+ records = []
39
+ with open(path, encoding="utf-8") as f:
40
+ for line in f:
41
+ line = line.strip()
42
+ if line:
43
+ records.append(json.loads(line))
44
+ return records
45
+
46
+
47
+ def _write_jsonl(pairs: list[dict], path: Path) -> None:
48
+ path.parent.mkdir(parents=True, exist_ok=True)
49
+ with open(path, "w", encoding="utf-8") as f:
50
+ for pair in pairs:
51
+ f.write(json.dumps(pair, ensure_ascii=False) + "\n")
52
+ print(f"Wrote {len(pairs):,} pairs → {path}")
53
+
54
+
55
+ def build_pubmed_pairs(
56
+ input_path: str | Path = "data/pubmed/pubmed_sentences.jsonl",
57
+ output_path: str | Path = "data/pubmed/pubmed_pairs.jsonl",
58
+ ) -> None:
59
+ records = _load_jsonl(Path(input_path))
60
+ pairs = _build_pairs_from_records(records, "document_idx", "Building PubMed pairs")
61
+ _write_jsonl(pairs, Path(output_path))
62
+
63
+
64
+ def build_wikipedia_pairs(
65
+ input_path: str | Path = "data/wikipedia/wikipedia_sentences.jsonl",
66
+ output_path: str | Path = "data/wikipedia/wikipedia_pairs.jsonl",
67
+ ) -> None:
68
+ records = _load_jsonl(Path(input_path))
69
+ pairs = _build_pairs_from_records(records, "document_idx", "Building Wikipedia pairs")
70
+ _write_jsonl(pairs, Path(output_path))
71
+
72
+
73
+ def build_gutenberg_pairs(
74
+ input_path: str | Path = "data/gutenberg/gutenberg_sentences.jsonl",
75
+ output_path: str | Path = "data/gutenberg/gutenberg_pairs.jsonl",
76
+ ) -> None:
77
+ records = _load_jsonl(Path(input_path))
78
+ pairs = _build_pairs_from_records(records, "file_name", "Building Gutenberg pairs")
79
+ _write_jsonl(pairs, Path(output_path))
80
+
81
+
82
+ def build_recipes_pairs(
83
+ input_path: str | Path = "data/recipes/recipes_sentences.jsonl",
84
+ output_path: str | Path = "data/recipes/recipes_pairs.jsonl",
85
+ ) -> None:
86
+ records = _load_jsonl(Path(input_path))
87
+ pairs = _build_pairs_from_records(records, "document_idx", "Building recipes pairs")
88
+ _write_jsonl(pairs, Path(output_path))
89
+
90
+
91
+ def build_all_pairs() -> None:
92
+ build_pubmed_pairs()
93
+ build_wikipedia_pairs()
94
+ build_gutenberg_pairs()
95
+ build_recipes_pairs()
96
+
97
+
98
+ if __name__ == "__main__":
99
+ import argparse
100
+
101
+ parser = argparse.ArgumentParser(description="Build sentence pairs from sentence-labeled JSONL files.")
102
+ sub = parser.add_subparsers(dest="dataset")
103
+
104
+ pub = sub.add_parser("pubmed", help="Build PubMed pairs")
105
+ pub.add_argument("--input", default="data/pubmed/pubmed_sentences.jsonl")
106
+ pub.add_argument("--output", default="data/pubmed/pubmed_pairs.jsonl")
107
+
108
+ wiki = sub.add_parser("wikipedia", help="Build Wikipedia pairs")
109
+ wiki.add_argument("--input", default="data/wikipedia/wikipedia_sentences.jsonl")
110
+ wiki.add_argument("--output", default="data/wikipedia/wikipedia_pairs.jsonl")
111
+
112
+ gut = sub.add_parser("gutenberg", help="Build Gutenberg pairs")
113
+ gut.add_argument("--input", default="data/gutenberg/gutenberg_sentences.jsonl")
114
+ gut.add_argument("--output", default="data/gutenberg/gutenberg_pairs.jsonl")
115
+
116
+ rec = sub.add_parser("recipes", help="Build recipes pairs")
117
+ rec.add_argument("--input", default="data/recipes/recipes_sentences.jsonl")
118
+ rec.add_argument("--output", default="data/recipes/recipes_pairs.jsonl")
119
+
120
+ all_p = sub.add_parser("all", help="Build pairs for all datasets")
121
+
122
+ args = parser.parse_args()
123
+
124
+ if args.dataset == "pubmed":
125
+ build_pubmed_pairs(args.input, args.output)
126
+ elif args.dataset == "wikipedia":
127
+ build_wikipedia_pairs(args.input, args.output)
128
+ elif args.dataset == "gutenberg":
129
+ build_gutenberg_pairs(args.input, args.output)
130
+ elif args.dataset == "recipes":
131
+ build_recipes_pairs(args.input, args.output)
132
+ elif args.dataset == "all":
133
+ build_all_pairs()
134
+ else:
135
+ parser.print_help()
src/datasets/create_recipes_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Create a recipes dataset from RecipeNLG full_dataset.csv.
2
+
3
+ Randomly samples recipes, formats each as a structured document, and writes
4
+ a JSONL file with fields: document_idx, text.
5
+
6
+ Usage:
7
+ python -m src.datasets.create_recipes_dataset
8
+ python -m src.datasets.create_recipes_dataset --n_samples 500 --seed 42
9
+ """
10
+
11
+ import argparse
12
+ import ast
13
+ import json
14
+ import random
15
+ from pathlib import Path
16
+
17
+ import pandas as pd
18
+
19
+
20
+ def format_recipe_as_document(rec: dict) -> str:
21
+ """Turn a raw recipe CSV row into a formatted document.
22
+
23
+ Structure:
24
+ Title\\n\\n
25
+ Ingredients:\\n
26
+ - ingredient 1\\n
27
+ - ingredient 2\\n\\n
28
+ Directions:\\n
29
+ 1. step one\\n
30
+ 2. step two
31
+ """
32
+ title = rec["title"].strip()
33
+
34
+ ingredients = rec["ingredients"]
35
+ if isinstance(ingredients, str):
36
+ ingredients = ast.literal_eval(ingredients)
37
+
38
+ directions = rec["directions"]
39
+ if isinstance(directions, str):
40
+ directions = ast.literal_eval(directions)
41
+
42
+ bullet = random.choice(["- ", "• "])
43
+ num_style = random.choice(["dot", "paren"])
44
+
45
+ ing_lines = [f"{bullet}{ing.strip()}" for ing in ingredients if ing.strip()]
46
+ if num_style == "dot":
47
+ dir_lines = [f"{i+1}. {d.strip()}" for i, d in enumerate(directions) if d.strip()]
48
+ else:
49
+ dir_lines = [f"{i+1}) {d.strip()}" for i, d in enumerate(directions) if d.strip()]
50
+
51
+ parts = [
52
+ title,
53
+ "",
54
+ "Ingredients:",
55
+ *ing_lines,
56
+ "",
57
+ "Directions:",
58
+ *dir_lines,
59
+ ]
60
+ return "\n".join(parts)
61
+
62
+
63
+ def main() -> None:
64
+ parser = argparse.ArgumentParser(description="Create recipes dataset from RecipeNLG CSV.")
65
+ parser.add_argument("--csv_path", type=str, default="data/recipes/full_dataset.csv")
66
+ parser.add_argument("--output", type=str, default="data/recipes/recipes_data.jsonl")
67
+ parser.add_argument("--n_samples", type=int, default=100)
68
+ parser.add_argument("--seed", type=int, default=42)
69
+ args = parser.parse_args()
70
+
71
+ random.seed(args.seed)
72
+ csv_path = Path(args.csv_path)
73
+ out_path = Path(args.output)
74
+ out_path.parent.mkdir(parents=True, exist_ok=True)
75
+
76
+ # Load CSV with pandas
77
+ df = pd.read_csv(csv_path)
78
+ print(f"Total recipes in CSV: {len(df):,}")
79
+
80
+ # Sample random rows
81
+ df_sample = df.sample(n=min(args.n_samples, len(df)), random_state=args.seed)
82
+ recipes = df_sample.to_dict(orient="records")
83
+ print(f"Sampled {len(recipes)} recipes")
84
+
85
+ # Format and write JSONL
86
+ with open(out_path, "w", encoding="utf-8") as f:
87
+ for doc_idx, rec in enumerate(recipes):
88
+ text = format_recipe_as_document(rec)
89
+ record = {"document_idx": doc_idx, "text": text}
90
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
91
+
92
+ print(f"Wrote {len(recipes)} documents -> {out_path}")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()
src/fe_handler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frontend handler — bridges the Streamlit UI with the FastAPI backend."""
2
+
3
+ import requests
4
+
5
+ BASE_URL = "http://localhost:8000"
6
+
7
+
8
+ def fix_newlines(text: str, model_name: str | None = None) -> dict:
9
+ payload = {"text": text}
10
+ if model_name:
11
+ payload["model_name"] = model_name
12
+ resp = requests.post(f"{BASE_URL}/fix-newlines", json=payload)
13
+ resp.raise_for_status()
14
+ return resp.json()
15
+
16
+
17
+ def fix_newlines_all_models(text: str) -> dict:
18
+ resp = requests.post(f"{BASE_URL}/fix-newlines-all-models", json={"text": text})
19
+ resp.raise_for_status()
20
+ return resp.json()
21
+
22
+
23
+ def health() -> dict:
24
+ resp = requests.get(f"{BASE_URL}/health")
25
+ resp.raise_for_status()
26
+ return resp.json()
src/models/inference.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive CLI for paragraph-boundary inference using ONNX models.
3
+
4
+ Downloads pre-trained ONNX models from Hugging Face Hub (if not cached),
5
+ loads SAT-12L for sentence splitting, then enters an interactive loop:
6
+ paste text, get boundary predictions.
7
+
8
+ Usage:
9
+ python -m src.models.inference
10
+ python -m src.models.inference --model distilbert
11
+ python -m src.models.inference --model bert
12
+ """
13
+
14
+ import argparse
15
+ import logging
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import onnxruntime as ort
20
+ from transformers import AutoTokenizer
21
+
22
+ from src.datasets.combined_pairs_dataset import ID2LABEL
23
+ from src.pipelines.sat_loader import load_sat
24
+ from src.models.export_and_download import HF_MODELS, download_model
25
+
26
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
27
+ log = logging.getLogger(__name__)
28
+
29
+ LABEL_SYMBOLS = {
30
+ "SAME_PARAGRAPH": " ",
31
+ "NEW_PARAGRAPH": "\n\n",
32
+ "NEWLINE": "\n",
33
+ }
34
+
35
+
36
+ LOCAL_CHECKPOINTS = Path("checkpoints")
37
+
38
+
39
+ def _load_onnx_model(model_name: str, local: bool = False):
40
+ """Load an ONNX session + tokenizer from local checkpoints or HF Hub."""
41
+ if local:
42
+ model_dir = LOCAL_CHECKPOINTS / model_name / "best"
43
+ else:
44
+ repo_id = HF_MODELS[model_name]
45
+ model_dir = download_model(repo_id)
46
+
47
+ onnx_path = model_dir / "model.onnx"
48
+ if not onnx_path.exists():
49
+ raise FileNotFoundError(f"No model.onnx found in {model_dir}")
50
+
51
+ session = ort.InferenceSession(
52
+ str(onnx_path),
53
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
54
+ )
55
+ input_names = [inp.name for inp in session.get_inputs()]
56
+ tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
57
+ return session, tokenizer, input_names
58
+
59
+
60
+ def _predict_pairs(session, tokenizer, input_names, sentences: list[str], max_length: int = 512) -> list[dict]:
61
+ """Classify boundary between each consecutive sentence pair via ONNX."""
62
+ if len(sentences) < 2:
63
+ return []
64
+
65
+ results = []
66
+ for i in range(len(sentences) - 1):
67
+ enc = tokenizer(
68
+ sentences[i],
69
+ sentences[i + 1],
70
+ truncation=True,
71
+ max_length=max_length,
72
+ padding="max_length",
73
+ return_tensors="np",
74
+ )
75
+ feeds = {k: enc[k] for k in input_names if k in enc}
76
+ logits = session.run(None, feeds)[0]
77
+ probs = _softmax(logits[0])
78
+ pred = int(np.argmax(probs))
79
+
80
+ results.append({
81
+ "sentence1": sentences[i],
82
+ "sentence2": sentences[i + 1],
83
+ "label": ID2LABEL[pred],
84
+ "confidence": round(float(probs[pred]), 4),
85
+ })
86
+
87
+ return results
88
+
89
+
90
+ def _softmax(x: np.ndarray) -> np.ndarray:
91
+ e = np.exp(x - np.max(x))
92
+ return e / e.sum()
93
+
94
+
95
+ def _reconstruct(sentences: list[str], predictions: list[dict]) -> str:
96
+ """Rebuild text from sentences and predicted boundaries."""
97
+ if not sentences:
98
+ return ""
99
+ parts = [sentences[0]]
100
+ for i, pred in enumerate(predictions):
101
+ sep = LABEL_SYMBOLS.get(pred["label"], " ")
102
+ parts.append(sep + sentences[i + 1])
103
+ return "".join(parts)
104
+
105
+
106
+ def _read_multiline() -> str | None:
107
+ """Read multi-line input until an empty line is entered."""
108
+ print("Paste your text (empty line to submit, 'quit' to exit):")
109
+ lines = []
110
+ while True:
111
+ try:
112
+ line = input()
113
+ except EOFError:
114
+ return None
115
+ if line.strip().lower() == "quit":
116
+ return None
117
+ if line == "" and lines:
118
+ break
119
+ lines.append(line)
120
+ return "\n".join(lines)
121
+
122
+
123
+ def interactive_loop(model_name: str, max_length: int = 512, local: bool = False) -> None:
124
+ source = "local checkpoints" if local else "HuggingFace Hub"
125
+ log.info(f"Loading ONNX model '{model_name}' from {source} ...")
126
+ session, tokenizer, input_names = _load_onnx_model(model_name, local=local)
127
+
128
+ log.info("Loading SAT-12L ...")
129
+ sat = load_sat()
130
+
131
+ print("\n" + "=" * 60)
132
+ print(f" Paragraph Boundary Inference [{model_name} / ONNX]")
133
+ print("=" * 60 + "\n")
134
+
135
+ while True:
136
+ text = _read_multiline()
137
+ if text is None:
138
+ print("Bye.")
139
+ break
140
+
141
+ if not text.strip():
142
+ print("(empty input, skipping)\n")
143
+ continue
144
+
145
+ # 1. Sentence-split with SAT first, then strip newlines from each sentence
146
+ sentences = sat.split(text, split_on_input_newlines=False, strip_whitespace=False)
147
+ sentences = [s.replace('\n', '').strip() for s in sentences if s.strip()]
148
+
149
+ print(f"\n--- {len(sentences)} sentence(s) detected ---")
150
+ if len(sentences) < 2:
151
+ print(f" {sentences[0] if sentences else '(none)'}")
152
+ print(" (need at least 2 sentences to classify boundaries)\n")
153
+ continue
154
+
155
+ # 3. Predict boundaries
156
+ predictions = _predict_pairs(session, tokenizer, input_names, sentences, max_length)
157
+
158
+ # 4. Show per-pair results
159
+ for i, pred in enumerate(predictions):
160
+ print(f" [{i+1}] {pred['label']:16s} ({pred['confidence']:.2%})")
161
+ print(f" S1: {pred['sentence1'][:80]}")
162
+ print(f" S2: {pred['sentence2'][:80]}")
163
+
164
+ # 5. Show reconstructed text
165
+ reconstructed = _reconstruct(sentences, predictions)
166
+ print("\n--- Reconstructed text ---")
167
+ print(reconstructed)
168
+ print()
169
+
170
+
171
+ def main() -> None:
172
+ parser = argparse.ArgumentParser(description="Interactive paragraph-boundary inference (ONNX).")
173
+ parser.add_argument(
174
+ "--model",
175
+ default="distilbert",
176
+ choices=list(HF_MODELS.keys()),
177
+ help="Which model to use (default: distilbert)",
178
+ )
179
+ parser.add_argument("--max_length", type=int, default=512)
180
+ parser.add_argument("--local", action="store_true",
181
+ help="Load from checkpoints/<model>/best instead of HF Hub")
182
+ args = parser.parse_args()
183
+
184
+ interactive_loop(args.model, args.max_length, local=args.local)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
tests/conftest.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+
4
+ from main import app
5
+
6
+
7
+ @pytest.fixture(scope="session")
8
+ def client():
9
+ """TestClient that runs the full app lifespan (loads SAT + ONNX models once)."""
10
+ with TestClient(app) as c:
11
+ yield c