bc-test / src /pipelines /base.py
lamossta's picture
base pipeline
e2f27ed
raw
history blame contribute delete
947 Bytes
"""Base pipeline with shared SAT preprocessing and reconstruction."""
from abc import ABC, abstractmethod
LABEL_SYMBOLS = {
"SAME_PARAGRAPH": " ",
"NEW_PARAGRAPH": "\n\n",
"NEWLINE": "\n",
}
class BasePipeline(ABC):
"""Base class for newline-fixing inference pipelines."""
def _split_sentences(self, text: str) -> list[str]:
sentences = self._sat.split(text, split_on_input_newlines=False, strip_whitespace=False)
return [s.replace('\n', '').strip() for s in sentences if s.strip()]
@staticmethod
def _reconstruct(sentences: list[str], predictions: list[dict]) -> str:
if not sentences:
return ""
parts = [sentences[0]]
for i, pred in enumerate(predictions):
sep = LABEL_SYMBOLS.get(pred["label"], " ")
parts.append(sep + sentences[i + 1])
return "".join(parts)
@abstractmethod
def predict(self, text: str):
...