File size: 2,269 Bytes
2c26172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st

from transformers import AutoTokenizer, DistilBertForSequenceClassification
import torch
from torch.nn.functional import softmax



base_model_name = 'distilbert-base-uncased'

@st.cache
def load_tags_info():
   
    id_to_description = {}
    with open('tags.txt', 'r') as file:
        i = 0
        for line in file:
            
            description = line[:-1]
            
            id_to_description[i] = description
            
            i += 1
    
    return id_to_description

id_to_description = load_tags_info()

@st.cache
def load_model():
    return DistilBertForSequenceClassification.from_pretrained('./')

def load_tokenizer():
    return AutoTokenizer.from_pretrained('distilbert-base-uncased')

def top_xx(preds, xx=95):
    tops = torch.argsort(preds, 1, descending=True)
    total = 0
    index = 0
    result = []
    while total < xx / 100:
        next_id = tops[0, index].item()
        total += preds[0, next_id]
        index += 1
        result.append(id_to_description[next_id])
    return result

model = load_model()
tokenizer = load_tokenizer()
temperature = 1

st.title('ArXivTager')
st.caption('Напишите тему (Title) и параграф из статьи (Abstract). Поля должны быть ЗАПОЛНЕНЫ для корректной классификации.')

with st.form("ArXivTager"):
    
    title = st.text_area(label='Title', height=30)
    abstract = st.text_area(label='Abstract (optional)', height=200)
    st.caption('ВЫВОД: набор тем в порядке уменьшения вероятностей.')

    submitted = st.form_submit_button("Get tags")
    if submitted:
        if title == '':
            st.markdown("Нужно хоть что-то написать")
        else:
            prompt = 'Title: ' + title + ' Abstract: ' + abstract
            tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
            preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
            tags = top_xx(preds)
            other_tags = []
            st.header('Inferred tags:')
            for i, tag_data in enumerate(tags):
                    st.markdown('* ' +  tag_data)