| import re | |
| import json | |
| import torch | |
| import unicodedata | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from huggingface_hub import hf_hub_download | |
| def load_model_and_tokenizer(model_path: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| from huggingface_hub import hf_hub_download | |
| id2label_path = hf_hub_download(repo_id=model_path, filename="id2label.json") | |
| with open(id2label_path) as f: | |
| id2label = json.load(f) | |
| label2id = {v: int(k) for k, v in id2label.items()} | |
| return model, tokenizer, id2label, label2id | |
| def preprocess_text(text: str) -> str: | |
| if not isinstance(text, str): | |
| return '' | |
| text = unicodedata.normalize("NFKC", text) | |
| text = re.sub(r'https?://\S+|www\.\S+', 'url', text) | |
| text = re.sub(r'\$.*?\$', 'math', text) | |
| text = re.sub(r'[^\w\s.,;:!?()\[\]{}\\/\'\"+-=*&^%$#@<>|~`]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| text = text.lower() | |
| return text[:1500] | |
| def preprocess_for_inference(title: str, abstract: str) -> str: | |
| title = '' if pd.isna(title) else str(title) | |
| abstract = '' if pd.isna(abstract) else str(abstract) | |
| text = f"Title: {title} Abstract: {abstract}" | |
| return preprocess_text(text) | |
| def predict_category(title: str, abstract: str, tokenizer, model, id2label): | |
| preprocessed_text = preprocess_for_inference(title, abstract) | |
| encoded = tokenizer( | |
| preprocessed_text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=128, | |
| return_tensors="pt" | |
| ) | |
| encoded = {key: val.to(model.device) for key, val in encoded.items()} | |
| with torch.no_grad(): | |
| pred = model(**encoded) | |
| predicted_class_id = torch.argmax(pred.logits, dim=1).item() | |
| return id2label[str(predicted_class_id)] |