Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from scipy.special import softmax | |
| import gradio as gr | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| model_name = "enigmaize/arxiv-nlp_project-scibert" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # Имена классов (в порядке, соответствующем вашему `num_labels`) | |
| labels = ['math.AC', 'cs.CV', 'cs.AI', 'cs.SY', 'math.GR', 'cs.CE', 'cs.PL', 'cs.IT', 'cs.DS', 'cs.NE', 'math.ST'] | |
| def classify_text(text): | |
| if not text.strip(): | |
| # Возвращаем пустой результат, если текст пустой | |
| return {label: 0.0 for label in labels}, None | |
| # Токенизация текста | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| # Инференс | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Применение softmax для получения вероятностей | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() | |
| # Создание словаря метка -> вероятность | |
| results = {label: prob for label, prob in zip(labels, probabilities)} | |
| # Сортировка по вероятности (по убыванию) | |
| sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) | |
| # --- Создание диаграммы --- | |
| top_k = 5 | |
| top_labels = list(sorted_results.keys())[:top_k] | |
| top_probs = list(sorted_results.values())[:top_k] | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| bars = ax.barh(top_labels, top_probs, color=['#4c72b0', '#dd8452', '#55a868', '#c44e52', '#8172b3']) | |
| ax.set_xlabel('Probability') | |
| ax.set_title('Top 5 Predicted Categories') | |
| ax.set_xlim(0, 1) | |
| # Добавление числовых значений на барах | |
| for bar, prob in zip(bars, top_probs): | |
| width = bar.get_width() | |
| ax.text(width, bar.get_y() + bar.get_height()/2, f'{prob:.3f}', | |
| va='center', ha='left', fontsize=10) | |
| plt.tight_layout() | |
| # Сохраняем диаграмму в буфер | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
| plt.close(fig) # Закрываем фигуру, чтобы освободить память | |
| chart_html = f'<img src="data:image/png;base64,{img_base64}" alt="Prediction Chart" style="width:100%;">' | |
| return sorted_results, chart_html | |
| # --- HTML для кастомного стиля --- | |
| custom_css = """ | |
| body { | |
| background-color: #f0f4f8; | |
| } | |
| .gradio-container { | |
| max-width: 900px; | |
| margin: auto; | |
| padding-top: 20px; | |
| padding-bottom: 20px; | |
| background: white; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| } | |
| h1 { | |
| color: #2c3e50; | |
| text-align: center; | |
| font-family: 'Arial', sans-serif; | |
| } | |
| h3 { | |
| color: #34495e; | |
| } | |
| label { | |
| font-weight: bold; | |
| color: #2c3e50; | |
| } | |
| """ | |
| # --- HTML для информации о модели --- | |
| model_info_html = """ | |
| <div style="background-color: #ecf0f1; padding: 15px; border-radius: 8px; margin-bottom: 20px;"> | |
| <h3>About the Model</h3> | |
| <p>This classifier uses a <strong>SciBERT</strong> model fine-tuned on the <a href="https://huggingface.co/datasets/ccdv/arxiv-classification" target="_blank">arXiv Classification dataset</a>.</p> | |
| <p>It predicts one of 11 categories related to Computer Science and Mathematics.</p> | |
| <p>For best results, input the abstract of a scientific paper.</p> | |
| </div> | |
| """ | |
| # --- HTML для описания --- | |
| description_html = """ | |
| <p style="font-size: 1.1em; text-align: center;">Enter the abstract of a scientific paper below, and the model will predict its arXiv category.</p> | |
| """ | |
| # Создание интерфейса Gradio | |
| interface = gr.Interface( | |
| fn=classify_text, | |
| inputs=gr.Textbox( | |
| lines=10, | |
| placeholder="Paste the abstract of a scientific paper here...", | |
| label="Paper Abstract", | |
| elem_classes="textbox_custom" | |
| ), | |
| outputs=[ | |
| gr.Label(num_top_classes=5, label="Prediction Probabilities"), | |
| gr.HTML(label="Prediction Chart") | |
| ], | |
| title="🔬 ArXiv Paper Classifier (SciBERT)", | |
| description=description_html, | |
| article=model_info_html, | |
| examples=[ | |
| [ | |
| "We propose a novel deep learning approach for image recognition using convolutional neural networks. Our method achieves state-of-the-art performance on the ImageNet benchmark, surpassing previous results by a significant margin through architectural innovations and improved training procedures." | |
| ], | |
| [ | |
| "We analyze the computational complexity of algorithms for sorting and searching. Specifically, we present a new variant of merge sort that reduces the number of comparisons in the average case. We also discuss the implications for cache performance and practical implementations." | |
| ], | |
| [ | |
| "This paper presents a statistical method for analyzing the spread of infectious diseases in populations. Using a modified SIR model with time-dependent transmission rates, we simulate the effects of various intervention strategies on disease dynamics." | |
| ] | |
| ], | |
| css=custom_css | |
| ) | |
| interface.launch() |