File size: 5,527 Bytes
5e2ba76
 
 
 
675f8ad
 
 
5e2ba76
675f8ad
5e2ba76
 
 
 
 
 
 
675f8ad
 
 
 
5e2ba76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675f8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e2ba76
675f8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e2ba76
 
 
675f8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
5e2ba76
675f8ad
 
 
 
 
 
 
 
 
 
 
5e2ba76
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from scipy.special import softmax
import gradio as gr
import torch
import matplotlib.pyplot as plt
import io
import base64

model_name = "enigmaize/arxiv-nlp_project-scibert"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Имена классов (в порядке, соответствующем вашему `num_labels`)
labels = ['math.AC', 'cs.CV', 'cs.AI', 'cs.SY', 'math.GR', 'cs.CE', 'cs.PL', 'cs.IT', 'cs.DS', 'cs.NE', 'math.ST']

def classify_text(text):
    if not text.strip():
        # Возвращаем пустой результат, если текст пустой
        return {label: 0.0 for label in labels}, None

    # Токенизация текста
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Инференс
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Применение softmax для получения вероятностей
    probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()

    # Создание словаря метка -> вероятность
    results = {label: prob for label, prob in zip(labels, probabilities)}

    # Сортировка по вероятности (по убыванию)
    sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))

    # --- Создание диаграммы ---
    top_k = 5
    top_labels = list(sorted_results.keys())[:top_k]
    top_probs = list(sorted_results.values())[:top_k]

    fig, ax = plt.subplots(figsize=(8, 4))
    bars = ax.barh(top_labels, top_probs, color=['#4c72b0', '#dd8452', '#55a868', '#c44e52', '#8172b3'])
    ax.set_xlabel('Probability')
    ax.set_title('Top 5 Predicted Categories')
    ax.set_xlim(0, 1)

    # Добавление числовых значений на барах
    for bar, prob in zip(bars, top_probs):
        width = bar.get_width()
        ax.text(width, bar.get_y() + bar.get_height()/2, f'{prob:.3f}', 
                va='center', ha='left', fontsize=10)

    plt.tight_layout()

    # Сохраняем диаграмму в буфер
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    img_base64 = base64.b64encode(buf.read()).decode('utf-8')
    plt.close(fig) # Закрываем фигуру, чтобы освободить память

    chart_html = f'<img src="data:image/png;base64,{img_base64}" alt="Prediction Chart" style="width:100%;">'

    return sorted_results, chart_html

# --- HTML для кастомного стиля ---
custom_css = """
body {
    background-color: #f0f4f8;
}
.gradio-container {
    max-width: 900px;
    margin: auto;
    padding-top: 20px;
    padding-bottom: 20px;
    background: white;
    border-radius: 10px;
    box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
h1 {
    color: #2c3e50;
    text-align: center;
    font-family: 'Arial', sans-serif;
}
h3 {
    color: #34495e;
}
label {
    font-weight: bold;
    color: #2c3e50;
}
"""

# --- HTML для информации о модели ---
model_info_html = """
<div style="background-color: #ecf0f1; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
    <h3>About the Model</h3>
    <p>This classifier uses a <strong>SciBERT</strong> model fine-tuned on the <a href="https://huggingface.co/datasets/ccdv/arxiv-classification" target="_blank">arXiv Classification dataset</a>.</p>
    <p>It predicts one of 11 categories related to Computer Science and Mathematics.</p>
    <p>For best results, input the abstract of a scientific paper.</p>
</div>
"""

# --- HTML для описания ---
description_html = """
<p style="font-size: 1.1em; text-align: center;">Enter the abstract of a scientific paper below, and the model will predict its arXiv category.</p>
"""

# Создание интерфейса Gradio
interface = gr.Interface(
    fn=classify_text,
    inputs=gr.Textbox(
        lines=10,
        placeholder="Paste the abstract of a scientific paper here...",
        label="Paper Abstract",
        elem_classes="textbox_custom"
    ),
    outputs=[
        gr.Label(num_top_classes=5, label="Prediction Probabilities"),
        gr.HTML(label="Prediction Chart")
    ],
    title="🔬 ArXiv Paper Classifier (SciBERT)",
    description=description_html,
    article=model_info_html,
    examples=[
        [
            "We propose a novel deep learning approach for image recognition using convolutional neural networks. Our method achieves state-of-the-art performance on the ImageNet benchmark, surpassing previous results by a significant margin through architectural innovations and improved training procedures."
        ],
        [
            "We analyze the computational complexity of algorithms for sorting and searching. Specifically, we present a new variant of merge sort that reduces the number of comparisons in the average case. We also discuss the implications for cache performance and practical implementations."
        ],
        [
            "This paper presents a statistical method for analyzing the spread of infectious diseases in populations. Using a modified SIR model with time-dependent transmission rates, we simulate the effects of various intervention strategies on disease dynamics."
        ]
    ],
    css=custom_css
)

interface.launch()