gvbazhenov commited on
Commit
7120cc1
·
1 Parent(s): 38581dc
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification
7
+ )
8
+ import streamlit as st
9
+
10
+ DEPLOYMENT_PATH = '.'
11
+
12
+ @st.cache_resource
13
+ def setup():
14
+ model_name = 'distilbert-base-cased'
15
+ model = AutoModelForSequenceClassification.from_pretrained(f'{DEPLOYMENT_PATH}/checkpoint')
16
+ model.eval()
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ idx2category = pd.read_csv(f'{DEPLOYMENT_PATH}/categories.csv').values.squeeze()
19
+ return model, tokenizer, idx2category
20
+
21
+ @torch.no_grad()
22
+ def get_probas(title, abstract=None):
23
+ inputs = tokenizer(
24
+ title,
25
+ abstract,
26
+ padding=True,
27
+ truncation=True,
28
+ return_tensors='pt'
29
+ )
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits
32
+ probas = (
33
+ torch.sigmoid(logits)
34
+ .detach().numpy().reshape(-1)
35
+ )
36
+ return probas
37
+
38
+ model, tokenizer, idx2category = setup()
39
+ num_categories = len(idx2category)
40
+
41
+ def get_categories_by_threshold(probas, threshold=0.3):
42
+ categories = [
43
+ idx2category[idx] for idx in range(num_categories)
44
+ if probas[idx] > threshold
45
+ ]
46
+ return categories
47
+
48
+ def get_top_categories(probas, num_predictions=5):
49
+ categories = [
50
+ idx2category[idx] for idx in np.argsort(probas)[::-1][:num_predictions]
51
+ ]
52
+ return categories
53
+
54
+ st.title('ArXiv Papers Categorization')
55
+ title_input = st.text_input('Enter the title of paper:')
56
+ abstract_input = st.text_area('Enter the abstract (optional):')
57
+
58
+ IS_READY = len(title_input) > 0
59
+
60
+ if IS_READY and st.button('Categorize'):
61
+ probas = get_probas(title_input, abstract_input)
62
+ categories_predicted = get_categories_by_threshold(probas)
63
+
64
+ if len(categories_predicted) == 0:
65
+ categories_predicted = get_top_categories(probas)
66
+
67
+ st.write('Relevant arXiv categories:')
68
+ for category in categories_predicted:
69
+ st.markdown(f'- `{category}`')
categories.csv ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ category
2
+ cond-mat.dis-nn
3
+ cs.AI
4
+ cs.AR
5
+ cs.CC
6
+ cs.CE
7
+ cs.CG
8
+ cs.CL
9
+ cs.CR
10
+ cs.CV
11
+ cs.CY
12
+ cs.DB
13
+ cs.DC
14
+ cs.DL
15
+ cs.DM
16
+ cs.DS
17
+ cs.ET
18
+ cs.FL
19
+ cs.GL
20
+ cs.GR
21
+ cs.GT
22
+ cs.HC
23
+ cs.IR
24
+ cs.IT
25
+ cs.LG
26
+ cs.LO
27
+ cs.MA
28
+ cs.MM
29
+ cs.MS
30
+ cs.NA
31
+ cs.NE
32
+ cs.NI
33
+ cs.OH
34
+ cs.OS
35
+ cs.PF
36
+ cs.PL
37
+ cs.RO
38
+ cs.SC
39
+ cs.SD
40
+ cs.SE
41
+ cs.SI
42
+ cs.SY
43
+ econ.EM
44
+ eess.AS
45
+ eess.IV
46
+ eess.SP
47
+ math.AC
48
+ math.AG
49
+ math.AP
50
+ math.AT
51
+ math.CA
52
+ math.CO
53
+ math.CT
54
+ math.CV
55
+ math.DG
56
+ math.DS
57
+ math.FA
58
+ math.GM
59
+ math.GN
60
+ math.GR
61
+ math.GT
62
+ math.HO
63
+ math.IT
64
+ math.LO
65
+ math.MG
66
+ math.MP
67
+ math.NA
68
+ math.NT
69
+ math.OA
70
+ math.OC
71
+ math.PR
72
+ math.QA
73
+ math.RA
74
+ math.RT
75
+ math.SP
76
+ math.ST
77
+ nlin.AO
78
+ nlin.CD
79
+ nlin.CG
80
+ nlin.PS
81
+ physics.ao-ph
82
+ physics.app-ph
83
+ physics.bio-ph
84
+ physics.chem-ph
85
+ physics.class-ph
86
+ physics.comp-ph
87
+ physics.data-an
88
+ physics.flu-dyn
89
+ physics.gen-ph
90
+ physics.geo-ph
91
+ physics.hist-ph
92
+ physics.ins-det
93
+ physics.med-ph
94
+ physics.optics
95
+ physics.pop-ph
96
+ physics.soc-ph
97
+ physics.space-ph
98
+ q-bio
99
+ q-bio.BM
100
+ q-bio.CB
101
+ q-bio.GN
102
+ q-bio.MN
103
+ q-bio.NC
104
+ q-bio.OT
105
+ q-bio.PE
106
+ q-bio.QM
107
+ q-bio.SC
108
+ q-bio.TO
109
+ q-fin.CP
110
+ q-fin.EC
111
+ q-fin.GN
112
+ q-fin.PM
113
+ q-fin.PR
114
+ q-fin.RM
115
+ q-fin.ST
116
+ q-fin.TR
117
+ stat.AP
118
+ stat.CO
119
+ stat.ME
120
+ stat.ML
121
+ stat.OT
122
+ stat.TH
checkpoint/config.json ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "dim": 768,
8
+ "dropout": 0.1,
9
+ "hidden_dim": 3072,
10
+ "id2label": {
11
+ "0": "LABEL_0",
12
+ "1": "LABEL_1",
13
+ "2": "LABEL_2",
14
+ "3": "LABEL_3",
15
+ "4": "LABEL_4",
16
+ "5": "LABEL_5",
17
+ "6": "LABEL_6",
18
+ "7": "LABEL_7",
19
+ "8": "LABEL_8",
20
+ "9": "LABEL_9",
21
+ "10": "LABEL_10",
22
+ "11": "LABEL_11",
23
+ "12": "LABEL_12",
24
+ "13": "LABEL_13",
25
+ "14": "LABEL_14",
26
+ "15": "LABEL_15",
27
+ "16": "LABEL_16",
28
+ "17": "LABEL_17",
29
+ "18": "LABEL_18",
30
+ "19": "LABEL_19",
31
+ "20": "LABEL_20",
32
+ "21": "LABEL_21",
33
+ "22": "LABEL_22",
34
+ "23": "LABEL_23",
35
+ "24": "LABEL_24",
36
+ "25": "LABEL_25",
37
+ "26": "LABEL_26",
38
+ "27": "LABEL_27",
39
+ "28": "LABEL_28",
40
+ "29": "LABEL_29",
41
+ "30": "LABEL_30",
42
+ "31": "LABEL_31",
43
+ "32": "LABEL_32",
44
+ "33": "LABEL_33",
45
+ "34": "LABEL_34",
46
+ "35": "LABEL_35",
47
+ "36": "LABEL_36",
48
+ "37": "LABEL_37",
49
+ "38": "LABEL_38",
50
+ "39": "LABEL_39",
51
+ "40": "LABEL_40",
52
+ "41": "LABEL_41",
53
+ "42": "LABEL_42",
54
+ "43": "LABEL_43",
55
+ "44": "LABEL_44",
56
+ "45": "LABEL_45",
57
+ "46": "LABEL_46",
58
+ "47": "LABEL_47",
59
+ "48": "LABEL_48",
60
+ "49": "LABEL_49",
61
+ "50": "LABEL_50",
62
+ "51": "LABEL_51",
63
+ "52": "LABEL_52",
64
+ "53": "LABEL_53",
65
+ "54": "LABEL_54",
66
+ "55": "LABEL_55",
67
+ "56": "LABEL_56",
68
+ "57": "LABEL_57",
69
+ "58": "LABEL_58",
70
+ "59": "LABEL_59",
71
+ "60": "LABEL_60",
72
+ "61": "LABEL_61",
73
+ "62": "LABEL_62",
74
+ "63": "LABEL_63",
75
+ "64": "LABEL_64",
76
+ "65": "LABEL_65",
77
+ "66": "LABEL_66",
78
+ "67": "LABEL_67",
79
+ "68": "LABEL_68",
80
+ "69": "LABEL_69",
81
+ "70": "LABEL_70",
82
+ "71": "LABEL_71",
83
+ "72": "LABEL_72",
84
+ "73": "LABEL_73",
85
+ "74": "LABEL_74",
86
+ "75": "LABEL_75",
87
+ "76": "LABEL_76",
88
+ "77": "LABEL_77",
89
+ "78": "LABEL_78",
90
+ "79": "LABEL_79",
91
+ "80": "LABEL_80",
92
+ "81": "LABEL_81",
93
+ "82": "LABEL_82",
94
+ "83": "LABEL_83",
95
+ "84": "LABEL_84",
96
+ "85": "LABEL_85",
97
+ "86": "LABEL_86",
98
+ "87": "LABEL_87",
99
+ "88": "LABEL_88",
100
+ "89": "LABEL_89",
101
+ "90": "LABEL_90",
102
+ "91": "LABEL_91",
103
+ "92": "LABEL_92",
104
+ "93": "LABEL_93",
105
+ "94": "LABEL_94",
106
+ "95": "LABEL_95",
107
+ "96": "LABEL_96",
108
+ "97": "LABEL_97",
109
+ "98": "LABEL_98",
110
+ "99": "LABEL_99",
111
+ "100": "LABEL_100",
112
+ "101": "LABEL_101",
113
+ "102": "LABEL_102",
114
+ "103": "LABEL_103",
115
+ "104": "LABEL_104",
116
+ "105": "LABEL_105",
117
+ "106": "LABEL_106",
118
+ "107": "LABEL_107",
119
+ "108": "LABEL_108",
120
+ "109": "LABEL_109",
121
+ "110": "LABEL_110",
122
+ "111": "LABEL_111",
123
+ "112": "LABEL_112",
124
+ "113": "LABEL_113",
125
+ "114": "LABEL_114",
126
+ "115": "LABEL_115",
127
+ "116": "LABEL_116",
128
+ "117": "LABEL_117",
129
+ "118": "LABEL_118",
130
+ "119": "LABEL_119",
131
+ "120": "LABEL_120"
132
+ },
133
+ "initializer_range": 0.02,
134
+ "label2id": {
135
+ "LABEL_0": 0,
136
+ "LABEL_1": 1,
137
+ "LABEL_10": 10,
138
+ "LABEL_100": 100,
139
+ "LABEL_101": 101,
140
+ "LABEL_102": 102,
141
+ "LABEL_103": 103,
142
+ "LABEL_104": 104,
143
+ "LABEL_105": 105,
144
+ "LABEL_106": 106,
145
+ "LABEL_107": 107,
146
+ "LABEL_108": 108,
147
+ "LABEL_109": 109,
148
+ "LABEL_11": 11,
149
+ "LABEL_110": 110,
150
+ "LABEL_111": 111,
151
+ "LABEL_112": 112,
152
+ "LABEL_113": 113,
153
+ "LABEL_114": 114,
154
+ "LABEL_115": 115,
155
+ "LABEL_116": 116,
156
+ "LABEL_117": 117,
157
+ "LABEL_118": 118,
158
+ "LABEL_119": 119,
159
+ "LABEL_12": 12,
160
+ "LABEL_120": 120,
161
+ "LABEL_13": 13,
162
+ "LABEL_14": 14,
163
+ "LABEL_15": 15,
164
+ "LABEL_16": 16,
165
+ "LABEL_17": 17,
166
+ "LABEL_18": 18,
167
+ "LABEL_19": 19,
168
+ "LABEL_2": 2,
169
+ "LABEL_20": 20,
170
+ "LABEL_21": 21,
171
+ "LABEL_22": 22,
172
+ "LABEL_23": 23,
173
+ "LABEL_24": 24,
174
+ "LABEL_25": 25,
175
+ "LABEL_26": 26,
176
+ "LABEL_27": 27,
177
+ "LABEL_28": 28,
178
+ "LABEL_29": 29,
179
+ "LABEL_3": 3,
180
+ "LABEL_30": 30,
181
+ "LABEL_31": 31,
182
+ "LABEL_32": 32,
183
+ "LABEL_33": 33,
184
+ "LABEL_34": 34,
185
+ "LABEL_35": 35,
186
+ "LABEL_36": 36,
187
+ "LABEL_37": 37,
188
+ "LABEL_38": 38,
189
+ "LABEL_39": 39,
190
+ "LABEL_4": 4,
191
+ "LABEL_40": 40,
192
+ "LABEL_41": 41,
193
+ "LABEL_42": 42,
194
+ "LABEL_43": 43,
195
+ "LABEL_44": 44,
196
+ "LABEL_45": 45,
197
+ "LABEL_46": 46,
198
+ "LABEL_47": 47,
199
+ "LABEL_48": 48,
200
+ "LABEL_49": 49,
201
+ "LABEL_5": 5,
202
+ "LABEL_50": 50,
203
+ "LABEL_51": 51,
204
+ "LABEL_52": 52,
205
+ "LABEL_53": 53,
206
+ "LABEL_54": 54,
207
+ "LABEL_55": 55,
208
+ "LABEL_56": 56,
209
+ "LABEL_57": 57,
210
+ "LABEL_58": 58,
211
+ "LABEL_59": 59,
212
+ "LABEL_6": 6,
213
+ "LABEL_60": 60,
214
+ "LABEL_61": 61,
215
+ "LABEL_62": 62,
216
+ "LABEL_63": 63,
217
+ "LABEL_64": 64,
218
+ "LABEL_65": 65,
219
+ "LABEL_66": 66,
220
+ "LABEL_67": 67,
221
+ "LABEL_68": 68,
222
+ "LABEL_69": 69,
223
+ "LABEL_7": 7,
224
+ "LABEL_70": 70,
225
+ "LABEL_71": 71,
226
+ "LABEL_72": 72,
227
+ "LABEL_73": 73,
228
+ "LABEL_74": 74,
229
+ "LABEL_75": 75,
230
+ "LABEL_76": 76,
231
+ "LABEL_77": 77,
232
+ "LABEL_78": 78,
233
+ "LABEL_79": 79,
234
+ "LABEL_8": 8,
235
+ "LABEL_80": 80,
236
+ "LABEL_81": 81,
237
+ "LABEL_82": 82,
238
+ "LABEL_83": 83,
239
+ "LABEL_84": 84,
240
+ "LABEL_85": 85,
241
+ "LABEL_86": 86,
242
+ "LABEL_87": 87,
243
+ "LABEL_88": 88,
244
+ "LABEL_89": 89,
245
+ "LABEL_9": 9,
246
+ "LABEL_90": 90,
247
+ "LABEL_91": 91,
248
+ "LABEL_92": 92,
249
+ "LABEL_93": 93,
250
+ "LABEL_94": 94,
251
+ "LABEL_95": 95,
252
+ "LABEL_96": 96,
253
+ "LABEL_97": 97,
254
+ "LABEL_98": 98,
255
+ "LABEL_99": 99
256
+ },
257
+ "max_position_embeddings": 512,
258
+ "model_type": "distilbert",
259
+ "n_heads": 12,
260
+ "n_layers": 6,
261
+ "output_past": true,
262
+ "pad_token_id": 0,
263
+ "problem_type": "multi_label_classification",
264
+ "qa_dropout": 0.1,
265
+ "seq_classif_dropout": 0.2,
266
+ "sinusoidal_pos_embds": false,
267
+ "tie_weights_": true,
268
+ "torch_dtype": "float32",
269
+ "transformers_version": "4.50.3",
270
+ "vocab_size": 28996
271
+ }
checkpoint/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f5408495c388d56779d586a1cea49b5118730777367cc117906705f0854d4e8
3
+ size 263510740
checkpoint/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4750d64b2f40be3f5815be468cb1e00555d1c85d6f15b2fb3ae1bdee389db690
3
+ size 5304
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ torch
4
+ transformers