from transformers import AutoTokenizer, AutoModelForSequenceClassification from scipy.special import softmax import gradio as gr import torch import matplotlib.pyplot as plt import io import base64 model_name = "enigmaize/arxiv-nlp_project-scibert" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) # Имена классов (в порядке, соответствующем вашему `num_labels`) labels = ['math.AC', 'cs.CV', 'cs.AI', 'cs.SY', 'math.GR', 'cs.CE', 'cs.PL', 'cs.IT', 'cs.DS', 'cs.NE', 'math.ST'] def classify_text(text): if not text.strip(): # Возвращаем пустой результат, если текст пустой return {label: 0.0 for label in labels}, None # Токенизация текста inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Инференс with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Применение softmax для получения вероятностей probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() # Создание словаря метка -> вероятность results = {label: prob for label, prob in zip(labels, probabilities)} # Сортировка по вероятности (по убыванию) sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) # --- Создание диаграммы --- top_k = 5 top_labels = list(sorted_results.keys())[:top_k] top_probs = list(sorted_results.values())[:top_k] fig, ax = plt.subplots(figsize=(8, 4)) bars = ax.barh(top_labels, top_probs, color=['#4c72b0', '#dd8452', '#55a868', '#c44e52', '#8172b3']) ax.set_xlabel('Probability') ax.set_title('Top 5 Predicted Categories') ax.set_xlim(0, 1) # Добавление числовых значений на барах for bar, prob in zip(bars, top_probs): width = bar.get_width() ax.text(width, bar.get_y() + bar.get_height()/2, f'{prob:.3f}', va='center', ha='left', fontsize=10) plt.tight_layout() # Сохраняем диаграмму в буфер buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode('utf-8') plt.close(fig) # Закрываем фигуру, чтобы освободить память chart_html = f'Prediction Chart' return sorted_results, chart_html # --- HTML для кастомного стиля --- custom_css = """ body { background-color: #f0f4f8; } .gradio-container { max-width: 900px; margin: auto; padding-top: 20px; padding-bottom: 20px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); } h1 { color: #2c3e50; text-align: center; font-family: 'Arial', sans-serif; } h3 { color: #34495e; } label { font-weight: bold; color: #2c3e50; } """ # --- HTML для информации о модели --- model_info_html = """

About the Model

This classifier uses a SciBERT model fine-tuned on the arXiv Classification dataset.

It predicts one of 11 categories related to Computer Science and Mathematics.

For best results, input the abstract of a scientific paper.

""" # --- HTML для описания --- description_html = """

Enter the abstract of a scientific paper below, and the model will predict its arXiv category.

""" # Создание интерфейса Gradio interface = gr.Interface( fn=classify_text, inputs=gr.Textbox( lines=10, placeholder="Paste the abstract of a scientific paper here...", label="Paper Abstract", elem_classes="textbox_custom" ), outputs=[ gr.Label(num_top_classes=5, label="Prediction Probabilities"), gr.HTML(label="Prediction Chart") ], title="🔬 ArXiv Paper Classifier (SciBERT)", description=description_html, article=model_info_html, examples=[ [ "We propose a novel deep learning approach for image recognition using convolutional neural networks. Our method achieves state-of-the-art performance on the ImageNet benchmark, surpassing previous results by a significant margin through architectural innovations and improved training procedures." ], [ "We analyze the computational complexity of algorithms for sorting and searching. Specifically, we present a new variant of merge sort that reduces the number of comparisons in the average case. We also discuss the implications for cache performance and practical implementations." ], [ "This paper presents a statistical method for analyzing the spread of infectious diseases in populations. Using a modified SIR model with time-dependent transmission rates, we simulate the effects of various intervention strategies on disease dynamics." ] ], css=custom_css ) interface.launch()