File size: 2,414 Bytes
3da2267
 
 
 
a3b0757
3da2267
45757d0
 
 
 
3da2267
45757d0
3da2267
 
 
45757d0
3da2267
 
 
 
 
 
 
 
 
 
 
a3b0757
 
 
 
3da2267
 
 
89e0dae
3da2267
 
 
a3b0757
 
 
3da2267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89e0dae
 
3da2267
89e0dae
 
 
7bbd71e
3da2267
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re

LABELS = ["astro-ph", "cond-mat", "cs", "econ", "eess", "gr-qc",
          "hep-ex", "hep-lat", "hep-ph", "hep-th", "math", "math-ph",
          "nlin", "nucl-ex", "nucl-th", "physics", "q-bio", "q-fin",
          "quant-ph", "stat"]

MODEL_NAME = "TochkaMikelya/arxiv-classifier"

tokenizer, model = None, None


def load_model():
    global tokenizer, model
    if model is None:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(
            MODEL_NAME,
            num_labels=len(LABELS),
            ignore_mismatched_sizes=True,
        )
        model.eval()

def clean_text(text):
    text = text.replace('\n', ' ')
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def predict(title: str, abstract: str):
    if not title.strip():
        raise gr.Error("Нужен как минимум заголовок статьи, введите пожалуйста")

    load_model()

    title = clean_text(title)
    abstract = clean_text(abstract)
    
    text = f"{title} [SEP] {abstract}" if abstract.strip() else title
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding=True,
    )
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = F.softmax(logits, dim=-1).squeeze().tolist()

    sorted_idx = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
    top_labels, cumsum = [], 0.0
    for i in sorted_idx:
        top_labels.append((LABELS[i], probs[i]))
        cumsum += probs[i]
        if cumsum >= 0.95:
            break

    return {label: prob for label, prob in top_labels}

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox(label="Название статьи", placeholder="Заголовок статьи: "),
        gr.Textbox(label="Abstract", placeholder="Abstract статьи: ", lines=5),
    ],
    outputs=gr.Label(label="Вероятности классов топ 95%"),
    title="Классификатор статей",
    description="Введи название и abstract статьи, классификация тэга статьи из arxiv",
    flagging_mode="never",
)
demo.launch()