File size: 1,898 Bytes
2e0fab3
 
 
 
 
 
b1863d7
2e0fab3
 
 
 
 
b1863d7
 
 
 
2e0fab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import re
import json
import torch
import unicodedata
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download


def load_model_and_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    from huggingface_hub import hf_hub_download

    id2label_path = hf_hub_download(repo_id=model_path, filename="id2label.json")
    with open(id2label_path) as f:
        id2label = json.load(f)
    label2id = {v: int(k) for k, v in id2label.items()}
    return model, tokenizer, id2label, label2id


def preprocess_text(text: str) -> str:
    if not isinstance(text, str):
        return ''
    text = unicodedata.normalize("NFKC", text)
    text = re.sub(r'https?://\S+|www\.\S+', 'url', text)
    text = re.sub(r'\$.*?\$', 'math', text)
    text = re.sub(r'[^\w\s.,;:!?()\[\]{}\\/\'\"+-=*&^%$#@<>|~`]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    text = text.lower()
    return text[:1500]


def preprocess_for_inference(title: str, abstract: str) -> str:
    title = '' if pd.isna(title) else str(title)
    abstract = '' if pd.isna(abstract) else str(abstract)
    text = f"Title: {title} Abstract: {abstract}"
    return preprocess_text(text)


def predict_category(title: str, abstract: str, tokenizer, model, id2label):
    preprocessed_text = preprocess_for_inference(title, abstract)
    encoded = tokenizer(
        preprocessed_text,
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    encoded = {key: val.to(model.device) for key, val in encoded.items()}
    with torch.no_grad():
        pred = model(**encoded)
        predicted_class_id = torch.argmax(pred.logits, dim=1).item()
    return id2label[str(predicted_class_id)]