from __future__ import annotations import json from pathlib import Path import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer PROJECT_DIR = Path(__file__).resolve().parent DEFAULT_MODEL_DIR = PROJECT_DIR / "artifacts" / "large_model" / "best_model" DEFAULT_LABELS_PATH = PROJECT_DIR / "data" / "processed_large" / "label_mapping.json" class ClassifierError(RuntimeError): pass class ArticleClassifier: def __init__( self, model_dir: Path = DEFAULT_MODEL_DIR, labels_path: Path = DEFAULT_LABELS_PATH, max_length: int = 256, ) -> None: self.model_dir = Path(model_dir) self.labels_path = Path(labels_path) self.max_length = max_length self.device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) if not self.labels_path.exists(): raise ClassifierError( f"Failed to initialize classifier at labels loading stage: labels file not found at {self.labels_path}" ) if not self.model_dir.exists(): raise ClassifierError( f"Failed to initialize classifier at model loading stage: model directory not found at {self.model_dir}" ) try: with self.labels_path.open("r", encoding="utf-8") as fh: self.label2id = json.load(fh) except Exception as exc: raise ClassifierError( f"Failed to initialize classifier at labels loading stage: {exc}" ) from exc if not isinstance(self.label2id, dict) or not self.label2id: raise ClassifierError( "Failed to initialize classifier at labels loading stage: label mapping is empty or invalid" ) self.id2label = {idx: label for label, idx in self.label2id.items()} try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir) self.model.to(self.device) self.model.eval() except Exception as exc: raise ClassifierError( f"Failed to initialize classifier at model loading stage: {exc}" ) from exc @property def labels(self) -> list[str]: return [self.id2label[idx] for idx in sorted(self.id2label)] @staticmethod def build_input_text(title: str, abstract: str) -> str: clean_title = " ".join(title.split()).strip() clean_abstract = " ".join(abstract.split()).strip() if clean_abstract: return f"title: {clean_title} abstract: {clean_abstract}" return f"title: {clean_title}" def predict(self, title: str, abstract: str = "") -> list[dict[str, float | str]]: if not isinstance(title, str): raise ValueError("Input validation error in predict: title must be a string.") if not isinstance(abstract, str): raise ValueError("Input validation error in predict: abstract must be a string.") if not title.strip() and not abstract.strip(): raise ValueError( "Input validation error in predict: please provide at least a title or an abstract." ) text = self.build_input_text(title=title, abstract=abstract) try: encoded = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=self.max_length, ) encoded = {key: value.to(self.device) for key, value in encoded.items()} except Exception as exc: raise ClassifierError(f"Failed during tokenization stage: {exc}") from exc try: with torch.inference_mode(): logits = self.model(**encoded).logits probabilities = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu() except Exception as exc: raise ClassifierError(f"Failed during model inference stage: {exc}") from exc results: list[dict[str, float | str]] = [] try: for class_id, probability in enumerate(probabilities.tolist()): results.append( { "label": self.id2label[class_id], "probability": float(probability), } ) except Exception as exc: raise ClassifierError(f"Failed during prediction formatting stage: {exc}") from exc results.sort(key=lambda item: item["probability"], reverse=True) return results @staticmethod def select_top_95( predictions: list[dict[str, float | str]], ) -> list[dict[str, float | str]]: return ArticleClassifier.select_top_k_by_probability_mass( predictions=predictions, threshold=0.95, ) @staticmethod def select_top_k_by_probability_mass( predictions: list[dict[str, float | str]], threshold: float = 0.95, ) -> list[dict[str, float | str]]: if not 0 < threshold <= 1: raise ValueError("Probability mass threshold must be in the interval (0, 1].") cumulative_probability = 0.0 top_predictions: list[dict[str, float | str]] = [] for item in predictions: top_predictions.append(item) cumulative_probability += float(item["probability"]) if cumulative_probability >= 0.95: break return top_predictions def predict_top_95(self, title: str, abstract: str = "") -> list[dict[str, float | str]]: predictions = self.predict(title=title, abstract=abstract) return self.select_top_95(predictions)