Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| import torch | |
| from transformers import AutoModelForSequenceClassification | |
| import pandas as pd | |
| from typing import Dict | |
| from transformers import RobertaTokenizer | |
| from typing import List | |
| USED_MODEL = "distilroberta-base" | |
| # кэширование | |
| def load_model(): | |
| # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное | |
| arxiv_topics_df = pd.read_csv('arxiv_topics.csv') | |
| category_to_index = {} | |
| current_index = 0 | |
| for i, row in arxiv_topics_df.iterrows(): | |
| category = row['category'] | |
| if category not in category_to_index: | |
| category_to_index[category] = current_index | |
| current_index += 1 | |
| index_to_category = {value: key for key, value in category_to_index.items()} | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| f"bumchik2/train-{USED_MODEL}-tags-classification", | |
| problem_type="multi_label_classification", | |
| num_labels=len(category_to_index), | |
| id2label=index_to_category, | |
| label2id=category_to_index | |
| ) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| def get_tokenizer(): | |
| return RobertaTokenizer.from_pretrained(USED_MODEL) | |
| def tokenize_function(text): | |
| tokenizer = get_tokenizer() | |
| return tokenizer(text, padding="max_length", truncation=True) | |
| def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]: | |
| # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное | |
| arxiv_topics_df = pd.read_csv('arxiv_topics.csv') | |
| category_to_index = {} | |
| current_index = 0 | |
| for i, row in arxiv_topics_df.iterrows(): | |
| category = row['category'] | |
| if category not in category_to_index: | |
| category_to_index[category] = current_index | |
| current_index += 1 | |
| index_to_category = {value: key for key, value in category_to_index.items()} | |
| text = f'{title} $ {summary or ""}' | |
| category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits | |
| sigmoid = torch.nn.Sigmoid() | |
| category_probs = sigmoid(category_logits.squeeze().cpu()).numpy() | |
| category_probs /= category_probs.sum() | |
| category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])} | |
| for index in range(len(index_to_category)): | |
| category_probs_dict[index_to_category[index]] += float(category_probs[index]) | |
| return category_probs_dict | |
| def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]: | |
| current_p = 0 | |
| probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1] | |
| current_index = 0 | |
| answer = [] | |
| while current_p <= target_probability: | |
| current_p += probs_list[current_index][0] | |
| if not print_probabilities: | |
| answer.append(probs_list[current_index][1]) | |
| else: | |
| answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})') | |
| current_index += 1 | |
| if current_index >= len(probs_list): | |
| break | |
| return answer | |
| title = st.text_input("Article title", value="Enter title here...") | |
| summary = st.text_input("Article summary", value="Enter summary here...") | |
| need_to_print_probabilities = st.radio("Need to print probabilities: ", ('Yes', 'No'), index=0) | |
| st.session_state['need_to_print_probabilities'] = need_to_print_probabilities | |
| target_probability = st.slider("Select minimum probability sum", 0.0, 1.0, step=0.01, value=0.95) | |
| st.session_state['target_probability'] = 'target_probability' | |
| if title or summary: | |
| category_probs_dict = get_category_probs_dict(model=model, title=title, summary=summary or '') | |
| result = get_most_probable_keys(probs_dict=category_probs_dict, target_probability=target_probability, print_probabilities=need_to_print_probabilities=='Yes') | |
| result_str = " \n ".join(result) | |
| st.write(result_str) | |