Andrei Pavlov commited on
Commit
e0b0f3b
·
1 Parent(s): 340f25a

Paper classifier app and model

Browse files
src/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import re
3
+
4
+ ROOT = Path(__file__).parent
5
+ DATA_DIR = ROOT / "data"
6
+ MODEL_DIR = ROOT / "model"
7
+ RAW_DATA_PATH = ROOT / "arxivData.json"
8
+
9
+ SEED = 42
10
+ BATCH_SIZE = 16
11
+ NUM_EPOCHS = 10
12
+ VAL_RATIO = 0.1
13
+ TEST_RATIO = 0.1
14
+ LEARNING_RATE = 1e-3
15
+ MAX_LENGTH = 512
16
+
17
+
18
+ def _load_taxonomy(path):
19
+ tag_names = {}
20
+ for line in open(path):
21
+ line = line.strip()
22
+ if not line:
23
+ continue
24
+
25
+ regex_tag_and_name = re.match(r"^([\w.-]+)\s+\((.+)\)$", line)
26
+ if regex_tag_and_name:
27
+ tag_names[regex_tag_and_name.group(1)] = regex_tag_and_name.group(2)
28
+
29
+ return tag_names
30
+
31
+
32
+ TAG_NAMES = _load_taxonomy(ROOT / "taxonomy.txt")
33
+
34
+
35
+ def get_tag_name(tag):
36
+ if tag in TAG_NAMES:
37
+ return TAG_NAMES[tag]
38
+
39
+ prefix = tag.split(".")[0] if "." in tag else tag
40
+ if prefix in TAG_NAMES:
41
+ return TAG_NAMES[prefix]
42
+
43
+ return tag
src/model/final/config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": null,
8
+ "classifier_dropout": null,
9
+ "dtype": "float32",
10
+ "eos_token_id": null,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "id2label": {
15
+ "0": "cmp-lg",
16
+ "1": "cs.AI",
17
+ "2": "cs.CE",
18
+ "3": "cs.CL",
19
+ "4": "cs.CR",
20
+ "5": "cs.CV",
21
+ "6": "cs.CY",
22
+ "7": "cs.DB",
23
+ "8": "cs.DC",
24
+ "9": "cs.DS",
25
+ "10": "cs.GT",
26
+ "11": "cs.HC",
27
+ "12": "cs.IR",
28
+ "13": "cs.IT",
29
+ "14": "cs.LG",
30
+ "15": "cs.LO",
31
+ "16": "cs.MM",
32
+ "17": "cs.NE",
33
+ "18": "cs.RO",
34
+ "19": "cs.SD",
35
+ "20": "cs.SE",
36
+ "21": "cs.SI",
37
+ "22": "math.OC",
38
+ "23": "q-bio.NC",
39
+ "24": "stat.ME",
40
+ "25": "stat.ML"
41
+ },
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 3072,
44
+ "is_decoder": false,
45
+ "label2id": {
46
+ "cmp-lg": 0,
47
+ "cs.AI": 1,
48
+ "cs.CE": 2,
49
+ "cs.CL": 3,
50
+ "cs.CR": 4,
51
+ "cs.CV": 5,
52
+ "cs.CY": 6,
53
+ "cs.DB": 7,
54
+ "cs.DC": 8,
55
+ "cs.DS": 9,
56
+ "cs.GT": 10,
57
+ "cs.HC": 11,
58
+ "cs.IR": 12,
59
+ "cs.IT": 13,
60
+ "cs.LG": 14,
61
+ "cs.LO": 15,
62
+ "cs.MM": 16,
63
+ "cs.NE": 17,
64
+ "cs.RO": 18,
65
+ "cs.SD": 19,
66
+ "cs.SE": 20,
67
+ "cs.SI": 21,
68
+ "math.OC": 22,
69
+ "q-bio.NC": 23,
70
+ "stat.ME": 24,
71
+ "stat.ML": 25
72
+ },
73
+ "layer_norm_eps": 1e-12,
74
+ "max_position_embeddings": 512,
75
+ "model_type": "bert",
76
+ "num_attention_heads": 12,
77
+ "num_hidden_layers": 12,
78
+ "pad_token_id": 0,
79
+ "problem_type": "single_label_classification",
80
+ "tie_word_embeddings": true,
81
+ "transformers_version": "5.5.0",
82
+ "type_vocab_size": 2,
83
+ "use_cache": false,
84
+ "vocab_size": 31090
85
+ }
src/model/final/label_mapping.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "label2id": {
3
+ "cmp-lg": "0",
4
+ "cs.AI": "1",
5
+ "cs.CE": "2",
6
+ "cs.CL": "3",
7
+ "cs.CR": "4",
8
+ "cs.CV": "5",
9
+ "cs.CY": "6",
10
+ "cs.DB": "7",
11
+ "cs.DC": "8",
12
+ "cs.DS": "9",
13
+ "cs.GT": "10",
14
+ "cs.HC": "11",
15
+ "cs.IR": "12",
16
+ "cs.IT": "13",
17
+ "cs.LG": "14",
18
+ "cs.LO": "15",
19
+ "cs.MM": "16",
20
+ "cs.NE": "17",
21
+ "cs.RO": "18",
22
+ "cs.SD": "19",
23
+ "cs.SE": "20",
24
+ "cs.SI": "21",
25
+ "math.OC": "22",
26
+ "q-bio.NC": "23",
27
+ "stat.ME": "24",
28
+ "stat.ML": "25"
29
+ },
30
+ "id2label": {
31
+ "0": "cmp-lg",
32
+ "1": "cs.AI",
33
+ "2": "cs.CE",
34
+ "3": "cs.CL",
35
+ "4": "cs.CR",
36
+ "5": "cs.CV",
37
+ "6": "cs.CY",
38
+ "7": "cs.DB",
39
+ "8": "cs.DC",
40
+ "9": "cs.DS",
41
+ "10": "cs.GT",
42
+ "11": "cs.HC",
43
+ "12": "cs.IR",
44
+ "13": "cs.IT",
45
+ "14": "cs.LG",
46
+ "15": "cs.LO",
47
+ "16": "cs.MM",
48
+ "17": "cs.NE",
49
+ "18": "cs.RO",
50
+ "19": "cs.SD",
51
+ "20": "cs.SE",
52
+ "21": "cs.SI",
53
+ "22": "math.OC",
54
+ "23": "q-bio.NC",
55
+ "24": "stat.ME",
56
+ "25": "stat.ML"
57
+ },
58
+ "label_names": {
59
+ "cmp-lg": "Computational Linguistics",
60
+ "cs.AI": "Artificial Intelligence",
61
+ "cs.CE": "Computational Engineering, Finance, and Science",
62
+ "cs.CL": "Computation and Language",
63
+ "cs.CR": "Cryptography and Security",
64
+ "cs.CV": "Computer Vision and Pattern Recognition",
65
+ "cs.CY": "Computers and Society",
66
+ "cs.DB": "Databases",
67
+ "cs.DC": "Distributed, Parallel, and Cluster Computing",
68
+ "cs.DS": "Data Structures and Algorithms",
69
+ "cs.GT": "Computer Science and Game Theory",
70
+ "cs.HC": "Human-Computer Interaction",
71
+ "cs.IR": "Information Retrieval",
72
+ "cs.IT": "Information Theory",
73
+ "cs.LG": "Machine Learning",
74
+ "cs.LO": "Logic in Computer Science",
75
+ "cs.MM": "Multimedia",
76
+ "cs.NE": "Neural and Evolutionary Computing",
77
+ "cs.RO": "Robotics",
78
+ "cs.SD": "Sound",
79
+ "cs.SE": "Software Engineering",
80
+ "cs.SI": "Social and Information Networks",
81
+ "math.OC": "Optimization and Control",
82
+ "q-bio.NC": "Neurons and Cognition",
83
+ "stat.ME": "Methodology",
84
+ "stat.ML": "Machine Learning"
85
+ }
86
+ }
src/model/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6e8d238bf5418b8d3b730f2ad95291c32d41b9628d9313b667f711d5cdddb90
3
+ size 439777344
src/model/final/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
src/model/final/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 1000000000000000019884624838656,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "BertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }
src/model/final/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc0ddfa117157db3ff50032a9a59efc659d26c4602a636deec4a8cf00b781bab
3
+ size 5329
src/model_utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
8
+
9
+ from config import MAX_LENGTH, MODEL_DIR, get_tag_name
10
+
11
+
12
+ def clean_text(text):
13
+ return re.sub(r"\s+", " ", text.strip())
14
+
15
+
16
+ def format_input(title, abstract=None):
17
+ title = clean_text(title)
18
+ if abstract and abstract.strip():
19
+ return f"[TITLE] {title} [SEP] [ABSTRACT] {clean_text(abstract)}"
20
+ return f"[TITLE] {title}"
21
+
22
+
23
+ class PaperClassifier:
24
+ def __init__(self, model_path=None):
25
+ if model_path is None:
26
+ model_path = str(MODEL_DIR / "final")
27
+
28
+ self.device = torch.device(
29
+ "cuda" if torch.cuda.is_available()
30
+ else "mps" if torch.backends.mps.is_available()
31
+ else "cpu"
32
+ )
33
+
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
35
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
36
+ self.model.to(self.device)
37
+ self.model.eval()
38
+
39
+ with open(Path(model_path) / "label_mapping.json") as f:
40
+ mapping = json.load(f)
41
+
42
+ self.id2label = mapping["id2label"]
43
+ self.label_names = mapping.get("label_names", {})
44
+
45
+ @torch.no_grad()
46
+ def predict(self, title, abstract=None, threshold=0.95):
47
+ text = format_input(title, abstract)
48
+
49
+ inputs = self.tokenizer(
50
+ text,
51
+ padding="max_length",
52
+ truncation=True,
53
+ max_length=MAX_LENGTH,
54
+ return_tensors="pt",
55
+ ).to(self.device)
56
+
57
+ logits = self.model(**inputs).logits[0].cpu().numpy()
58
+ probs = np.exp(logits - logits.max())
59
+ probs /= probs.sum()
60
+
61
+ results = []
62
+ cumulative = 0.0
63
+ for idx in np.argsort(probs)[::-1]:
64
+ tag = self.id2label[str(idx)]
65
+ prob = float(probs[idx])
66
+ results.append({
67
+ "tag": tag,
68
+ "name": self.label_names.get(tag, get_tag_name(tag)),
69
+ "probability": prob,
70
+ })
71
+ cumulative += prob
72
+ if cumulative >= threshold:
73
+ break
74
+
75
+ return results
src/streamlit_app.py CHANGED
@@ -1,40 +1,97 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from model_utils import PaperClassifier
3
 
4
+ st.set_page_config(page_title="Paper Classifier", layout="centered")
5
+
6
+ st.markdown("""
7
+ <style>
8
+ .result-box {
9
+ background: #4a5568; padding: 1rem; border-radius: 8px; color: white; margin-bottom: 0.5rem;
10
+ }
11
+ .prob-bar {
12
+ background: rgba(255,255,255,0.2); border-radius: 6px; height: 22px; margin-top: 4px; overflow: hidden;
13
+ }
14
+ .prob-fill {
15
+ background: #68d391; height: 100%; border-radius: 6px;
16
+ padding-left: 8px; font-size: 0.85rem; font-weight: 600;
17
+ color: #1a202c; display: flex; align-items: center;
18
+ }
19
+ </style>
20
+ """, unsafe_allow_html=True)
21
+
22
+
23
+ @st.cache_resource(show_spinner="Loading model...")
24
+ def load_model():
25
+ return PaperClassifier()
26
+
27
+
28
+ EXAMPLES = [
29
+ {"title": "Attention Is All You Need",
30
+ "abstract": "We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely."},
31
+ {"title": "A Survey on 3D Gaussian Splatting",
32
+ "abstract": "3D Gaussian splatting (GS) has emerged as a transformative technique in radiance fields. Unlike mainstream implicit neural models, 3D GS uses millions of learnable 3D Gaussians for an explicit scene representation."},
33
+ {"title": "Interior Point Differential Dynamic Programming",
34
+ "abstract": ""},
35
+ ]
36
+
37
+ if "input_title" not in st.session_state:
38
+ st.session_state.input_title = ""
39
+ if "input_abstract" not in st.session_state:
40
+ st.session_state.input_abstract = ""
41
+
42
+
43
+ def set_example(idx):
44
+ st.session_state.input_title = EXAMPLES[idx]["title"]
45
+ st.session_state.input_abstract = EXAMPLES[idx]["abstract"]
46
+
47
+
48
+ def show_results(results):
49
+ st.markdown(f"### Predicted {len(results)} categories")
50
+ for r in results:
51
+ pct = r["probability"] * 100
52
+ st.markdown(f"""
53
+ <div class="result-box">
54
+ <b>{r['tag']}</b> - {r['name']}
55
+ <div class="prob-bar">
56
+ <div class="prob-fill" style="width:{max(pct,3)}%">{pct:.1f}%</div>
57
+ </div>
58
+ </div>""", unsafe_allow_html=True)
59
+
60
+
61
+ def main():
62
+ st.title("Paper Classifier")
63
+ st.write("Classify papers using fine-tuned SciBERT in one click!")
64
+
65
+ try:
66
+ clf = load_model()
67
+ except Exception as err:
68
+ st.error(f"Could not load model: {err}")
69
+ return
70
+
71
+ title = st.text_input("**Title:**", key="input_title", placeholder="Paste paper title here")
72
+ abstract = st.text_area("**Abstract**", key="input_abstract", placeholder="You can leave it empty", height=150)
73
+
74
+ st.write("**Use our examples:**")
75
+ cols = st.columns(len(EXAMPLES))
76
+ for i, (col, ex) in enumerate(zip(cols, EXAMPLES)):
77
+ with col:
78
+ label = ex["title"][:20] + "..." if len(ex["title"]) > 20 else ex["title"]
79
+ st.button(label, key=f"ex_{i}", on_click=set_example, args=(i,), use_container_width=True)
80
+
81
+ if st.button("Classify", use_container_width=True):
82
+ if not title or not title.strip():
83
+ st.warning("Enter a title first.")
84
+ return
85
+
86
+ with st.spinner("Classifying..."):
87
+ try:
88
+ results = clf.predict(title=title, abstract=abstract)
89
+ except Exception as err:
90
+ st.error(f"Error: {err}")
91
+ return
92
+
93
+ show_results(results)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()