import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import torch.nn.functional as F import re LABELS = ["astro-ph", "cond-mat", "cs", "econ", "eess", "gr-qc", "hep-ex", "hep-lat", "hep-ph", "hep-th", "math", "math-ph", "nlin", "nucl-ex", "nucl-th", "physics", "q-bio", "q-fin", "quant-ph", "stat"] MODEL_NAME = "TochkaMikelya/arxiv-classifier" tokenizer, model = None, None def load_model(): global tokenizer, model if model is None: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=len(LABELS), ignore_mismatched_sizes=True, ) model.eval() def clean_text(text): text = text.replace('\n', ' ') text = re.sub(r'\s+', ' ', text) return text.strip() def predict(title: str, abstract: str): if not title.strip(): raise gr.Error("Нужен как минимум заголовок статьи, введите пожалуйста") load_model() title = clean_text(title) abstract = clean_text(abstract) text = f"{title} [SEP] {abstract}" if abstract.strip() else title inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True, ) with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=-1).squeeze().tolist() sorted_idx = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) top_labels, cumsum = [], 0.0 for i in sorted_idx: top_labels.append((LABELS[i], probs[i])) cumsum += probs[i] if cumsum >= 0.95: break return {label: prob for label, prob in top_labels} demo = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Название статьи", placeholder="Заголовок статьи: "), gr.Textbox(label="Abstract", placeholder="Abstract статьи: ", lines=5), ], outputs=gr.Label(label="Вероятности классов топ 95%"), title="Классификатор статей", description="Введи название и abstract статьи, классификация тэга статьи из arxiv", flagging_mode="never", ) demo.launch()