Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| PROJECT_DIR = Path(__file__).resolve().parent | |
| DEFAULT_MODEL_DIR = PROJECT_DIR / "artifacts" / "large_model" / "best_model" | |
| DEFAULT_LABELS_PATH = PROJECT_DIR / "data" / "processed_large" / "label_mapping.json" | |
| class ClassifierError(RuntimeError): | |
| pass | |
| class ArticleClassifier: | |
| def __init__( | |
| self, | |
| model_dir: Path = DEFAULT_MODEL_DIR, | |
| labels_path: Path = DEFAULT_LABELS_PATH, | |
| max_length: int = 256, | |
| ) -> None: | |
| self.model_dir = Path(model_dir) | |
| self.labels_path = Path(labels_path) | |
| self.max_length = max_length | |
| self.device = torch.device( | |
| "mps" | |
| if torch.backends.mps.is_available() | |
| else "cuda" | |
| if torch.cuda.is_available() | |
| else "cpu" | |
| ) | |
| if not self.labels_path.exists(): | |
| raise ClassifierError( | |
| f"Failed to initialize classifier at labels loading stage: labels file not found at {self.labels_path}" | |
| ) | |
| if not self.model_dir.exists(): | |
| raise ClassifierError( | |
| f"Failed to initialize classifier at model loading stage: model directory not found at {self.model_dir}" | |
| ) | |
| try: | |
| with self.labels_path.open("r", encoding="utf-8") as fh: | |
| self.label2id = json.load(fh) | |
| except Exception as exc: | |
| raise ClassifierError( | |
| f"Failed to initialize classifier at labels loading stage: {exc}" | |
| ) from exc | |
| if not isinstance(self.label2id, dict) or not self.label2id: | |
| raise ClassifierError( | |
| "Failed to initialize classifier at labels loading stage: label mapping is empty or invalid" | |
| ) | |
| self.id2label = {idx: label for label, idx in self.label2id.items()} | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| except Exception as exc: | |
| raise ClassifierError( | |
| f"Failed to initialize classifier at model loading stage: {exc}" | |
| ) from exc | |
| def labels(self) -> list[str]: | |
| return [self.id2label[idx] for idx in sorted(self.id2label)] | |
| def build_input_text(title: str, abstract: str) -> str: | |
| clean_title = " ".join(title.split()).strip() | |
| clean_abstract = " ".join(abstract.split()).strip() | |
| if clean_abstract: | |
| return f"title: {clean_title} abstract: {clean_abstract}" | |
| return f"title: {clean_title}" | |
| def predict(self, title: str, abstract: str = "") -> list[dict[str, float | str]]: | |
| if not isinstance(title, str): | |
| raise ValueError("Input validation error in predict: title must be a string.") | |
| if not isinstance(abstract, str): | |
| raise ValueError("Input validation error in predict: abstract must be a string.") | |
| if not title.strip() and not abstract.strip(): | |
| raise ValueError( | |
| "Input validation error in predict: please provide at least a title or an abstract." | |
| ) | |
| text = self.build_input_text(title=title, abstract=abstract) | |
| try: | |
| encoded = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=self.max_length, | |
| ) | |
| encoded = {key: value.to(self.device) for key, value in encoded.items()} | |
| except Exception as exc: | |
| raise ClassifierError(f"Failed during tokenization stage: {exc}") from exc | |
| try: | |
| with torch.inference_mode(): | |
| logits = self.model(**encoded).logits | |
| probabilities = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu() | |
| except Exception as exc: | |
| raise ClassifierError(f"Failed during model inference stage: {exc}") from exc | |
| results: list[dict[str, float | str]] = [] | |
| try: | |
| for class_id, probability in enumerate(probabilities.tolist()): | |
| results.append( | |
| { | |
| "label": self.id2label[class_id], | |
| "probability": float(probability), | |
| } | |
| ) | |
| except Exception as exc: | |
| raise ClassifierError(f"Failed during prediction formatting stage: {exc}") from exc | |
| results.sort(key=lambda item: item["probability"], reverse=True) | |
| return results | |
| def select_top_95( | |
| predictions: list[dict[str, float | str]], | |
| ) -> list[dict[str, float | str]]: | |
| return ArticleClassifier.select_top_k_by_probability_mass( | |
| predictions=predictions, | |
| threshold=0.95, | |
| ) | |
| def select_top_k_by_probability_mass( | |
| predictions: list[dict[str, float | str]], | |
| threshold: float = 0.95, | |
| ) -> list[dict[str, float | str]]: | |
| if not 0 < threshold <= 1: | |
| raise ValueError("Probability mass threshold must be in the interval (0, 1].") | |
| cumulative_probability = 0.0 | |
| top_predictions: list[dict[str, float | str]] = [] | |
| for item in predictions: | |
| top_predictions.append(item) | |
| cumulative_probability += float(item["probability"]) | |
| if cumulative_probability >= 0.95: | |
| break | |
| return top_predictions | |
| def predict_top_95(self, title: str, abstract: str = "") -> list[dict[str, float | str]]: | |
| predictions = self.predict(title=title, abstract=abstract) | |
| return self.select_top_95(predictions) | |