enigmaize's picture
Update app.py
675f8ad verified
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from scipy.special import softmax
import gradio as gr
import torch
import matplotlib.pyplot as plt
import io
import base64
model_name = "enigmaize/arxiv-nlp_project-scibert"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Имена классов (в порядке, соответствующем вашему `num_labels`)
labels = ['math.AC', 'cs.CV', 'cs.AI', 'cs.SY', 'math.GR', 'cs.CE', 'cs.PL', 'cs.IT', 'cs.DS', 'cs.NE', 'math.ST']
def classify_text(text):
if not text.strip():
# Возвращаем пустой результат, если текст пустой
return {label: 0.0 for label in labels}, None
# Токенизация текста
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Инференс
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Применение softmax для получения вероятностей
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
# Создание словаря метка -> вероятность
results = {label: prob for label, prob in zip(labels, probabilities)}
# Сортировка по вероятности (по убыванию)
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
# --- Создание диаграммы ---
top_k = 5
top_labels = list(sorted_results.keys())[:top_k]
top_probs = list(sorted_results.values())[:top_k]
fig, ax = plt.subplots(figsize=(8, 4))
bars = ax.barh(top_labels, top_probs, color=['#4c72b0', '#dd8452', '#55a868', '#c44e52', '#8172b3'])
ax.set_xlabel('Probability')
ax.set_title('Top 5 Predicted Categories')
ax.set_xlim(0, 1)
# Добавление числовых значений на барах
for bar, prob in zip(bars, top_probs):
width = bar.get_width()
ax.text(width, bar.get_y() + bar.get_height()/2, f'{prob:.3f}',
va='center', ha='left', fontsize=10)
plt.tight_layout()
# Сохраняем диаграмму в буфер
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig) # Закрываем фигуру, чтобы освободить память
chart_html = f'<img src="data:image/png;base64,{img_base64}" alt="Prediction Chart" style="width:100%;">'
return sorted_results, chart_html
# --- HTML для кастомного стиля ---
custom_css = """
body {
background-color: #f0f4f8;
}
.gradio-container {
max-width: 900px;
margin: auto;
padding-top: 20px;
padding-bottom: 20px;
background: white;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
h1 {
color: #2c3e50;
text-align: center;
font-family: 'Arial', sans-serif;
}
h3 {
color: #34495e;
}
label {
font-weight: bold;
color: #2c3e50;
}
"""
# --- HTML для информации о модели ---
model_info_html = """
<div style="background-color: #ecf0f1; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
<h3>About the Model</h3>
<p>This classifier uses a <strong>SciBERT</strong> model fine-tuned on the <a href="https://huggingface.co/datasets/ccdv/arxiv-classification" target="_blank">arXiv Classification dataset</a>.</p>
<p>It predicts one of 11 categories related to Computer Science and Mathematics.</p>
<p>For best results, input the abstract of a scientific paper.</p>
</div>
"""
# --- HTML для описания ---
description_html = """
<p style="font-size: 1.1em; text-align: center;">Enter the abstract of a scientific paper below, and the model will predict its arXiv category.</p>
"""
# Создание интерфейса Gradio
interface = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(
lines=10,
placeholder="Paste the abstract of a scientific paper here...",
label="Paper Abstract",
elem_classes="textbox_custom"
),
outputs=[
gr.Label(num_top_classes=5, label="Prediction Probabilities"),
gr.HTML(label="Prediction Chart")
],
title="🔬 ArXiv Paper Classifier (SciBERT)",
description=description_html,
article=model_info_html,
examples=[
[
"We propose a novel deep learning approach for image recognition using convolutional neural networks. Our method achieves state-of-the-art performance on the ImageNet benchmark, surpassing previous results by a significant margin through architectural innovations and improved training procedures."
],
[
"We analyze the computational complexity of algorithms for sorting and searching. Specifically, we present a new variant of merge sort that reduces the number of comparisons in the average case. We also discuss the implications for cache performance and practical implementations."
],
[
"This paper presents a statistical method for analyzing the spread of infectious diseases in populations. Using a modified SIR model with time-dependent transmission rates, we simulate the effects of various intervention strategies on disease dynamics."
]
],
css=custom_css
)
interface.launch()