import json from pathlib import Path import matplotlib.pyplot as plt import streamlit as st import torch from transformers import AutoTokenizer from model import ArticleClassifier st.set_page_config( page_title="Scientific Article Classification", page_icon="💖", layout="wide", ) APP_DIR = Path(__file__).resolve().parent MODEL_NAME = "allenai/scibert_scivocab_cased" CHECKPOINT_PATH = APP_DIR / "scibert_full.pt" FIELD_LABELS_PATH = APP_DIR / "field_labels.json" MAX_LEN = 256 QUARTILE_LABELS = {0: "Q1", 1: "Q2", 2: "Q3", 3: "Q4"} TOPIC_COLORS = [ "#FF4FA3", "#FF7A7A", "#FFAA5B", "#FFD45C", "#7FE1D8", "#53C8F2", "#7C8CFF", "#B784F7", "#FF8FCF", "#F06292", ] EXAMPLES = [ { "id": "alphafold", "title": "Highly accurate protein structure prediction with AlphaFold", "abstract": ( "Proteins are central to biology, but experimentally determining their three-" "dimensional structures remains far slower than the growth of sequence databases. " "This paper presents a redesigned AlphaFold system, a neural network that combines " "evolutionary information from multiple-sequence alignments with geometric and " "physical constraints on protein structure. Evaluated in the blind CASP14 benchmark, " "the method reaches near-experimental accuracy for many targets, including cases " "without close structural homologues, and substantially outperforms competing " "approaches. The model predicts full atomic coordinates, scales to long proteins, " "and reports confidence estimates that help users judge which regions of a predicted " "structure are reliable. By turning sequence data into high-quality structural " "models at scale, the work positions deep learning as a practical tool for " "structural biology, functional annotation, and downstream biological discovery." ), }, { "id": "weather", "title": "Skilful precipitation nowcasting using deep generative models of radar", "abstract": ( "Short-range precipitation nowcasting is critical for emergency response, energy " "operations, transport, flood warning, and other weather-sensitive decisions, yet " "traditional radar-advection systems struggle with nonlinear events such as " "convective initiation. Earlier deep learning approaches improve some low-intensity " "forecasts but often become blurry at longer lead times and perform poorly on rarer, " "heavier rain. This paper introduces a probabilistic deep generative model that " "predicts future radar fields directly from recent radar observations while " "preserving realistic spatial and temporal structure. Across statistical, economic, " "and expert-evaluation measures, the model improves forecast quality, consistency, " "and operational value. It generates realistic predictions over regions up to 1536 " "by 1280 kilometres and for lead times from 5 to 90 minutes. In evaluations " "involving more than fifty meteorologists, it ranked first against strong baseline " "systems in most cases, showing that generative models can produce useful " "high-resolution nowcasts without relying on blur." ), }, { "id": "materials", "title": "Experimental search for high-temperature ferroelectric perovskites guided by two-step machine learning", "abstract": ( "Searching for high-temperature ferroelectric perovskites is difficult because the " "chemical space is enormous, experimental guidance is limited, and many candidate " "compositions fail to form the desired phase. This study proposes a two-step machine " "learning workflow to guide synthesis in xBi(Me' y Me'' 1-y)O3-(1-x)PbTiO3-type " "systems. A classification model first screens compositions likely to crystallize as " "perovskites, and a regression model coupled with active learning then prioritizes " "candidates with high Curie temperature for experimental testing. The search spans " "roughly 61,500 possible compositions, whereas only 167 had been characterized " "beforehand. By iterating between prediction, synthesis, and feedback from both " "successful and failed experiments, the authors efficiently refine the models and " "focus the search. Out of ten newly synthesized candidates, six form perovskites, " "including three previously unexplored cation pairs, and one composition reaches a " "measured Curie temperature of 898 K. The work shows how machine learning can " "meaningfully reduce experimental burden in materials discovery." ), }, { "id": "chem_xai", "title": "Chemistry-intuitive explanation of graph neural networks for molecular property prediction with substructure masking", "abstract": ( "Graph neural networks are widely used for molecular property prediction, but their " "explanations often highlight individual atoms, bonds, or fragments that do not " "match how chemists reason about structure-property relationships. This paper " "introduces Substructure Mask Explanation, an interpretation method built around " "chemically meaningful molecular segmentations such as functional groups and other " "established fragment schemes. Instead of assigning importance to isolated nodes or " "edges, the method reveals which substructures drive a graph model's prediction in a " "way that is easier for domain experts to inspect. The authors apply the approach to " "models for aqueous solubility, genotoxicity, cardiotoxicity, and blood-brain " "barrier permeation in small molecules. Across these cases, the explanations align " "better with chemical intuition, help identify unreliable model behaviour, and offer " "actionable guidance for structural optimization. The paper argues that " "substructure-level explainability can make graph neural networks more useful for " "medicinal chemistry and drug-discovery workflows." ), }, ] def inject_styles(): st.markdown( """ """, unsafe_allow_html=True, ) def get_top_95_classes(probabilities, id2label): result = [] total = 0.0 for class_id, prob in sorted(enumerate(probabilities), key=lambda x: x[1], reverse=True): result.append({"label": id2label[class_id], "probability": float(prob)}) total += float(prob) if total >= 0.95: break return result @st.cache_resource(show_spinner=False) def load_artifacts(): if not FIELD_LABELS_PATH.exists(): raise FileNotFoundError(f"Topic labels file not found: {FIELD_LABELS_PATH}") if not CHECKPOINT_PATH.exists(): raise FileNotFoundError(f"Model checkpoint not found: {CHECKPOINT_PATH}") with open(FIELD_LABELS_PATH, "r", encoding="utf-8") as f: field_labels = json.load(f) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = ArticleClassifier( num_fields=len(field_labels), num_quartiles=4, model_name=MODEL_NAME, ) state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu") state_dict.pop("criterion_field.weight", None) state_dict.pop("criterion_quartile.weight", None) model.load_state_dict(state_dict) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) id2field = {idx: label for label, idx in field_labels.items()} return tokenizer, model, id2field, str(device) @st.cache_data(show_spinner=False) def predict_article_cached(title, abstract): tokenizer, model, id2field, device_str = load_artifacts() device = torch.device(device_str) text = f"Title: {title.strip()} Abstract: {abstract.strip()}" if abstract else f"Title: {title.strip()}" enc = tokenizer( text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt", ) with torch.no_grad(): output = model( input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device), ) field_probs = torch.softmax(output.logits, dim=-1)[0].cpu().tolist() quartile_probs = torch.softmax(output.quartile_logits, dim=-1)[0].cpu().tolist() quartiles = [ {"label": QUARTILE_LABELS[i], "probability": float(prob)} for i, prob in enumerate(quartile_probs) ] quartiles.sort(key=lambda x: x["probability"], reverse=True) return { "top95_fields": get_top_95_classes(field_probs, id2field), "quartiles": quartiles, "best_quartile": quartiles[0], } def render_topics(top95_fields): total = 0.0 for i, item in enumerate(top95_fields, start=1): total += item["probability"] color = TOPIC_COLORS[(i - 1) % len(TOPIC_COLORS)] st.markdown( f"""
{i}. {item["label"]}
Probability: {item["probability"] * 100:.2f}%
""", unsafe_allow_html=True, ) st.progress(min(item["probability"], 1.0)) st.caption(f"Total probability of displayed topics: {total * 100:.2f}%") def render_pie(top95_fields): sizes = [item["probability"] for item in top95_fields] colors = [TOPIC_COLORS[i % len(TOPIC_COLORS)] for i in range(len(top95_fields))] fig, ax = plt.subplots(figsize=(6.2, 6.2), facecolor="#fff7fb") ax.set_facecolor("#fff7fb") ax.pie( sizes, labels=None, colors=colors, autopct=lambda p: f"{p:.1f}%" if p >= 4 else "", startangle=90, pctdistance=0.72, wedgeprops={"width": 0.42, "edgecolor": "#fff7fb", "linewidth": 2.8}, textprops={"color": "#8f2f67", "fontweight": "semibold"}, ) ax.text(0, 0, "TOPICS", ha="center", va="center", fontsize=15, fontweight="bold", color="#c43b82") ax.axis("equal") st.pyplot(fig, clear_figure=True) plt.close(fig) def render_quartiles(quartiles): max_prob = max(item["probability"] for item in quartiles) if quartiles else 1.0 for col, item in zip(st.columns(4), quartiles): alpha = 0.18 + 0.42 * (item["probability"] / max_prob if max_prob else 0) with col: st.markdown( f"""
{item["label"]}
{item["probability"] * 100:.2f}%
""", unsafe_allow_html=True, ) inject_styles() st.session_state.setdefault("title_input", "") st.session_state.setdefault("abstract_input", "") st.session_state.setdefault("pending_example", None) st.session_state.setdefault("pending_clear", False) if st.session_state["pending_example"] is not None: for example in EXAMPLES: if example["id"] == st.session_state["pending_example"]: st.session_state["title_input"] = example["title"] st.session_state["abstract_input"] = example["abstract"] break st.session_state["pending_example"] = None if st.session_state["pending_clear"]: st.session_state["title_input"] = "" st.session_state["abstract_input"] = "" st.session_state["pending_clear"] = False st.title("💖 Top-95% Scientific Article Classification") st.write( "The app takes an article title and abstract, identifies the most likely topics, " "and also predicts the journal quartile." ) left_col, right_col = st.columns([1.0, 0.97], gap="medium") with left_col: with st.container(key="input-shell"): st.subheader("Input") st.text_input( "Article title", key="title_input", placeholder="For example: Graph Neural Networks for Molecular Property Prediction", help="You can enter only the title if the abstract is unavailable.", ) st.text_area( "Abstract", key="abstract_input", placeholder="For example: We propose a new method for predicting molecular properties...", height=220, help="If the abstract is empty, classification will be based on the title only.", ) c1, c2 = st.columns(2) with c1: predict_button = st.button("Predict topic and quartile", type="primary", use_container_width=True) with c2: clear_button = st.button("Clear fields", type="tertiary", use_container_width=True) status_placeholder = st.empty() if clear_button: st.session_state["pending_clear"] = True st.rerun() with right_col: with st.container(key="examples-shell"): st.subheader("Examples") for i in range(0, len(EXAMPLES), 2): cols = st.columns(2, gap="small") for col, example in zip(cols, EXAMPLES[i : i + 2]): with col: preview = example["abstract"].strip() preview = preview if len(preview) <= 95 else preview[:94].rstrip() + "…" label = f"**{example['title']}**\n\n{preview or 'Title-only classification'}" if st.button(label, key=example["id"], type="secondary", use_container_width=True): st.session_state["pending_example"] = example["id"] st.rerun() title = st.session_state["title_input"].strip() abstract = st.session_state["abstract_input"].strip() if predict_button: try: if not title and not abstract: st.error("Input error: fill in at least one field: 'Article title' or 'Abstract'.") else: status_placeholder.info("Running classification...") result = predict_article_cached(title, abstract) status_placeholder.empty() st.success("Classification completed.") result_left, result_right = st.columns(2, gap="medium") with result_left: with st.container(key="topics-shell"): st.subheader("Article topic") text_col, chart_col = st.columns([1.15, 1.1], gap="medium") with text_col: render_topics(result["top95_fields"]) with chart_col: render_pie(result["top95_fields"]) with result_right: with st.container(key="quartile-shell"): st.subheader("Journal quartile") st.markdown( f"**Most likely quartile: {result['best_quartile']['label']}** - " f"{result['best_quartile']['probability'] * 100:.2f}%" ) render_quartiles(result["quartiles"]) if abstract: st.info("Classification used both the title and the abstract.") else: st.info("No abstract was provided, so classification used the title only.") st.link_button( "Click for something interesting", "https://www.youtube.com/watch?v=dQw4w9WgXcQ", ) except FileNotFoundError as e: status_placeholder.empty() st.error(f"Error loading model files: {e}") except RuntimeError as e: status_placeholder.empty() st.error(f"Runtime error during model execution: {e}") except Exception as e: status_placeholder.empty() st.error(f"Unexpected error: {e}")