| """Base pipeline with shared SAT preprocessing and reconstruction.""" |
|
|
| from abc import ABC, abstractmethod |
|
|
| LABEL_SYMBOLS = { |
| "SAME_PARAGRAPH": " ", |
| "NEW_PARAGRAPH": "\n\n", |
| "NEWLINE": "\n", |
| } |
|
|
|
|
| class BasePipeline(ABC): |
| """Base class for newline-fixing inference pipelines.""" |
|
|
| def _split_sentences(self, text: str) -> list[str]: |
| sentences = self._sat.split(text, split_on_input_newlines=False, strip_whitespace=False) |
| return [s.replace('\n', '').strip() for s in sentences if s.strip()] |
|
|
| @staticmethod |
| def _reconstruct(sentences: list[str], predictions: list[dict]) -> str: |
| if not sentences: |
| return "" |
| parts = [sentences[0]] |
| for i, pred in enumerate(predictions): |
| sep = LABEL_SYMBOLS.get(pred["label"], " ") |
| parts.append(sep + sentences[i + 1]) |
| return "".join(parts) |
|
|
| @abstractmethod |
| def predict(self, text: str): |
| ... |
|
|