Spaces:
Running
Running
| import json | |
| import re | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from config import MAX_LENGTH, MODEL_DIR, get_tag_name | |
| def clean_text(text): | |
| return re.sub(r"\s+", " ", text.strip()) | |
| def format_input(title, abstract=None): | |
| title = clean_text(title) | |
| if abstract and abstract.strip(): | |
| return f"[TITLE] {title} [SEP] [ABSTRACT] {clean_text(abstract)}" | |
| return f"[TITLE] {title}" | |
| class PaperClassifier: | |
| def __init__(self, model_path=None): | |
| if model_path is None: | |
| model_path = str(MODEL_DIR / "final") | |
| self.device = torch.device( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| with open(Path(model_path) / "label_mapping.json") as f: | |
| mapping = json.load(f) | |
| self.id2label = mapping["id2label"] | |
| self.label_names = mapping.get("label_names", {}) | |
| def predict(self, title, abstract=None, threshold=0.95): | |
| text = format_input(title, abstract) | |
| inputs = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| logits = self.model(**inputs).logits[0].cpu().numpy() | |
| probs = np.exp(logits - logits.max()) | |
| probs /= probs.sum() | |
| results = [] | |
| cumulative = 0.0 | |
| for idx in np.argsort(probs)[::-1]: | |
| tag = self.id2label[str(idx)] | |
| prob = float(probs[idx]) | |
| results.append({ | |
| "tag": tag, | |
| "name": self.label_names.get(tag, get_tag_name(tag)), | |
| "probability": prob, | |
| }) | |
| cumulative += prob | |
| if cumulative >= threshold: | |
| break | |
| return results | |