Spaces:
Sleeping
Sleeping
File size: 5,902 Bytes
70b2ea0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | from __future__ import annotations
import json
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
PROJECT_DIR = Path(__file__).resolve().parent
DEFAULT_MODEL_DIR = PROJECT_DIR / "artifacts" / "large_model" / "best_model"
DEFAULT_LABELS_PATH = PROJECT_DIR / "data" / "processed_large" / "label_mapping.json"
class ClassifierError(RuntimeError):
pass
class ArticleClassifier:
def __init__(
self,
model_dir: Path = DEFAULT_MODEL_DIR,
labels_path: Path = DEFAULT_LABELS_PATH,
max_length: int = 256,
) -> None:
self.model_dir = Path(model_dir)
self.labels_path = Path(labels_path)
self.max_length = max_length
self.device = torch.device(
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
if not self.labels_path.exists():
raise ClassifierError(
f"Failed to initialize classifier at labels loading stage: labels file not found at {self.labels_path}"
)
if not self.model_dir.exists():
raise ClassifierError(
f"Failed to initialize classifier at model loading stage: model directory not found at {self.model_dir}"
)
try:
with self.labels_path.open("r", encoding="utf-8") as fh:
self.label2id = json.load(fh)
except Exception as exc:
raise ClassifierError(
f"Failed to initialize classifier at labels loading stage: {exc}"
) from exc
if not isinstance(self.label2id, dict) or not self.label2id:
raise ClassifierError(
"Failed to initialize classifier at labels loading stage: label mapping is empty or invalid"
)
self.id2label = {idx: label for label, idx in self.label2id.items()}
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir)
self.model.to(self.device)
self.model.eval()
except Exception as exc:
raise ClassifierError(
f"Failed to initialize classifier at model loading stage: {exc}"
) from exc
@property
def labels(self) -> list[str]:
return [self.id2label[idx] for idx in sorted(self.id2label)]
@staticmethod
def build_input_text(title: str, abstract: str) -> str:
clean_title = " ".join(title.split()).strip()
clean_abstract = " ".join(abstract.split()).strip()
if clean_abstract:
return f"title: {clean_title} abstract: {clean_abstract}"
return f"title: {clean_title}"
def predict(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
if not isinstance(title, str):
raise ValueError("Input validation error in predict: title must be a string.")
if not isinstance(abstract, str):
raise ValueError("Input validation error in predict: abstract must be a string.")
if not title.strip() and not abstract.strip():
raise ValueError(
"Input validation error in predict: please provide at least a title or an abstract."
)
text = self.build_input_text(title=title, abstract=abstract)
try:
encoded = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
encoded = {key: value.to(self.device) for key, value in encoded.items()}
except Exception as exc:
raise ClassifierError(f"Failed during tokenization stage: {exc}") from exc
try:
with torch.inference_mode():
logits = self.model(**encoded).logits
probabilities = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu()
except Exception as exc:
raise ClassifierError(f"Failed during model inference stage: {exc}") from exc
results: list[dict[str, float | str]] = []
try:
for class_id, probability in enumerate(probabilities.tolist()):
results.append(
{
"label": self.id2label[class_id],
"probability": float(probability),
}
)
except Exception as exc:
raise ClassifierError(f"Failed during prediction formatting stage: {exc}") from exc
results.sort(key=lambda item: item["probability"], reverse=True)
return results
@staticmethod
def select_top_95(
predictions: list[dict[str, float | str]],
) -> list[dict[str, float | str]]:
return ArticleClassifier.select_top_k_by_probability_mass(
predictions=predictions,
threshold=0.95,
)
@staticmethod
def select_top_k_by_probability_mass(
predictions: list[dict[str, float | str]],
threshold: float = 0.95,
) -> list[dict[str, float | str]]:
if not 0 < threshold <= 1:
raise ValueError("Probability mass threshold must be in the interval (0, 1].")
cumulative_probability = 0.0
top_predictions: list[dict[str, float | str]] = []
for item in predictions:
top_predictions.append(item)
cumulative_probability += float(item["probability"])
if cumulative_probability >= 0.95:
break
return top_predictions
def predict_top_95(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
predictions = self.predict(title=title, abstract=abstract)
return self.select_top_95(predictions)
|