TochkaMikelya's picture
Update app.py
495f787 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re
LABELS = ["astro-ph", "cond-mat", "cs", "econ", "eess", "gr-qc",
"hep-ex", "hep-lat", "hep-ph", "hep-th", "math", "math-ph",
"nlin", "nucl-ex", "nucl-th", "physics", "q-bio", "q-fin",
"quant-ph", "stat"]
MODEL_NAME = "TochkaMikelya/arxiv-classifier"
tokenizer, model = None, None
def load_model():
global tokenizer, model
if model is None:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=len(LABELS),
ignore_mismatched_sizes=True,
)
model.eval()
def clean_text(text):
text = text.replace('\n', ' ')
text = re.sub(r'\s+', ' ', text)
return text.strip()
def predict(title: str, abstract: str):
if not title.strip():
raise gr.Error("Нужен как минимум заголовок статьи, введите пожалуйста")
load_model()
title = clean_text(title)
abstract = clean_text(abstract)
text = f"{title} [SEP] {abstract}" if abstract.strip() else title
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze().tolist()
sorted_idx = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
top_labels, cumsum = [], 0.0
for i in sorted_idx:
top_labels.append((LABELS[i], probs[i]))
cumsum += probs[i]
if cumsum >= 0.95:
break
return {label: prob for label, prob in top_labels}
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Название статьи", placeholder="Заголовок статьи: "),
gr.Textbox(label="Abstract", placeholder="Abstract статьи: ", lines=5),
],
outputs=gr.Label(label="Вероятности классов топ 95%"),
title="Классификатор статей",
description="Введи название и abstract статьи, классификация тэга статьи из arxiv",
flagging_mode="never",
)
demo.launch()