ilushado commited on
Commit
9e69dbf
·
1 Parent(s): ecf8794

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -29
app.py CHANGED
@@ -7,9 +7,8 @@ import transformers
7
  import json
8
  from torch.utils.data import Dataset, DataLoader
9
  from transformers import RobertaModel, RobertaTokenizer
10
- from transformers import AutoModel, DistilBertTokenizer
11
  import transformers
12
- from transformers import pipeline
13
 
14
  idx_to_tag = {0: 'cs',
15
  1: 'stat',
@@ -17,8 +16,10 @@ idx_to_tag = {0: 'cs',
17
  3: 'math',
18
  4: 'q-bio',
19
  5: 'eess',
20
- 6: 'economics, finances'
21
- }
 
 
22
 
23
 
24
  tag_to_idx = {'cs': 0,
@@ -27,12 +28,36 @@ tag_to_idx = {'cs': 0,
27
  'math': 3,
28
  'q-bio': 4,
29
  'eess': 5,
30
- 'economics, finances': 6
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
 
 
34
 
35
- model = pipeline('./model')
 
 
36
 
37
 
38
  st.markdown("### Угадыватель")
@@ -45,32 +70,46 @@ ans = None
45
 
46
 
47
  if st.button('Предположить'):
48
- inputs = tokenizer(title + ' : ' + abstract, return_tensors='pt')
49
- inputs['input_ids'] = inputs['input_ids']
50
- inputs['attention_mask'] = inputs['attention_mask']
51
-
52
- with torch.no_grad():
53
- logits = model(**inputs).logits
 
 
 
 
54
 
55
- idx = torch.nn.functional.softmax(logits[0], dim=0).argmax().item()
 
 
 
 
 
56
  st.markdown(f'{idx_to_tag[idx]}')
57
 
58
  if st.button("Посмотреть топ"):
59
- if not logits:
60
- inputs = tokenizer(title + ' : ' + abstract, return_tensors='pt')
61
- inputs['input_ids'] = inputs['input_ids']
62
- inputs['attention_mask'] = inputs['attention_mask']
63
-
64
- with torch.no_grad():
65
- logits = model(**inputs).logits
66
-
67
- idx = torch.nn.functional.softmax(logits[0], dim=0).argmax().item()
68
-
69
-
 
 
 
 
 
70
 
71
- elems = [el.item() for el in logits[0].argsort(descending=True)]
72
- print(len(elems))
73
- probs = logits[0].softmax(dim=0)
74
  str_ans = ''
75
  current_prob = 0
76
  current_elems = []
 
7
  import json
8
  from torch.utils.data import Dataset, DataLoader
9
  from transformers import RobertaModel, RobertaTokenizer
 
10
  import transformers
11
+
12
 
13
  idx_to_tag = {0: 'cs',
14
  1: 'stat',
 
16
  3: 'math',
17
  4: 'q-bio',
18
  5: 'eess',
19
+ 6: 'economics, finances',
20
+ 7: 'gr-qc',
21
+ 8: 'hep-ex',
22
+ 9: 'hep-lat'}
23
 
24
 
25
  tag_to_idx = {'cs': 0,
 
28
  'math': 3,
29
  'q-bio': 4,
30
  'eess': 5,
31
+ 'economics, finances': 6,
32
+ 'gr-qc': 7,
33
+ 'hep-ex': 8,
34
+ 'hep-lat': 9}
35
+
36
+ class RobertaClass(torch.nn.Module):
37
+ def __init__(self):
38
+ super(RobertaClass, self).__init__()
39
+ self.l1 = RobertaModel.from_pretrained("roberta-base")
40
+ self.pre_classifier = torch.nn.Linear(768, 768)
41
+ self.dropout = torch.nn.Dropout(0.3)
42
+ self.classifier = torch.nn.Linear(768, 5)
43
+
44
+ def forward(self, input_ids, attention_mask, token_type_ids):
45
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
46
+ hidden_state = output_1[0]
47
+ pooler = hidden_state[:, 0]
48
+ pooler = self.pre_classifier(pooler)
49
+ pooler = torch.nn.ReLU()(pooler)
50
+ pooler = self.dropout(pooler)
51
+ output = self.classifier(pooler)
52
+ return output
53
 
54
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True,
55
+ vocab_file='model/vocab.json',
56
+ merges_file='model/merges.txt')
57
 
58
+
59
+
60
+ model = torch.load('model/pytorch_roberta_sentiment.bin', map_location=torch.device('cpu'))
61
 
62
 
63
  st.markdown("### Угадыватель")
 
70
 
71
 
72
  if st.button('Предположить'):
73
+ text = title + " : " + abstract
74
+ inputs = tokenizer.encode_plus(
75
+ text,
76
+ None,
77
+ add_special_tokens=True,
78
+ max_length=256,
79
+ pad_to_max_length=True,
80
+ return_token_type_ids=True
81
+ )
82
+
83
 
84
+ ids = torch.Tensor(inputs['input_ids']).long()
85
+ mask = torch.Tensor(inputs['attention_mask']).long()
86
+ token_type_ids = torch.Tensor(inputs['token_type_ids']).long()
87
+
88
+ ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0))
89
+ idx = torch.nn.functional.softmax(ans[0], dim=0).argmax().item()
90
  st.markdown(f'{idx_to_tag[idx]}')
91
 
92
  if st.button("Посмотреть топ"):
93
+ if not ans:
94
+ print(1)
95
+ text = title + " : " + abstract
96
+ inputs = tokenizer.encode_plus(
97
+ text,
98
+ None,
99
+ add_special_tokens=True,
100
+ max_length=256,
101
+ pad_to_max_length=True,
102
+ return_token_type_ids=True
103
+ )
104
+
105
+
106
+ ids = torch.Tensor(inputs['input_ids']).long()
107
+ mask = torch.Tensor(inputs['attention_mask']).long()
108
+ token_type_ids = torch.Tensor(inputs['token_type_ids']).long()
109
 
110
+ ans = model(ids.unsqueeze(0), mask.unsqueeze(0), token_type_ids.unsqueeze(0))
111
+ elems = [el.item() for el in ans[0].argsort(descending=True)]
112
+ probs = ans[0].softmax(dim=0)
113
  str_ans = ''
114
  current_prob = 0
115
  current_elems = []