Spaces:
Sleeping
Sleeping
Andrei Pavlov commited on
Commit ·
e0b0f3b
1
Parent(s): 340f25a
Paper classifier app and model
Browse files- src/config.py +43 -0
- src/model/final/config.json +85 -0
- src/model/final/label_mapping.json +86 -0
- src/model/final/model.safetensors +3 -0
- src/model/final/tokenizer.json +0 -0
- src/model/final/tokenizer_config.json +14 -0
- src/model/final/training_args.bin +3 -0
- src/model_utils.py +75 -0
- src/streamlit_app.py +95 -38
src/config.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
ROOT = Path(__file__).parent
|
| 5 |
+
DATA_DIR = ROOT / "data"
|
| 6 |
+
MODEL_DIR = ROOT / "model"
|
| 7 |
+
RAW_DATA_PATH = ROOT / "arxivData.json"
|
| 8 |
+
|
| 9 |
+
SEED = 42
|
| 10 |
+
BATCH_SIZE = 16
|
| 11 |
+
NUM_EPOCHS = 10
|
| 12 |
+
VAL_RATIO = 0.1
|
| 13 |
+
TEST_RATIO = 0.1
|
| 14 |
+
LEARNING_RATE = 1e-3
|
| 15 |
+
MAX_LENGTH = 512
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _load_taxonomy(path):
|
| 19 |
+
tag_names = {}
|
| 20 |
+
for line in open(path):
|
| 21 |
+
line = line.strip()
|
| 22 |
+
if not line:
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
regex_tag_and_name = re.match(r"^([\w.-]+)\s+\((.+)\)$", line)
|
| 26 |
+
if regex_tag_and_name:
|
| 27 |
+
tag_names[regex_tag_and_name.group(1)] = regex_tag_and_name.group(2)
|
| 28 |
+
|
| 29 |
+
return tag_names
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
TAG_NAMES = _load_taxonomy(ROOT / "taxonomy.txt")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_tag_name(tag):
|
| 36 |
+
if tag in TAG_NAMES:
|
| 37 |
+
return TAG_NAMES[tag]
|
| 38 |
+
|
| 39 |
+
prefix = tag.split(".")[0] if "." in tag else tag
|
| 40 |
+
if prefix in TAG_NAMES:
|
| 41 |
+
return TAG_NAMES[prefix]
|
| 42 |
+
|
| 43 |
+
return tag
|
src/model/final/config.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_cross_attention": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"classifier_dropout": null,
|
| 9 |
+
"dtype": "float32",
|
| 10 |
+
"eos_token_id": null,
|
| 11 |
+
"hidden_act": "gelu",
|
| 12 |
+
"hidden_dropout_prob": 0.1,
|
| 13 |
+
"hidden_size": 768,
|
| 14 |
+
"id2label": {
|
| 15 |
+
"0": "cmp-lg",
|
| 16 |
+
"1": "cs.AI",
|
| 17 |
+
"2": "cs.CE",
|
| 18 |
+
"3": "cs.CL",
|
| 19 |
+
"4": "cs.CR",
|
| 20 |
+
"5": "cs.CV",
|
| 21 |
+
"6": "cs.CY",
|
| 22 |
+
"7": "cs.DB",
|
| 23 |
+
"8": "cs.DC",
|
| 24 |
+
"9": "cs.DS",
|
| 25 |
+
"10": "cs.GT",
|
| 26 |
+
"11": "cs.HC",
|
| 27 |
+
"12": "cs.IR",
|
| 28 |
+
"13": "cs.IT",
|
| 29 |
+
"14": "cs.LG",
|
| 30 |
+
"15": "cs.LO",
|
| 31 |
+
"16": "cs.MM",
|
| 32 |
+
"17": "cs.NE",
|
| 33 |
+
"18": "cs.RO",
|
| 34 |
+
"19": "cs.SD",
|
| 35 |
+
"20": "cs.SE",
|
| 36 |
+
"21": "cs.SI",
|
| 37 |
+
"22": "math.OC",
|
| 38 |
+
"23": "q-bio.NC",
|
| 39 |
+
"24": "stat.ME",
|
| 40 |
+
"25": "stat.ML"
|
| 41 |
+
},
|
| 42 |
+
"initializer_range": 0.02,
|
| 43 |
+
"intermediate_size": 3072,
|
| 44 |
+
"is_decoder": false,
|
| 45 |
+
"label2id": {
|
| 46 |
+
"cmp-lg": 0,
|
| 47 |
+
"cs.AI": 1,
|
| 48 |
+
"cs.CE": 2,
|
| 49 |
+
"cs.CL": 3,
|
| 50 |
+
"cs.CR": 4,
|
| 51 |
+
"cs.CV": 5,
|
| 52 |
+
"cs.CY": 6,
|
| 53 |
+
"cs.DB": 7,
|
| 54 |
+
"cs.DC": 8,
|
| 55 |
+
"cs.DS": 9,
|
| 56 |
+
"cs.GT": 10,
|
| 57 |
+
"cs.HC": 11,
|
| 58 |
+
"cs.IR": 12,
|
| 59 |
+
"cs.IT": 13,
|
| 60 |
+
"cs.LG": 14,
|
| 61 |
+
"cs.LO": 15,
|
| 62 |
+
"cs.MM": 16,
|
| 63 |
+
"cs.NE": 17,
|
| 64 |
+
"cs.RO": 18,
|
| 65 |
+
"cs.SD": 19,
|
| 66 |
+
"cs.SE": 20,
|
| 67 |
+
"cs.SI": 21,
|
| 68 |
+
"math.OC": 22,
|
| 69 |
+
"q-bio.NC": 23,
|
| 70 |
+
"stat.ME": 24,
|
| 71 |
+
"stat.ML": 25
|
| 72 |
+
},
|
| 73 |
+
"layer_norm_eps": 1e-12,
|
| 74 |
+
"max_position_embeddings": 512,
|
| 75 |
+
"model_type": "bert",
|
| 76 |
+
"num_attention_heads": 12,
|
| 77 |
+
"num_hidden_layers": 12,
|
| 78 |
+
"pad_token_id": 0,
|
| 79 |
+
"problem_type": "single_label_classification",
|
| 80 |
+
"tie_word_embeddings": true,
|
| 81 |
+
"transformers_version": "5.5.0",
|
| 82 |
+
"type_vocab_size": 2,
|
| 83 |
+
"use_cache": false,
|
| 84 |
+
"vocab_size": 31090
|
| 85 |
+
}
|
src/model/final/label_mapping.json
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"label2id": {
|
| 3 |
+
"cmp-lg": "0",
|
| 4 |
+
"cs.AI": "1",
|
| 5 |
+
"cs.CE": "2",
|
| 6 |
+
"cs.CL": "3",
|
| 7 |
+
"cs.CR": "4",
|
| 8 |
+
"cs.CV": "5",
|
| 9 |
+
"cs.CY": "6",
|
| 10 |
+
"cs.DB": "7",
|
| 11 |
+
"cs.DC": "8",
|
| 12 |
+
"cs.DS": "9",
|
| 13 |
+
"cs.GT": "10",
|
| 14 |
+
"cs.HC": "11",
|
| 15 |
+
"cs.IR": "12",
|
| 16 |
+
"cs.IT": "13",
|
| 17 |
+
"cs.LG": "14",
|
| 18 |
+
"cs.LO": "15",
|
| 19 |
+
"cs.MM": "16",
|
| 20 |
+
"cs.NE": "17",
|
| 21 |
+
"cs.RO": "18",
|
| 22 |
+
"cs.SD": "19",
|
| 23 |
+
"cs.SE": "20",
|
| 24 |
+
"cs.SI": "21",
|
| 25 |
+
"math.OC": "22",
|
| 26 |
+
"q-bio.NC": "23",
|
| 27 |
+
"stat.ME": "24",
|
| 28 |
+
"stat.ML": "25"
|
| 29 |
+
},
|
| 30 |
+
"id2label": {
|
| 31 |
+
"0": "cmp-lg",
|
| 32 |
+
"1": "cs.AI",
|
| 33 |
+
"2": "cs.CE",
|
| 34 |
+
"3": "cs.CL",
|
| 35 |
+
"4": "cs.CR",
|
| 36 |
+
"5": "cs.CV",
|
| 37 |
+
"6": "cs.CY",
|
| 38 |
+
"7": "cs.DB",
|
| 39 |
+
"8": "cs.DC",
|
| 40 |
+
"9": "cs.DS",
|
| 41 |
+
"10": "cs.GT",
|
| 42 |
+
"11": "cs.HC",
|
| 43 |
+
"12": "cs.IR",
|
| 44 |
+
"13": "cs.IT",
|
| 45 |
+
"14": "cs.LG",
|
| 46 |
+
"15": "cs.LO",
|
| 47 |
+
"16": "cs.MM",
|
| 48 |
+
"17": "cs.NE",
|
| 49 |
+
"18": "cs.RO",
|
| 50 |
+
"19": "cs.SD",
|
| 51 |
+
"20": "cs.SE",
|
| 52 |
+
"21": "cs.SI",
|
| 53 |
+
"22": "math.OC",
|
| 54 |
+
"23": "q-bio.NC",
|
| 55 |
+
"24": "stat.ME",
|
| 56 |
+
"25": "stat.ML"
|
| 57 |
+
},
|
| 58 |
+
"label_names": {
|
| 59 |
+
"cmp-lg": "Computational Linguistics",
|
| 60 |
+
"cs.AI": "Artificial Intelligence",
|
| 61 |
+
"cs.CE": "Computational Engineering, Finance, and Science",
|
| 62 |
+
"cs.CL": "Computation and Language",
|
| 63 |
+
"cs.CR": "Cryptography and Security",
|
| 64 |
+
"cs.CV": "Computer Vision and Pattern Recognition",
|
| 65 |
+
"cs.CY": "Computers and Society",
|
| 66 |
+
"cs.DB": "Databases",
|
| 67 |
+
"cs.DC": "Distributed, Parallel, and Cluster Computing",
|
| 68 |
+
"cs.DS": "Data Structures and Algorithms",
|
| 69 |
+
"cs.GT": "Computer Science and Game Theory",
|
| 70 |
+
"cs.HC": "Human-Computer Interaction",
|
| 71 |
+
"cs.IR": "Information Retrieval",
|
| 72 |
+
"cs.IT": "Information Theory",
|
| 73 |
+
"cs.LG": "Machine Learning",
|
| 74 |
+
"cs.LO": "Logic in Computer Science",
|
| 75 |
+
"cs.MM": "Multimedia",
|
| 76 |
+
"cs.NE": "Neural and Evolutionary Computing",
|
| 77 |
+
"cs.RO": "Robotics",
|
| 78 |
+
"cs.SD": "Sound",
|
| 79 |
+
"cs.SE": "Software Engineering",
|
| 80 |
+
"cs.SI": "Social and Information Networks",
|
| 81 |
+
"math.OC": "Optimization and Control",
|
| 82 |
+
"q-bio.NC": "Neurons and Cognition",
|
| 83 |
+
"stat.ME": "Methodology",
|
| 84 |
+
"stat.ML": "Machine Learning"
|
| 85 |
+
}
|
| 86 |
+
}
|
src/model/final/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6e8d238bf5418b8d3b730f2ad95291c32d41b9628d9313b667f711d5cdddb90
|
| 3 |
+
size 439777344
|
src/model/final/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/model/final/tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"do_lower_case": true,
|
| 5 |
+
"is_local": false,
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 8 |
+
"pad_token": "[PAD]",
|
| 9 |
+
"sep_token": "[SEP]",
|
| 10 |
+
"strip_accents": null,
|
| 11 |
+
"tokenize_chinese_chars": true,
|
| 12 |
+
"tokenizer_class": "BertTokenizer",
|
| 13 |
+
"unk_token": "[UNK]"
|
| 14 |
+
}
|
src/model/final/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc0ddfa117157db3ff50032a9a59efc659d26c4602a636deec4a8cf00b781bab
|
| 3 |
+
size 5329
|
src/model_utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
from config import MAX_LENGTH, MODEL_DIR, get_tag_name
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def clean_text(text):
|
| 13 |
+
return re.sub(r"\s+", " ", text.strip())
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def format_input(title, abstract=None):
|
| 17 |
+
title = clean_text(title)
|
| 18 |
+
if abstract and abstract.strip():
|
| 19 |
+
return f"[TITLE] {title} [SEP] [ABSTRACT] {clean_text(abstract)}"
|
| 20 |
+
return f"[TITLE] {title}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PaperClassifier:
|
| 24 |
+
def __init__(self, model_path=None):
|
| 25 |
+
if model_path is None:
|
| 26 |
+
model_path = str(MODEL_DIR / "final")
|
| 27 |
+
|
| 28 |
+
self.device = torch.device(
|
| 29 |
+
"cuda" if torch.cuda.is_available()
|
| 30 |
+
else "mps" if torch.backends.mps.is_available()
|
| 31 |
+
else "cpu"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 35 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 36 |
+
self.model.to(self.device)
|
| 37 |
+
self.model.eval()
|
| 38 |
+
|
| 39 |
+
with open(Path(model_path) / "label_mapping.json") as f:
|
| 40 |
+
mapping = json.load(f)
|
| 41 |
+
|
| 42 |
+
self.id2label = mapping["id2label"]
|
| 43 |
+
self.label_names = mapping.get("label_names", {})
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def predict(self, title, abstract=None, threshold=0.95):
|
| 47 |
+
text = format_input(title, abstract)
|
| 48 |
+
|
| 49 |
+
inputs = self.tokenizer(
|
| 50 |
+
text,
|
| 51 |
+
padding="max_length",
|
| 52 |
+
truncation=True,
|
| 53 |
+
max_length=MAX_LENGTH,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
).to(self.device)
|
| 56 |
+
|
| 57 |
+
logits = self.model(**inputs).logits[0].cpu().numpy()
|
| 58 |
+
probs = np.exp(logits - logits.max())
|
| 59 |
+
probs /= probs.sum()
|
| 60 |
+
|
| 61 |
+
results = []
|
| 62 |
+
cumulative = 0.0
|
| 63 |
+
for idx in np.argsort(probs)[::-1]:
|
| 64 |
+
tag = self.id2label[str(idx)]
|
| 65 |
+
prob = float(probs[idx])
|
| 66 |
+
results.append({
|
| 67 |
+
"tag": tag,
|
| 68 |
+
"name": self.label_names.get(tag, get_tag_name(tag)),
|
| 69 |
+
"probability": prob,
|
| 70 |
+
})
|
| 71 |
+
cumulative += prob
|
| 72 |
+
if cumulative >= threshold:
|
| 73 |
+
break
|
| 74 |
+
|
| 75 |
+
return results
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,97 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
import streamlit as st
|
|
|
|
| 5 |
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from model_utils import PaperClassifier
|
| 3 |
|
| 4 |
+
st.set_page_config(page_title="Paper Classifier", layout="centered")
|
| 5 |
+
|
| 6 |
+
st.markdown("""
|
| 7 |
+
<style>
|
| 8 |
+
.result-box {
|
| 9 |
+
background: #4a5568; padding: 1rem; border-radius: 8px; color: white; margin-bottom: 0.5rem;
|
| 10 |
+
}
|
| 11 |
+
.prob-bar {
|
| 12 |
+
background: rgba(255,255,255,0.2); border-radius: 6px; height: 22px; margin-top: 4px; overflow: hidden;
|
| 13 |
+
}
|
| 14 |
+
.prob-fill {
|
| 15 |
+
background: #68d391; height: 100%; border-radius: 6px;
|
| 16 |
+
padding-left: 8px; font-size: 0.85rem; font-weight: 600;
|
| 17 |
+
color: #1a202c; display: flex; align-items: center;
|
| 18 |
+
}
|
| 19 |
+
</style>
|
| 20 |
+
""", unsafe_allow_html=True)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@st.cache_resource(show_spinner="Loading model...")
|
| 24 |
+
def load_model():
|
| 25 |
+
return PaperClassifier()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
EXAMPLES = [
|
| 29 |
+
{"title": "Attention Is All You Need",
|
| 30 |
+
"abstract": "We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely."},
|
| 31 |
+
{"title": "A Survey on 3D Gaussian Splatting",
|
| 32 |
+
"abstract": "3D Gaussian splatting (GS) has emerged as a transformative technique in radiance fields. Unlike mainstream implicit neural models, 3D GS uses millions of learnable 3D Gaussians for an explicit scene representation."},
|
| 33 |
+
{"title": "Interior Point Differential Dynamic Programming",
|
| 34 |
+
"abstract": ""},
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
if "input_title" not in st.session_state:
|
| 38 |
+
st.session_state.input_title = ""
|
| 39 |
+
if "input_abstract" not in st.session_state:
|
| 40 |
+
st.session_state.input_abstract = ""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def set_example(idx):
|
| 44 |
+
st.session_state.input_title = EXAMPLES[idx]["title"]
|
| 45 |
+
st.session_state.input_abstract = EXAMPLES[idx]["abstract"]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def show_results(results):
|
| 49 |
+
st.markdown(f"### Predicted {len(results)} categories")
|
| 50 |
+
for r in results:
|
| 51 |
+
pct = r["probability"] * 100
|
| 52 |
+
st.markdown(f"""
|
| 53 |
+
<div class="result-box">
|
| 54 |
+
<b>{r['tag']}</b> - {r['name']}
|
| 55 |
+
<div class="prob-bar">
|
| 56 |
+
<div class="prob-fill" style="width:{max(pct,3)}%">{pct:.1f}%</div>
|
| 57 |
+
</div>
|
| 58 |
+
</div>""", unsafe_allow_html=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
st.title("Paper Classifier")
|
| 63 |
+
st.write("Classify papers using fine-tuned SciBERT in one click!")
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
clf = load_model()
|
| 67 |
+
except Exception as err:
|
| 68 |
+
st.error(f"Could not load model: {err}")
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
title = st.text_input("**Title:**", key="input_title", placeholder="Paste paper title here")
|
| 72 |
+
abstract = st.text_area("**Abstract**", key="input_abstract", placeholder="You can leave it empty", height=150)
|
| 73 |
+
|
| 74 |
+
st.write("**Use our examples:**")
|
| 75 |
+
cols = st.columns(len(EXAMPLES))
|
| 76 |
+
for i, (col, ex) in enumerate(zip(cols, EXAMPLES)):
|
| 77 |
+
with col:
|
| 78 |
+
label = ex["title"][:20] + "..." if len(ex["title"]) > 20 else ex["title"]
|
| 79 |
+
st.button(label, key=f"ex_{i}", on_click=set_example, args=(i,), use_container_width=True)
|
| 80 |
+
|
| 81 |
+
if st.button("Classify", use_container_width=True):
|
| 82 |
+
if not title or not title.strip():
|
| 83 |
+
st.warning("Enter a title first.")
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
with st.spinner("Classifying..."):
|
| 87 |
+
try:
|
| 88 |
+
results = clf.predict(title=title, abstract=abstract)
|
| 89 |
+
except Exception as err:
|
| 90 |
+
st.error(f"Error: {err}")
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
show_results(results)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
main()
|