| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch |
| import torch.nn.functional as F |
| import re |
|
|
| LABELS = ["astro-ph", "cond-mat", "cs", "econ", "eess", "gr-qc", |
| "hep-ex", "hep-lat", "hep-ph", "hep-th", "math", "math-ph", |
| "nlin", "nucl-ex", "nucl-th", "physics", "q-bio", "q-fin", |
| "quant-ph", "stat"] |
|
|
| MODEL_NAME = "TochkaMikelya/arxiv-classifier" |
|
|
| tokenizer, model = None, None |
|
|
|
|
| def load_model(): |
| global tokenizer, model |
| if model is None: |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| MODEL_NAME, |
| num_labels=len(LABELS), |
| ignore_mismatched_sizes=True, |
| ) |
| model.eval() |
|
|
| def clean_text(text): |
| text = text.replace('\n', ' ') |
| text = re.sub(r'\s+', ' ', text) |
| return text.strip() |
|
|
| def predict(title: str, abstract: str): |
| if not title.strip(): |
| raise gr.Error("Нужен как минимум заголовок статьи, введите пожалуйста") |
|
|
| load_model() |
|
|
| title = clean_text(title) |
| abstract = clean_text(abstract) |
| |
| text = f"{title} [SEP] {abstract}" if abstract.strip() else title |
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| padding=True, |
| ) |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = F.softmax(logits, dim=-1).squeeze().tolist() |
|
|
| sorted_idx = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) |
| top_labels, cumsum = [], 0.0 |
| for i in sorted_idx: |
| top_labels.append((LABELS[i], probs[i])) |
| cumsum += probs[i] |
| if cumsum >= 0.95: |
| break |
|
|
| return {label: prob for label, prob in top_labels} |
|
|
| demo = gr.Interface( |
| fn=predict, |
| inputs=[ |
| gr.Textbox(label="Название статьи", placeholder="Заголовок статьи: "), |
| gr.Textbox(label="Abstract", placeholder="Abstract статьи: ", lines=5), |
| ], |
| outputs=gr.Label(label="Вероятности классов топ 95%"), |
| title="Классификатор статей", |
| description="Введи название и abstract статьи, классификация тэга статьи из arxiv", |
| flagging_mode="never", |
| ) |
| demo.launch() |