Spaces:
Sleeping
Sleeping
File size: 5,527 Bytes
5e2ba76 675f8ad 5e2ba76 675f8ad 5e2ba76 675f8ad 5e2ba76 675f8ad 5e2ba76 675f8ad 5e2ba76 675f8ad 5e2ba76 675f8ad 5e2ba76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
from scipy.special import softmax
import gradio as gr
import torch
import matplotlib.pyplot as plt
import io
import base64
model_name = "enigmaize/arxiv-nlp_project-scibert"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Имена классов (в порядке, соответствующем вашему `num_labels`)
labels = ['math.AC', 'cs.CV', 'cs.AI', 'cs.SY', 'math.GR', 'cs.CE', 'cs.PL', 'cs.IT', 'cs.DS', 'cs.NE', 'math.ST']
def classify_text(text):
if not text.strip():
# Возвращаем пустой результат, если текст пустой
return {label: 0.0 for label in labels}, None
# Токенизация текста
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Инференс
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Применение softmax для получения вероятностей
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
# Создание словаря метка -> вероятность
results = {label: prob for label, prob in zip(labels, probabilities)}
# Сортировка по вероятности (по убыванию)
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
# --- Создание диаграммы ---
top_k = 5
top_labels = list(sorted_results.keys())[:top_k]
top_probs = list(sorted_results.values())[:top_k]
fig, ax = plt.subplots(figsize=(8, 4))
bars = ax.barh(top_labels, top_probs, color=['#4c72b0', '#dd8452', '#55a868', '#c44e52', '#8172b3'])
ax.set_xlabel('Probability')
ax.set_title('Top 5 Predicted Categories')
ax.set_xlim(0, 1)
# Добавление числовых значений на барах
for bar, prob in zip(bars, top_probs):
width = bar.get_width()
ax.text(width, bar.get_y() + bar.get_height()/2, f'{prob:.3f}',
va='center', ha='left', fontsize=10)
plt.tight_layout()
# Сохраняем диаграмму в буфер
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig) # Закрываем фигуру, чтобы освободить память
chart_html = f'<img src="data:image/png;base64,{img_base64}" alt="Prediction Chart" style="width:100%;">'
return sorted_results, chart_html
# --- HTML для кастомного стиля ---
custom_css = """
body {
background-color: #f0f4f8;
}
.gradio-container {
max-width: 900px;
margin: auto;
padding-top: 20px;
padding-bottom: 20px;
background: white;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
h1 {
color: #2c3e50;
text-align: center;
font-family: 'Arial', sans-serif;
}
h3 {
color: #34495e;
}
label {
font-weight: bold;
color: #2c3e50;
}
"""
# --- HTML для информации о модели ---
model_info_html = """
<div style="background-color: #ecf0f1; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
<h3>About the Model</h3>
<p>This classifier uses a <strong>SciBERT</strong> model fine-tuned on the <a href="https://huggingface.co/datasets/ccdv/arxiv-classification" target="_blank">arXiv Classification dataset</a>.</p>
<p>It predicts one of 11 categories related to Computer Science and Mathematics.</p>
<p>For best results, input the abstract of a scientific paper.</p>
</div>
"""
# --- HTML для описания ---
description_html = """
<p style="font-size: 1.1em; text-align: center;">Enter the abstract of a scientific paper below, and the model will predict its arXiv category.</p>
"""
# Создание интерфейса Gradio
interface = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(
lines=10,
placeholder="Paste the abstract of a scientific paper here...",
label="Paper Abstract",
elem_classes="textbox_custom"
),
outputs=[
gr.Label(num_top_classes=5, label="Prediction Probabilities"),
gr.HTML(label="Prediction Chart")
],
title="🔬 ArXiv Paper Classifier (SciBERT)",
description=description_html,
article=model_info_html,
examples=[
[
"We propose a novel deep learning approach for image recognition using convolutional neural networks. Our method achieves state-of-the-art performance on the ImageNet benchmark, surpassing previous results by a significant margin through architectural innovations and improved training procedures."
],
[
"We analyze the computational complexity of algorithms for sorting and searching. Specifically, we present a new variant of merge sort that reduces the number of comparisons in the average case. We also discuss the implications for cache performance and practical implementations."
],
[
"This paper presents a statistical method for analyzing the spread of infectious diseases in populations. Using a modified SIR model with time-dependent transmission rates, we simulate the effects of various intervention strategies on disease dynamics."
]
],
css=custom_css
)
interface.launch() |