File size: 19,513 Bytes
cc525d6
 
 
 
41cfc91
cc525d6
 
 
c1a3e2a
cc525d6
 
 
 
 
 
 
 
 
 
c1a3e2a
cc525d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1daed99
 
cc525d6
1daed99
 
 
 
 
 
 
 
 
 
 
 
cc525d6
 
 
1daed99
 
cc525d6
1daed99
 
 
 
 
 
 
 
 
 
 
 
 
 
cc525d6
 
 
1daed99
 
cc525d6
1daed99
 
 
 
 
 
 
 
 
 
 
 
 
 
cc525d6
 
 
1daed99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc525d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61ef0ad
 
 
 
 
 
 
e14f368
 
 
61ef0ad
 
 
 
 
e14f368
 
 
cc525d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41cfc91
61ef0ad
 
 
 
 
cc525d6
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
import json
from pathlib import Path

import matplotlib.pyplot as plt
import streamlit as st
import torch
from transformers import AutoTokenizer

from model import ArticleClassifier


st.set_page_config(
    page_title="Scientific Article Classification",
    page_icon="๐Ÿ’–",
    layout="wide",
)

APP_DIR = Path(__file__).resolve().parent
MODEL_NAME = "allenai/scibert_scivocab_cased"
CHECKPOINT_PATH = APP_DIR / "scibert_full.pt"
FIELD_LABELS_PATH = APP_DIR / "field_labels.json"
MAX_LEN = 256

QUARTILE_LABELS = {0: "Q1", 1: "Q2", 2: "Q3", 3: "Q4"}

TOPIC_COLORS = [
    "#FF4FA3",
    "#FF7A7A",
    "#FFAA5B",
    "#FFD45C",
    "#7FE1D8",
    "#53C8F2",
    "#7C8CFF",
    "#B784F7",
    "#FF8FCF",
    "#F06292",
]

EXAMPLES = [
    {
        "id": "alphafold",
        "title": "Highly accurate protein structure prediction with AlphaFold",
        "abstract": (
            "Proteins are central to biology, but experimentally determining their three-"
            "dimensional structures remains far slower than the growth of sequence databases. "
            "This paper presents a redesigned AlphaFold system, a neural network that combines "
            "evolutionary information from multiple-sequence alignments with geometric and "
            "physical constraints on protein structure. Evaluated in the blind CASP14 benchmark, "
            "the method reaches near-experimental accuracy for many targets, including cases "
            "without close structural homologues, and substantially outperforms competing "
            "approaches. The model predicts full atomic coordinates, scales to long proteins, "
            "and reports confidence estimates that help users judge which regions of a predicted "
            "structure are reliable. By turning sequence data into high-quality structural "
            "models at scale, the work positions deep learning as a practical tool for "
            "structural biology, functional annotation, and downstream biological discovery."
        ),
    },
    {
        "id": "weather",
        "title": "Skilful precipitation nowcasting using deep generative models of radar",
        "abstract": (
            "Short-range precipitation nowcasting is critical for emergency response, energy "
            "operations, transport, flood warning, and other weather-sensitive decisions, yet "
            "traditional radar-advection systems struggle with nonlinear events such as "
            "convective initiation. Earlier deep learning approaches improve some low-intensity "
            "forecasts but often become blurry at longer lead times and perform poorly on rarer, "
            "heavier rain. This paper introduces a probabilistic deep generative model that "
            "predicts future radar fields directly from recent radar observations while "
            "preserving realistic spatial and temporal structure. Across statistical, economic, "
            "and expert-evaluation measures, the model improves forecast quality, consistency, "
            "and operational value. It generates realistic predictions over regions up to 1536 "
            "by 1280 kilometres and for lead times from 5 to 90 minutes. In evaluations "
            "involving more than fifty meteorologists, it ranked first against strong baseline "
            "systems in most cases, showing that generative models can produce useful "
            "high-resolution nowcasts without relying on blur."
        ),
    },
    {
        "id": "materials",
        "title": "Experimental search for high-temperature ferroelectric perovskites guided by two-step machine learning",
        "abstract": (
            "Searching for high-temperature ferroelectric perovskites is difficult because the "
            "chemical space is enormous, experimental guidance is limited, and many candidate "
            "compositions fail to form the desired phase. This study proposes a two-step machine "
            "learning workflow to guide synthesis in xBi(Me' y Me'' 1-y)O3-(1-x)PbTiO3-type "
            "systems. A classification model first screens compositions likely to crystallize as "
            "perovskites, and a regression model coupled with active learning then prioritizes "
            "candidates with high Curie temperature for experimental testing. The search spans "
            "roughly 61,500 possible compositions, whereas only 167 had been characterized "
            "beforehand. By iterating between prediction, synthesis, and feedback from both "
            "successful and failed experiments, the authors efficiently refine the models and "
            "focus the search. Out of ten newly synthesized candidates, six form perovskites, "
            "including three previously unexplored cation pairs, and one composition reaches a "
            "measured Curie temperature of 898 K. The work shows how machine learning can "
            "meaningfully reduce experimental burden in materials discovery."
        ),
    },
    {
        "id": "chem_xai",
        "title": "Chemistry-intuitive explanation of graph neural networks for molecular property prediction with substructure masking",
        "abstract": (
            "Graph neural networks are widely used for molecular property prediction, but their "
            "explanations often highlight individual atoms, bonds, or fragments that do not "
            "match how chemists reason about structure-property relationships. This paper "
            "introduces Substructure Mask Explanation, an interpretation method built around "
            "chemically meaningful molecular segmentations such as functional groups and other "
            "established fragment schemes. Instead of assigning importance to isolated nodes or "
            "edges, the method reveals which substructures drive a graph model's prediction in a "
            "way that is easier for domain experts to inspect. The authors apply the approach to "
            "models for aqueous solubility, genotoxicity, cardiotoxicity, and blood-brain "
            "barrier permeation in small molecules. Across these cases, the explanations align "
            "better with chemical intuition, help identify unreliable model behaviour, and offer "
            "actionable guidance for structural optimization. The paper argues that "
            "substructure-level explainability can make graph neural networks more useful for "
            "medicinal chemistry and drug-discovery workflows."
        ),
    },
]
def inject_styles():
    st.markdown(
        """
        <style>
        .stApp {
            color: #7d2b58;
            background: #fff4fa;
        }
        .stApp h1 {
            color: #d63384;
            font-size: clamp(2.1rem, 3vw, 3rem);
            font-weight: 900;
            letter-spacing: 0.01em;
            white-space: nowrap;
        }
        .st-key-input-shell, .st-key-examples-shell, .st-key-topics-shell, .st-key-quartile-shell {
            border: 1px solid rgba(255, 109, 177, 0.22);
            border-radius: 18px;
            padding: 1rem 1.1rem;
            background: white;
            box-shadow: 0 6px 14px rgba(214, 51, 132, 0.06);
        }
        div.stButton > button {
            border-radius: 14px;
        }
        div.stButton > button[kind="primary"] {
            border: none;
            background: #ff4fa3;
            color: white;
            font-weight: 800;
        }
        div.stButton > button[kind="secondary"] {
            min-height: 140px;
            text-align: left;
            white-space: normal;
            border: 1px solid rgba(255, 118, 186, 0.22);
            background: #fff8fc;
        }
        div.stButton > button[kind="secondary"] p { margin: 0; }
        div.stButton > button[kind="secondary"] p strong { color: #d63384; }
        div.stButton > button[kind="tertiary"] {
            border: 1px solid rgba(255, 116, 184, 0.44);
            background: white;
            color: #d63384;
            font-weight: 700;
        }
        div.stLinkButton > a {
            border-radius: 14px;
            border: none;
            background: #ff4fa3;
            color: white;
            font-weight: 800;
        }
        div.stLinkButton > a:visited {
            color: white;
        }
        div.stLinkButton > a:hover {
            background: #f33f97;
            color: white;
            border: none;
        }
        div.stLinkButton > a:active {
            color: white;
        }
        .topic-card, .quartile-card {
            border-radius: 14px;
            border: 1px solid rgba(255, 127, 186, 0.18);
            background: white;
            box-shadow: 0 6px 14px rgba(214, 51, 132, 0.05);
        }
        .topic-card {
            padding: 16px 18px 10px;
            margin-bottom: 10px;
        }
        .quartile-card {
            padding: 26px 14px 20px;
            min-height: 145px;
            text-align: center;
        }
        .quartile-title {
            color: #d63384;
            font-size: 2.1rem;
            font-weight: 800;
            line-height: 1;
            margin-bottom: 12px;
        }
        div[data-testid="stProgressBar"] > div {
            background-color: rgba(255, 201, 227, 0.42);
            border-radius: 999px;
        }
        div[data-testid="stProgressBar"] div[role="progressbar"] {
            background: #ff4fa3;
            border-radius: 999px;
        }
        </style>
        """,
        unsafe_allow_html=True,
    )


def get_top_95_classes(probabilities, id2label):
    result = []
    total = 0.0
    for class_id, prob in sorted(enumerate(probabilities), key=lambda x: x[1], reverse=True):
        result.append({"label": id2label[class_id], "probability": float(prob)})
        total += float(prob)
        if total >= 0.95:
            break
    return result


@st.cache_resource(show_spinner=False)
def load_artifacts():
    if not FIELD_LABELS_PATH.exists():
        raise FileNotFoundError(f"Topic labels file not found: {FIELD_LABELS_PATH}")
    if not CHECKPOINT_PATH.exists():
        raise FileNotFoundError(f"Model checkpoint not found: {CHECKPOINT_PATH}")

    with open(FIELD_LABELS_PATH, "r", encoding="utf-8") as f:
        field_labels = json.load(f)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = ArticleClassifier(
        num_fields=len(field_labels),
        num_quartiles=4,
        model_name=MODEL_NAME,
    )

    state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
    state_dict.pop("criterion_field.weight", None)
    state_dict.pop("criterion_quartile.weight", None)
    model.load_state_dict(state_dict)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    id2field = {idx: label for label, idx in field_labels.items()}
    return tokenizer, model, id2field, str(device)


@st.cache_data(show_spinner=False)
def predict_article_cached(title, abstract):
    tokenizer, model, id2field, device_str = load_artifacts()
    device = torch.device(device_str)
    text = f"Title: {title.strip()} Abstract: {abstract.strip()}" if abstract else f"Title: {title.strip()}"

    enc = tokenizer(
        text,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    with torch.no_grad():
        output = model(
            input_ids=enc["input_ids"].to(device),
            attention_mask=enc["attention_mask"].to(device),
        )
        field_probs = torch.softmax(output.logits, dim=-1)[0].cpu().tolist()
        quartile_probs = torch.softmax(output.quartile_logits, dim=-1)[0].cpu().tolist()

    quartiles = [
        {"label": QUARTILE_LABELS[i], "probability": float(prob)}
        for i, prob in enumerate(quartile_probs)
    ]
    quartiles.sort(key=lambda x: x["probability"], reverse=True)

    return {
        "top95_fields": get_top_95_classes(field_probs, id2field),
        "quartiles": quartiles,
        "best_quartile": quartiles[0],
    }


def render_topics(top95_fields):
    total = 0.0
    for i, item in enumerate(top95_fields, start=1):
        total += item["probability"]
        color = TOPIC_COLORS[(i - 1) % len(TOPIC_COLORS)]
        st.markdown(
            f"""
            <div class="topic-card">
                <div style="display:flex; gap:12px; align-items:flex-start;">
                    <div style="width:14px;height:14px;min-width:14px;border-radius:50%;background:{color};margin-top:8px;"></div>
                    <div>
                        <div style="color:#d63384;font-weight:800;margin-bottom:6px;">{i}. {item["label"]}</div>
                        <div style="color:#7d2b58;">Probability: {item["probability"] * 100:.2f}%</div>
                    </div>
                </div>
            </div>
            """,
            unsafe_allow_html=True,
        )
        st.progress(min(item["probability"], 1.0))
    st.caption(f"Total probability of displayed topics: {total * 100:.2f}%")


def render_pie(top95_fields):
    sizes = [item["probability"] for item in top95_fields]
    colors = [TOPIC_COLORS[i % len(TOPIC_COLORS)] for i in range(len(top95_fields))]
    fig, ax = plt.subplots(figsize=(6.2, 6.2), facecolor="#fff7fb")
    ax.set_facecolor("#fff7fb")
    ax.pie(
        sizes,
        labels=None,
        colors=colors,
        autopct=lambda p: f"{p:.1f}%" if p >= 4 else "",
        startangle=90,
        pctdistance=0.72,
        wedgeprops={"width": 0.42, "edgecolor": "#fff7fb", "linewidth": 2.8},
        textprops={"color": "#8f2f67", "fontweight": "semibold"},
    )
    ax.text(0, 0, "TOPICS", ha="center", va="center", fontsize=15, fontweight="bold", color="#c43b82")
    ax.axis("equal")
    st.pyplot(fig, clear_figure=True)
    plt.close(fig)


def render_quartiles(quartiles):
    max_prob = max(item["probability"] for item in quartiles) if quartiles else 1.0
    for col, item in zip(st.columns(4), quartiles):
        alpha = 0.18 + 0.42 * (item["probability"] / max_prob if max_prob else 0)
        with col:
            st.markdown(
                f"""
                <div class="quartile-card" style="background: linear-gradient(180deg, rgba(255,255,255,0.96), rgba(255,79,163,{alpha:.3f}));">
                    <div class="quartile-title">{item["label"]}</div>
                    <div style="color:#7d2b58;font-weight:700;">{item["probability"] * 100:.2f}%</div>
                </div>
                """,
                unsafe_allow_html=True,
            )


inject_styles()

st.session_state.setdefault("title_input", "")
st.session_state.setdefault("abstract_input", "")
st.session_state.setdefault("pending_example", None)
st.session_state.setdefault("pending_clear", False)

if st.session_state["pending_example"] is not None:
    for example in EXAMPLES:
        if example["id"] == st.session_state["pending_example"]:
            st.session_state["title_input"] = example["title"]
            st.session_state["abstract_input"] = example["abstract"]
            break
    st.session_state["pending_example"] = None

if st.session_state["pending_clear"]:
    st.session_state["title_input"] = ""
    st.session_state["abstract_input"] = ""
    st.session_state["pending_clear"] = False

st.title("๐Ÿ’– Top-95% Scientific Article Classification")
st.write(
    "The app takes an article title and abstract, identifies the most likely topics, "
    "and also predicts the journal quartile."
)

left_col, right_col = st.columns([1.0, 0.97], gap="medium")

with left_col:
    with st.container(key="input-shell"):
        st.subheader("Input")
        st.text_input(
            "Article title",
            key="title_input",
            placeholder="For example: Graph Neural Networks for Molecular Property Prediction",
            help="You can enter only the title if the abstract is unavailable.",
        )
        st.text_area(
            "Abstract",
            key="abstract_input",
            placeholder="For example: We propose a new method for predicting molecular properties...",
            height=220,
            help="If the abstract is empty, classification will be based on the title only.",
        )
        c1, c2 = st.columns(2)
        with c1:
            predict_button = st.button("Predict topic and quartile", type="primary", use_container_width=True)
        with c2:
            clear_button = st.button("Clear fields", type="tertiary", use_container_width=True)
        status_placeholder = st.empty()
        if clear_button:
            st.session_state["pending_clear"] = True
            st.rerun()

with right_col:
    with st.container(key="examples-shell"):
        st.subheader("Examples")
        for i in range(0, len(EXAMPLES), 2):
            cols = st.columns(2, gap="small")
            for col, example in zip(cols, EXAMPLES[i : i + 2]):
                with col:
                    preview = example["abstract"].strip()
                    preview = preview if len(preview) <= 95 else preview[:94].rstrip() + "โ€ฆ"
                    label = f"**{example['title']}**\n\n{preview or 'Title-only classification'}"
                    if st.button(label, key=example["id"], type="secondary", use_container_width=True):
                        st.session_state["pending_example"] = example["id"]
                        st.rerun()

title = st.session_state["title_input"].strip()
abstract = st.session_state["abstract_input"].strip()

if predict_button:
    try:
        if not title and not abstract:
            st.error("Input error: fill in at least one field: 'Article title' or 'Abstract'.")
        else:
            status_placeholder.info("Running classification...")
            result = predict_article_cached(title, abstract)
            status_placeholder.empty()
            st.success("Classification completed.")

            result_left, result_right = st.columns(2, gap="medium")
            with result_left:
                with st.container(key="topics-shell"):
                    st.subheader("Article topic")
                    text_col, chart_col = st.columns([1.15, 1.1], gap="medium")
                    with text_col:
                        render_topics(result["top95_fields"])
                    with chart_col:
                        render_pie(result["top95_fields"])

            with result_right:
                with st.container(key="quartile-shell"):
                    st.subheader("Journal quartile")
                    st.markdown(
                        f"**Most likely quartile: {result['best_quartile']['label']}** - "
                        f"{result['best_quartile']['probability'] * 100:.2f}%"
                    )
                    render_quartiles(result["quartiles"])

            if abstract:
                st.info("Classification used both the title and the abstract.")
            else:
                st.info("No abstract was provided, so classification used the title only.")

            st.link_button(
                "Click for something interesting",
                "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
            )

    except FileNotFoundError as e:
        status_placeholder.empty()
        st.error(f"Error loading model files: {e}")
    except RuntimeError as e:
        status_placeholder.empty()
        st.error(f"Runtime error during model execution: {e}")
    except Exception as e:
        status_placeholder.empty()
        st.error(f"Unexpected error: {e}")