Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.serialization | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| def load_model(): | |
| #trained_model = 'TinyBERT_cls_model.pt' | |
| #base_model = 'huawei-noah/TinyBERT_General_4L_312D' | |
| trained_model = 'distilbert-base_cls_model.pt' | |
| base_model = 'distilbert-base-uncased' | |
| checkpoint = torch.load(trained_model, | |
| map_location='cpu', | |
| weights_only=False) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| base_model, | |
| num_labels=len(checkpoint['idx_to_category']) | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| tokenizer = checkpoint['tokenizer'] | |
| idx_to_category = checkpoint['idx_to_category'] | |
| return model, tokenizer, idx_to_category | |
| def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.95): | |
| text = f"{title} /n {abstract}" if abstract else title | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| results = [] | |
| cumulative_prob = 0.0 | |
| for i in range(len(sorted_probs)): | |
| if cumulative_prob >= threshold: | |
| break | |
| prob = sorted_probs[i].item() | |
| results.append({ | |
| "category": idx_to_category[sorted_indices[i].item()], | |
| "probability": prob | |
| }) | |
| cumulative_prob += prob | |
| return results, cumulative_prob | |
| def main(): | |
| model, tokenizer, idx_to_category = load_model() | |
| st.title("Классификатор статей") | |
| st.markdown("Определение тематики научных статей по названию и аннотации") | |
| with st.form("input_form"): | |
| title = st.text_input("Название статьи*", placeholder="Введите название...") | |
| abstract = st.text_area("Аннотация", placeholder="Введите текст аннотации (необязательно)...", height=150) | |
| submitted = st.form_submit_button("Классифицировать") | |
| if submitted and not title: | |
| st.error("Пожалуйста, введите название статьи") | |
| if submitted and title: | |
| with st.spinner("Анализируем статью..."): | |
| results, total_prob = predict( | |
| title=title, | |
| abstract=abstract, | |
| model=model, | |
| tokenizer=tokenizer, | |
| idx_to_category=idx_to_category | |
| ) | |
| st.success("Результаты классификации:") | |
| st.metric("Общая вероятность", f"{total_prob*100:.1f}%") | |
| for i, res in enumerate(results, 1): | |
| col1, col2 = st.columns([1, 4]) | |
| with col1: | |
| st.metric(f"Топ {i}", f"{res['probability']*100:.1f}%") | |
| with col2: | |
| st.progress(res['probability'], text=res['category']) | |
| if __name__ == "__main__": | |
| main() | |