sci_classifier / src /streamlit_app.py
nolongerlaugh's picture
Update src/streamlit_app.py
e14f368 verified
import json
from pathlib import Path
import matplotlib.pyplot as plt
import streamlit as st
import torch
from transformers import AutoTokenizer
from model import ArticleClassifier
st.set_page_config(
page_title="Scientific Article Classification",
page_icon="💖",
layout="wide",
)
APP_DIR = Path(__file__).resolve().parent
MODEL_NAME = "allenai/scibert_scivocab_cased"
CHECKPOINT_PATH = APP_DIR / "scibert_full.pt"
FIELD_LABELS_PATH = APP_DIR / "field_labels.json"
MAX_LEN = 256
QUARTILE_LABELS = {0: "Q1", 1: "Q2", 2: "Q3", 3: "Q4"}
TOPIC_COLORS = [
"#FF4FA3",
"#FF7A7A",
"#FFAA5B",
"#FFD45C",
"#7FE1D8",
"#53C8F2",
"#7C8CFF",
"#B784F7",
"#FF8FCF",
"#F06292",
]
EXAMPLES = [
{
"id": "alphafold",
"title": "Highly accurate protein structure prediction with AlphaFold",
"abstract": (
"Proteins are central to biology, but experimentally determining their three-"
"dimensional structures remains far slower than the growth of sequence databases. "
"This paper presents a redesigned AlphaFold system, a neural network that combines "
"evolutionary information from multiple-sequence alignments with geometric and "
"physical constraints on protein structure. Evaluated in the blind CASP14 benchmark, "
"the method reaches near-experimental accuracy for many targets, including cases "
"without close structural homologues, and substantially outperforms competing "
"approaches. The model predicts full atomic coordinates, scales to long proteins, "
"and reports confidence estimates that help users judge which regions of a predicted "
"structure are reliable. By turning sequence data into high-quality structural "
"models at scale, the work positions deep learning as a practical tool for "
"structural biology, functional annotation, and downstream biological discovery."
),
},
{
"id": "weather",
"title": "Skilful precipitation nowcasting using deep generative models of radar",
"abstract": (
"Short-range precipitation nowcasting is critical for emergency response, energy "
"operations, transport, flood warning, and other weather-sensitive decisions, yet "
"traditional radar-advection systems struggle with nonlinear events such as "
"convective initiation. Earlier deep learning approaches improve some low-intensity "
"forecasts but often become blurry at longer lead times and perform poorly on rarer, "
"heavier rain. This paper introduces a probabilistic deep generative model that "
"predicts future radar fields directly from recent radar observations while "
"preserving realistic spatial and temporal structure. Across statistical, economic, "
"and expert-evaluation measures, the model improves forecast quality, consistency, "
"and operational value. It generates realistic predictions over regions up to 1536 "
"by 1280 kilometres and for lead times from 5 to 90 minutes. In evaluations "
"involving more than fifty meteorologists, it ranked first against strong baseline "
"systems in most cases, showing that generative models can produce useful "
"high-resolution nowcasts without relying on blur."
),
},
{
"id": "materials",
"title": "Experimental search for high-temperature ferroelectric perovskites guided by two-step machine learning",
"abstract": (
"Searching for high-temperature ferroelectric perovskites is difficult because the "
"chemical space is enormous, experimental guidance is limited, and many candidate "
"compositions fail to form the desired phase. This study proposes a two-step machine "
"learning workflow to guide synthesis in xBi(Me' y Me'' 1-y)O3-(1-x)PbTiO3-type "
"systems. A classification model first screens compositions likely to crystallize as "
"perovskites, and a regression model coupled with active learning then prioritizes "
"candidates with high Curie temperature for experimental testing. The search spans "
"roughly 61,500 possible compositions, whereas only 167 had been characterized "
"beforehand. By iterating between prediction, synthesis, and feedback from both "
"successful and failed experiments, the authors efficiently refine the models and "
"focus the search. Out of ten newly synthesized candidates, six form perovskites, "
"including three previously unexplored cation pairs, and one composition reaches a "
"measured Curie temperature of 898 K. The work shows how machine learning can "
"meaningfully reduce experimental burden in materials discovery."
),
},
{
"id": "chem_xai",
"title": "Chemistry-intuitive explanation of graph neural networks for molecular property prediction with substructure masking",
"abstract": (
"Graph neural networks are widely used for molecular property prediction, but their "
"explanations often highlight individual atoms, bonds, or fragments that do not "
"match how chemists reason about structure-property relationships. This paper "
"introduces Substructure Mask Explanation, an interpretation method built around "
"chemically meaningful molecular segmentations such as functional groups and other "
"established fragment schemes. Instead of assigning importance to isolated nodes or "
"edges, the method reveals which substructures drive a graph model's prediction in a "
"way that is easier for domain experts to inspect. The authors apply the approach to "
"models for aqueous solubility, genotoxicity, cardiotoxicity, and blood-brain "
"barrier permeation in small molecules. Across these cases, the explanations align "
"better with chemical intuition, help identify unreliable model behaviour, and offer "
"actionable guidance for structural optimization. The paper argues that "
"substructure-level explainability can make graph neural networks more useful for "
"medicinal chemistry and drug-discovery workflows."
),
},
]
def inject_styles():
st.markdown(
"""
<style>
.stApp {
color: #7d2b58;
background: #fff4fa;
}
.stApp h1 {
color: #d63384;
font-size: clamp(2.1rem, 3vw, 3rem);
font-weight: 900;
letter-spacing: 0.01em;
white-space: nowrap;
}
.st-key-input-shell, .st-key-examples-shell, .st-key-topics-shell, .st-key-quartile-shell {
border: 1px solid rgba(255, 109, 177, 0.22);
border-radius: 18px;
padding: 1rem 1.1rem;
background: white;
box-shadow: 0 6px 14px rgba(214, 51, 132, 0.06);
}
div.stButton > button {
border-radius: 14px;
}
div.stButton > button[kind="primary"] {
border: none;
background: #ff4fa3;
color: white;
font-weight: 800;
}
div.stButton > button[kind="secondary"] {
min-height: 140px;
text-align: left;
white-space: normal;
border: 1px solid rgba(255, 118, 186, 0.22);
background: #fff8fc;
}
div.stButton > button[kind="secondary"] p { margin: 0; }
div.stButton > button[kind="secondary"] p strong { color: #d63384; }
div.stButton > button[kind="tertiary"] {
border: 1px solid rgba(255, 116, 184, 0.44);
background: white;
color: #d63384;
font-weight: 700;
}
div.stLinkButton > a {
border-radius: 14px;
border: none;
background: #ff4fa3;
color: white;
font-weight: 800;
}
div.stLinkButton > a:visited {
color: white;
}
div.stLinkButton > a:hover {
background: #f33f97;
color: white;
border: none;
}
div.stLinkButton > a:active {
color: white;
}
.topic-card, .quartile-card {
border-radius: 14px;
border: 1px solid rgba(255, 127, 186, 0.18);
background: white;
box-shadow: 0 6px 14px rgba(214, 51, 132, 0.05);
}
.topic-card {
padding: 16px 18px 10px;
margin-bottom: 10px;
}
.quartile-card {
padding: 26px 14px 20px;
min-height: 145px;
text-align: center;
}
.quartile-title {
color: #d63384;
font-size: 2.1rem;
font-weight: 800;
line-height: 1;
margin-bottom: 12px;
}
div[data-testid="stProgressBar"] > div {
background-color: rgba(255, 201, 227, 0.42);
border-radius: 999px;
}
div[data-testid="stProgressBar"] div[role="progressbar"] {
background: #ff4fa3;
border-radius: 999px;
}
</style>
""",
unsafe_allow_html=True,
)
def get_top_95_classes(probabilities, id2label):
result = []
total = 0.0
for class_id, prob in sorted(enumerate(probabilities), key=lambda x: x[1], reverse=True):
result.append({"label": id2label[class_id], "probability": float(prob)})
total += float(prob)
if total >= 0.95:
break
return result
@st.cache_resource(show_spinner=False)
def load_artifacts():
if not FIELD_LABELS_PATH.exists():
raise FileNotFoundError(f"Topic labels file not found: {FIELD_LABELS_PATH}")
if not CHECKPOINT_PATH.exists():
raise FileNotFoundError(f"Model checkpoint not found: {CHECKPOINT_PATH}")
with open(FIELD_LABELS_PATH, "r", encoding="utf-8") as f:
field_labels = json.load(f)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = ArticleClassifier(
num_fields=len(field_labels),
num_quartiles=4,
model_name=MODEL_NAME,
)
state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
state_dict.pop("criterion_field.weight", None)
state_dict.pop("criterion_quartile.weight", None)
model.load_state_dict(state_dict)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
id2field = {idx: label for label, idx in field_labels.items()}
return tokenizer, model, id2field, str(device)
@st.cache_data(show_spinner=False)
def predict_article_cached(title, abstract):
tokenizer, model, id2field, device_str = load_artifacts()
device = torch.device(device_str)
text = f"Title: {title.strip()} Abstract: {abstract.strip()}" if abstract else f"Title: {title.strip()}"
enc = tokenizer(
text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
output = model(
input_ids=enc["input_ids"].to(device),
attention_mask=enc["attention_mask"].to(device),
)
field_probs = torch.softmax(output.logits, dim=-1)[0].cpu().tolist()
quartile_probs = torch.softmax(output.quartile_logits, dim=-1)[0].cpu().tolist()
quartiles = [
{"label": QUARTILE_LABELS[i], "probability": float(prob)}
for i, prob in enumerate(quartile_probs)
]
quartiles.sort(key=lambda x: x["probability"], reverse=True)
return {
"top95_fields": get_top_95_classes(field_probs, id2field),
"quartiles": quartiles,
"best_quartile": quartiles[0],
}
def render_topics(top95_fields):
total = 0.0
for i, item in enumerate(top95_fields, start=1):
total += item["probability"]
color = TOPIC_COLORS[(i - 1) % len(TOPIC_COLORS)]
st.markdown(
f"""
<div class="topic-card">
<div style="display:flex; gap:12px; align-items:flex-start;">
<div style="width:14px;height:14px;min-width:14px;border-radius:50%;background:{color};margin-top:8px;"></div>
<div>
<div style="color:#d63384;font-weight:800;margin-bottom:6px;">{i}. {item["label"]}</div>
<div style="color:#7d2b58;">Probability: {item["probability"] * 100:.2f}%</div>
</div>
</div>
</div>
""",
unsafe_allow_html=True,
)
st.progress(min(item["probability"], 1.0))
st.caption(f"Total probability of displayed topics: {total * 100:.2f}%")
def render_pie(top95_fields):
sizes = [item["probability"] for item in top95_fields]
colors = [TOPIC_COLORS[i % len(TOPIC_COLORS)] for i in range(len(top95_fields))]
fig, ax = plt.subplots(figsize=(6.2, 6.2), facecolor="#fff7fb")
ax.set_facecolor("#fff7fb")
ax.pie(
sizes,
labels=None,
colors=colors,
autopct=lambda p: f"{p:.1f}%" if p >= 4 else "",
startangle=90,
pctdistance=0.72,
wedgeprops={"width": 0.42, "edgecolor": "#fff7fb", "linewidth": 2.8},
textprops={"color": "#8f2f67", "fontweight": "semibold"},
)
ax.text(0, 0, "TOPICS", ha="center", va="center", fontsize=15, fontweight="bold", color="#c43b82")
ax.axis("equal")
st.pyplot(fig, clear_figure=True)
plt.close(fig)
def render_quartiles(quartiles):
max_prob = max(item["probability"] for item in quartiles) if quartiles else 1.0
for col, item in zip(st.columns(4), quartiles):
alpha = 0.18 + 0.42 * (item["probability"] / max_prob if max_prob else 0)
with col:
st.markdown(
f"""
<div class="quartile-card" style="background: linear-gradient(180deg, rgba(255,255,255,0.96), rgba(255,79,163,{alpha:.3f}));">
<div class="quartile-title">{item["label"]}</div>
<div style="color:#7d2b58;font-weight:700;">{item["probability"] * 100:.2f}%</div>
</div>
""",
unsafe_allow_html=True,
)
inject_styles()
st.session_state.setdefault("title_input", "")
st.session_state.setdefault("abstract_input", "")
st.session_state.setdefault("pending_example", None)
st.session_state.setdefault("pending_clear", False)
if st.session_state["pending_example"] is not None:
for example in EXAMPLES:
if example["id"] == st.session_state["pending_example"]:
st.session_state["title_input"] = example["title"]
st.session_state["abstract_input"] = example["abstract"]
break
st.session_state["pending_example"] = None
if st.session_state["pending_clear"]:
st.session_state["title_input"] = ""
st.session_state["abstract_input"] = ""
st.session_state["pending_clear"] = False
st.title("💖 Top-95% Scientific Article Classification")
st.write(
"The app takes an article title and abstract, identifies the most likely topics, "
"and also predicts the journal quartile."
)
left_col, right_col = st.columns([1.0, 0.97], gap="medium")
with left_col:
with st.container(key="input-shell"):
st.subheader("Input")
st.text_input(
"Article title",
key="title_input",
placeholder="For example: Graph Neural Networks for Molecular Property Prediction",
help="You can enter only the title if the abstract is unavailable.",
)
st.text_area(
"Abstract",
key="abstract_input",
placeholder="For example: We propose a new method for predicting molecular properties...",
height=220,
help="If the abstract is empty, classification will be based on the title only.",
)
c1, c2 = st.columns(2)
with c1:
predict_button = st.button("Predict topic and quartile", type="primary", use_container_width=True)
with c2:
clear_button = st.button("Clear fields", type="tertiary", use_container_width=True)
status_placeholder = st.empty()
if clear_button:
st.session_state["pending_clear"] = True
st.rerun()
with right_col:
with st.container(key="examples-shell"):
st.subheader("Examples")
for i in range(0, len(EXAMPLES), 2):
cols = st.columns(2, gap="small")
for col, example in zip(cols, EXAMPLES[i : i + 2]):
with col:
preview = example["abstract"].strip()
preview = preview if len(preview) <= 95 else preview[:94].rstrip() + "…"
label = f"**{example['title']}**\n\n{preview or 'Title-only classification'}"
if st.button(label, key=example["id"], type="secondary", use_container_width=True):
st.session_state["pending_example"] = example["id"]
st.rerun()
title = st.session_state["title_input"].strip()
abstract = st.session_state["abstract_input"].strip()
if predict_button:
try:
if not title and not abstract:
st.error("Input error: fill in at least one field: 'Article title' or 'Abstract'.")
else:
status_placeholder.info("Running classification...")
result = predict_article_cached(title, abstract)
status_placeholder.empty()
st.success("Classification completed.")
result_left, result_right = st.columns(2, gap="medium")
with result_left:
with st.container(key="topics-shell"):
st.subheader("Article topic")
text_col, chart_col = st.columns([1.15, 1.1], gap="medium")
with text_col:
render_topics(result["top95_fields"])
with chart_col:
render_pie(result["top95_fields"])
with result_right:
with st.container(key="quartile-shell"):
st.subheader("Journal quartile")
st.markdown(
f"**Most likely quartile: {result['best_quartile']['label']}** - "
f"{result['best_quartile']['probability'] * 100:.2f}%"
)
render_quartiles(result["quartiles"])
if abstract:
st.info("Classification used both the title and the abstract.")
else:
st.info("No abstract was provided, so classification used the title only.")
st.link_button(
"Click for something interesting",
"https://www.youtube.com/watch?v=dQw4w9WgXcQ",
)
except FileNotFoundError as e:
status_placeholder.empty()
st.error(f"Error loading model files: {e}")
except RuntimeError as e:
status_placeholder.empty()
st.error(f"Runtime error during model execution: {e}")
except Exception as e:
status_placeholder.empty()
st.error(f"Unexpected error: {e}")