Article-classifier / inference.py
Pyotr Lisov
Add article classifier app
70b2ea0
from __future__ import annotations
import json
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
PROJECT_DIR = Path(__file__).resolve().parent
DEFAULT_MODEL_DIR = PROJECT_DIR / "artifacts" / "large_model" / "best_model"
DEFAULT_LABELS_PATH = PROJECT_DIR / "data" / "processed_large" / "label_mapping.json"
class ClassifierError(RuntimeError):
pass
class ArticleClassifier:
def __init__(
self,
model_dir: Path = DEFAULT_MODEL_DIR,
labels_path: Path = DEFAULT_LABELS_PATH,
max_length: int = 256,
) -> None:
self.model_dir = Path(model_dir)
self.labels_path = Path(labels_path)
self.max_length = max_length
self.device = torch.device(
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
if not self.labels_path.exists():
raise ClassifierError(
f"Failed to initialize classifier at labels loading stage: labels file not found at {self.labels_path}"
)
if not self.model_dir.exists():
raise ClassifierError(
f"Failed to initialize classifier at model loading stage: model directory not found at {self.model_dir}"
)
try:
with self.labels_path.open("r", encoding="utf-8") as fh:
self.label2id = json.load(fh)
except Exception as exc:
raise ClassifierError(
f"Failed to initialize classifier at labels loading stage: {exc}"
) from exc
if not isinstance(self.label2id, dict) or not self.label2id:
raise ClassifierError(
"Failed to initialize classifier at labels loading stage: label mapping is empty or invalid"
)
self.id2label = {idx: label for label, idx in self.label2id.items()}
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir)
self.model.to(self.device)
self.model.eval()
except Exception as exc:
raise ClassifierError(
f"Failed to initialize classifier at model loading stage: {exc}"
) from exc
@property
def labels(self) -> list[str]:
return [self.id2label[idx] for idx in sorted(self.id2label)]
@staticmethod
def build_input_text(title: str, abstract: str) -> str:
clean_title = " ".join(title.split()).strip()
clean_abstract = " ".join(abstract.split()).strip()
if clean_abstract:
return f"title: {clean_title} abstract: {clean_abstract}"
return f"title: {clean_title}"
def predict(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
if not isinstance(title, str):
raise ValueError("Input validation error in predict: title must be a string.")
if not isinstance(abstract, str):
raise ValueError("Input validation error in predict: abstract must be a string.")
if not title.strip() and not abstract.strip():
raise ValueError(
"Input validation error in predict: please provide at least a title or an abstract."
)
text = self.build_input_text(title=title, abstract=abstract)
try:
encoded = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
encoded = {key: value.to(self.device) for key, value in encoded.items()}
except Exception as exc:
raise ClassifierError(f"Failed during tokenization stage: {exc}") from exc
try:
with torch.inference_mode():
logits = self.model(**encoded).logits
probabilities = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu()
except Exception as exc:
raise ClassifierError(f"Failed during model inference stage: {exc}") from exc
results: list[dict[str, float | str]] = []
try:
for class_id, probability in enumerate(probabilities.tolist()):
results.append(
{
"label": self.id2label[class_id],
"probability": float(probability),
}
)
except Exception as exc:
raise ClassifierError(f"Failed during prediction formatting stage: {exc}") from exc
results.sort(key=lambda item: item["probability"], reverse=True)
return results
@staticmethod
def select_top_95(
predictions: list[dict[str, float | str]],
) -> list[dict[str, float | str]]:
return ArticleClassifier.select_top_k_by_probability_mass(
predictions=predictions,
threshold=0.95,
)
@staticmethod
def select_top_k_by_probability_mass(
predictions: list[dict[str, float | str]],
threshold: float = 0.95,
) -> list[dict[str, float | str]]:
if not 0 < threshold <= 1:
raise ValueError("Probability mass threshold must be in the interval (0, 1].")
cumulative_probability = 0.0
top_predictions: list[dict[str, float | str]] = []
for item in predictions:
top_predictions.append(item)
cumulative_probability += float(item["probability"])
if cumulative_probability >= 0.95:
break
return top_predictions
def predict_top_95(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
predictions = self.predict(title=title, abstract=abstract)
return self.select_top_95(predictions)