Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer | |
| from model import ArticleClassifier | |
| st.set_page_config( | |
| page_title="Scientific Article Classification", | |
| page_icon="💖", | |
| layout="wide", | |
| ) | |
| APP_DIR = Path(__file__).resolve().parent | |
| MODEL_NAME = "allenai/scibert_scivocab_cased" | |
| CHECKPOINT_PATH = APP_DIR / "scibert_full.pt" | |
| FIELD_LABELS_PATH = APP_DIR / "field_labels.json" | |
| MAX_LEN = 256 | |
| QUARTILE_LABELS = {0: "Q1", 1: "Q2", 2: "Q3", 3: "Q4"} | |
| TOPIC_COLORS = [ | |
| "#FF4FA3", | |
| "#FF7A7A", | |
| "#FFAA5B", | |
| "#FFD45C", | |
| "#7FE1D8", | |
| "#53C8F2", | |
| "#7C8CFF", | |
| "#B784F7", | |
| "#FF8FCF", | |
| "#F06292", | |
| ] | |
| EXAMPLES = [ | |
| { | |
| "id": "alphafold", | |
| "title": "Highly accurate protein structure prediction with AlphaFold", | |
| "abstract": ( | |
| "Proteins are central to biology, but experimentally determining their three-" | |
| "dimensional structures remains far slower than the growth of sequence databases. " | |
| "This paper presents a redesigned AlphaFold system, a neural network that combines " | |
| "evolutionary information from multiple-sequence alignments with geometric and " | |
| "physical constraints on protein structure. Evaluated in the blind CASP14 benchmark, " | |
| "the method reaches near-experimental accuracy for many targets, including cases " | |
| "without close structural homologues, and substantially outperforms competing " | |
| "approaches. The model predicts full atomic coordinates, scales to long proteins, " | |
| "and reports confidence estimates that help users judge which regions of a predicted " | |
| "structure are reliable. By turning sequence data into high-quality structural " | |
| "models at scale, the work positions deep learning as a practical tool for " | |
| "structural biology, functional annotation, and downstream biological discovery." | |
| ), | |
| }, | |
| { | |
| "id": "weather", | |
| "title": "Skilful precipitation nowcasting using deep generative models of radar", | |
| "abstract": ( | |
| "Short-range precipitation nowcasting is critical for emergency response, energy " | |
| "operations, transport, flood warning, and other weather-sensitive decisions, yet " | |
| "traditional radar-advection systems struggle with nonlinear events such as " | |
| "convective initiation. Earlier deep learning approaches improve some low-intensity " | |
| "forecasts but often become blurry at longer lead times and perform poorly on rarer, " | |
| "heavier rain. This paper introduces a probabilistic deep generative model that " | |
| "predicts future radar fields directly from recent radar observations while " | |
| "preserving realistic spatial and temporal structure. Across statistical, economic, " | |
| "and expert-evaluation measures, the model improves forecast quality, consistency, " | |
| "and operational value. It generates realistic predictions over regions up to 1536 " | |
| "by 1280 kilometres and for lead times from 5 to 90 minutes. In evaluations " | |
| "involving more than fifty meteorologists, it ranked first against strong baseline " | |
| "systems in most cases, showing that generative models can produce useful " | |
| "high-resolution nowcasts without relying on blur." | |
| ), | |
| }, | |
| { | |
| "id": "materials", | |
| "title": "Experimental search for high-temperature ferroelectric perovskites guided by two-step machine learning", | |
| "abstract": ( | |
| "Searching for high-temperature ferroelectric perovskites is difficult because the " | |
| "chemical space is enormous, experimental guidance is limited, and many candidate " | |
| "compositions fail to form the desired phase. This study proposes a two-step machine " | |
| "learning workflow to guide synthesis in xBi(Me' y Me'' 1-y)O3-(1-x)PbTiO3-type " | |
| "systems. A classification model first screens compositions likely to crystallize as " | |
| "perovskites, and a regression model coupled with active learning then prioritizes " | |
| "candidates with high Curie temperature for experimental testing. The search spans " | |
| "roughly 61,500 possible compositions, whereas only 167 had been characterized " | |
| "beforehand. By iterating between prediction, synthesis, and feedback from both " | |
| "successful and failed experiments, the authors efficiently refine the models and " | |
| "focus the search. Out of ten newly synthesized candidates, six form perovskites, " | |
| "including three previously unexplored cation pairs, and one composition reaches a " | |
| "measured Curie temperature of 898 K. The work shows how machine learning can " | |
| "meaningfully reduce experimental burden in materials discovery." | |
| ), | |
| }, | |
| { | |
| "id": "chem_xai", | |
| "title": "Chemistry-intuitive explanation of graph neural networks for molecular property prediction with substructure masking", | |
| "abstract": ( | |
| "Graph neural networks are widely used for molecular property prediction, but their " | |
| "explanations often highlight individual atoms, bonds, or fragments that do not " | |
| "match how chemists reason about structure-property relationships. This paper " | |
| "introduces Substructure Mask Explanation, an interpretation method built around " | |
| "chemically meaningful molecular segmentations such as functional groups and other " | |
| "established fragment schemes. Instead of assigning importance to isolated nodes or " | |
| "edges, the method reveals which substructures drive a graph model's prediction in a " | |
| "way that is easier for domain experts to inspect. The authors apply the approach to " | |
| "models for aqueous solubility, genotoxicity, cardiotoxicity, and blood-brain " | |
| "barrier permeation in small molecules. Across these cases, the explanations align " | |
| "better with chemical intuition, help identify unreliable model behaviour, and offer " | |
| "actionable guidance for structural optimization. The paper argues that " | |
| "substructure-level explainability can make graph neural networks more useful for " | |
| "medicinal chemistry and drug-discovery workflows." | |
| ), | |
| }, | |
| ] | |
| def inject_styles(): | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp { | |
| color: #7d2b58; | |
| background: #fff4fa; | |
| } | |
| .stApp h1 { | |
| color: #d63384; | |
| font-size: clamp(2.1rem, 3vw, 3rem); | |
| font-weight: 900; | |
| letter-spacing: 0.01em; | |
| white-space: nowrap; | |
| } | |
| .st-key-input-shell, .st-key-examples-shell, .st-key-topics-shell, .st-key-quartile-shell { | |
| border: 1px solid rgba(255, 109, 177, 0.22); | |
| border-radius: 18px; | |
| padding: 1rem 1.1rem; | |
| background: white; | |
| box-shadow: 0 6px 14px rgba(214, 51, 132, 0.06); | |
| } | |
| div.stButton > button { | |
| border-radius: 14px; | |
| } | |
| div.stButton > button[kind="primary"] { | |
| border: none; | |
| background: #ff4fa3; | |
| color: white; | |
| font-weight: 800; | |
| } | |
| div.stButton > button[kind="secondary"] { | |
| min-height: 140px; | |
| text-align: left; | |
| white-space: normal; | |
| border: 1px solid rgba(255, 118, 186, 0.22); | |
| background: #fff8fc; | |
| } | |
| div.stButton > button[kind="secondary"] p { margin: 0; } | |
| div.stButton > button[kind="secondary"] p strong { color: #d63384; } | |
| div.stButton > button[kind="tertiary"] { | |
| border: 1px solid rgba(255, 116, 184, 0.44); | |
| background: white; | |
| color: #d63384; | |
| font-weight: 700; | |
| } | |
| div.stLinkButton > a { | |
| border-radius: 14px; | |
| border: none; | |
| background: #ff4fa3; | |
| color: white; | |
| font-weight: 800; | |
| } | |
| div.stLinkButton > a:visited { | |
| color: white; | |
| } | |
| div.stLinkButton > a:hover { | |
| background: #f33f97; | |
| color: white; | |
| border: none; | |
| } | |
| div.stLinkButton > a:active { | |
| color: white; | |
| } | |
| .topic-card, .quartile-card { | |
| border-radius: 14px; | |
| border: 1px solid rgba(255, 127, 186, 0.18); | |
| background: white; | |
| box-shadow: 0 6px 14px rgba(214, 51, 132, 0.05); | |
| } | |
| .topic-card { | |
| padding: 16px 18px 10px; | |
| margin-bottom: 10px; | |
| } | |
| .quartile-card { | |
| padding: 26px 14px 20px; | |
| min-height: 145px; | |
| text-align: center; | |
| } | |
| .quartile-title { | |
| color: #d63384; | |
| font-size: 2.1rem; | |
| font-weight: 800; | |
| line-height: 1; | |
| margin-bottom: 12px; | |
| } | |
| div[data-testid="stProgressBar"] > div { | |
| background-color: rgba(255, 201, 227, 0.42); | |
| border-radius: 999px; | |
| } | |
| div[data-testid="stProgressBar"] div[role="progressbar"] { | |
| background: #ff4fa3; | |
| border-radius: 999px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def get_top_95_classes(probabilities, id2label): | |
| result = [] | |
| total = 0.0 | |
| for class_id, prob in sorted(enumerate(probabilities), key=lambda x: x[1], reverse=True): | |
| result.append({"label": id2label[class_id], "probability": float(prob)}) | |
| total += float(prob) | |
| if total >= 0.95: | |
| break | |
| return result | |
| def load_artifacts(): | |
| if not FIELD_LABELS_PATH.exists(): | |
| raise FileNotFoundError(f"Topic labels file not found: {FIELD_LABELS_PATH}") | |
| if not CHECKPOINT_PATH.exists(): | |
| raise FileNotFoundError(f"Model checkpoint not found: {CHECKPOINT_PATH}") | |
| with open(FIELD_LABELS_PATH, "r", encoding="utf-8") as f: | |
| field_labels = json.load(f) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = ArticleClassifier( | |
| num_fields=len(field_labels), | |
| num_quartiles=4, | |
| model_name=MODEL_NAME, | |
| ) | |
| state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu") | |
| state_dict.pop("criterion_field.weight", None) | |
| state_dict.pop("criterion_quartile.weight", None) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| id2field = {idx: label for label, idx in field_labels.items()} | |
| return tokenizer, model, id2field, str(device) | |
| def predict_article_cached(title, abstract): | |
| tokenizer, model, id2field, device_str = load_artifacts() | |
| device = torch.device(device_str) | |
| text = f"Title: {title.strip()} Abstract: {abstract.strip()}" if abstract else f"Title: {title.strip()}" | |
| enc = tokenizer( | |
| text, | |
| max_length=MAX_LEN, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| output = model( | |
| input_ids=enc["input_ids"].to(device), | |
| attention_mask=enc["attention_mask"].to(device), | |
| ) | |
| field_probs = torch.softmax(output.logits, dim=-1)[0].cpu().tolist() | |
| quartile_probs = torch.softmax(output.quartile_logits, dim=-1)[0].cpu().tolist() | |
| quartiles = [ | |
| {"label": QUARTILE_LABELS[i], "probability": float(prob)} | |
| for i, prob in enumerate(quartile_probs) | |
| ] | |
| quartiles.sort(key=lambda x: x["probability"], reverse=True) | |
| return { | |
| "top95_fields": get_top_95_classes(field_probs, id2field), | |
| "quartiles": quartiles, | |
| "best_quartile": quartiles[0], | |
| } | |
| def render_topics(top95_fields): | |
| total = 0.0 | |
| for i, item in enumerate(top95_fields, start=1): | |
| total += item["probability"] | |
| color = TOPIC_COLORS[(i - 1) % len(TOPIC_COLORS)] | |
| st.markdown( | |
| f""" | |
| <div class="topic-card"> | |
| <div style="display:flex; gap:12px; align-items:flex-start;"> | |
| <div style="width:14px;height:14px;min-width:14px;border-radius:50%;background:{color};margin-top:8px;"></div> | |
| <div> | |
| <div style="color:#d63384;font-weight:800;margin-bottom:6px;">{i}. {item["label"]}</div> | |
| <div style="color:#7d2b58;">Probability: {item["probability"] * 100:.2f}%</div> | |
| </div> | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.progress(min(item["probability"], 1.0)) | |
| st.caption(f"Total probability of displayed topics: {total * 100:.2f}%") | |
| def render_pie(top95_fields): | |
| sizes = [item["probability"] for item in top95_fields] | |
| colors = [TOPIC_COLORS[i % len(TOPIC_COLORS)] for i in range(len(top95_fields))] | |
| fig, ax = plt.subplots(figsize=(6.2, 6.2), facecolor="#fff7fb") | |
| ax.set_facecolor("#fff7fb") | |
| ax.pie( | |
| sizes, | |
| labels=None, | |
| colors=colors, | |
| autopct=lambda p: f"{p:.1f}%" if p >= 4 else "", | |
| startangle=90, | |
| pctdistance=0.72, | |
| wedgeprops={"width": 0.42, "edgecolor": "#fff7fb", "linewidth": 2.8}, | |
| textprops={"color": "#8f2f67", "fontweight": "semibold"}, | |
| ) | |
| ax.text(0, 0, "TOPICS", ha="center", va="center", fontsize=15, fontweight="bold", color="#c43b82") | |
| ax.axis("equal") | |
| st.pyplot(fig, clear_figure=True) | |
| plt.close(fig) | |
| def render_quartiles(quartiles): | |
| max_prob = max(item["probability"] for item in quartiles) if quartiles else 1.0 | |
| for col, item in zip(st.columns(4), quartiles): | |
| alpha = 0.18 + 0.42 * (item["probability"] / max_prob if max_prob else 0) | |
| with col: | |
| st.markdown( | |
| f""" | |
| <div class="quartile-card" style="background: linear-gradient(180deg, rgba(255,255,255,0.96), rgba(255,79,163,{alpha:.3f}));"> | |
| <div class="quartile-title">{item["label"]}</div> | |
| <div style="color:#7d2b58;font-weight:700;">{item["probability"] * 100:.2f}%</div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| inject_styles() | |
| st.session_state.setdefault("title_input", "") | |
| st.session_state.setdefault("abstract_input", "") | |
| st.session_state.setdefault("pending_example", None) | |
| st.session_state.setdefault("pending_clear", False) | |
| if st.session_state["pending_example"] is not None: | |
| for example in EXAMPLES: | |
| if example["id"] == st.session_state["pending_example"]: | |
| st.session_state["title_input"] = example["title"] | |
| st.session_state["abstract_input"] = example["abstract"] | |
| break | |
| st.session_state["pending_example"] = None | |
| if st.session_state["pending_clear"]: | |
| st.session_state["title_input"] = "" | |
| st.session_state["abstract_input"] = "" | |
| st.session_state["pending_clear"] = False | |
| st.title("💖 Top-95% Scientific Article Classification") | |
| st.write( | |
| "The app takes an article title and abstract, identifies the most likely topics, " | |
| "and also predicts the journal quartile." | |
| ) | |
| left_col, right_col = st.columns([1.0, 0.97], gap="medium") | |
| with left_col: | |
| with st.container(key="input-shell"): | |
| st.subheader("Input") | |
| st.text_input( | |
| "Article title", | |
| key="title_input", | |
| placeholder="For example: Graph Neural Networks for Molecular Property Prediction", | |
| help="You can enter only the title if the abstract is unavailable.", | |
| ) | |
| st.text_area( | |
| "Abstract", | |
| key="abstract_input", | |
| placeholder="For example: We propose a new method for predicting molecular properties...", | |
| height=220, | |
| help="If the abstract is empty, classification will be based on the title only.", | |
| ) | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| predict_button = st.button("Predict topic and quartile", type="primary", use_container_width=True) | |
| with c2: | |
| clear_button = st.button("Clear fields", type="tertiary", use_container_width=True) | |
| status_placeholder = st.empty() | |
| if clear_button: | |
| st.session_state["pending_clear"] = True | |
| st.rerun() | |
| with right_col: | |
| with st.container(key="examples-shell"): | |
| st.subheader("Examples") | |
| for i in range(0, len(EXAMPLES), 2): | |
| cols = st.columns(2, gap="small") | |
| for col, example in zip(cols, EXAMPLES[i : i + 2]): | |
| with col: | |
| preview = example["abstract"].strip() | |
| preview = preview if len(preview) <= 95 else preview[:94].rstrip() + "…" | |
| label = f"**{example['title']}**\n\n{preview or 'Title-only classification'}" | |
| if st.button(label, key=example["id"], type="secondary", use_container_width=True): | |
| st.session_state["pending_example"] = example["id"] | |
| st.rerun() | |
| title = st.session_state["title_input"].strip() | |
| abstract = st.session_state["abstract_input"].strip() | |
| if predict_button: | |
| try: | |
| if not title and not abstract: | |
| st.error("Input error: fill in at least one field: 'Article title' or 'Abstract'.") | |
| else: | |
| status_placeholder.info("Running classification...") | |
| result = predict_article_cached(title, abstract) | |
| status_placeholder.empty() | |
| st.success("Classification completed.") | |
| result_left, result_right = st.columns(2, gap="medium") | |
| with result_left: | |
| with st.container(key="topics-shell"): | |
| st.subheader("Article topic") | |
| text_col, chart_col = st.columns([1.15, 1.1], gap="medium") | |
| with text_col: | |
| render_topics(result["top95_fields"]) | |
| with chart_col: | |
| render_pie(result["top95_fields"]) | |
| with result_right: | |
| with st.container(key="quartile-shell"): | |
| st.subheader("Journal quartile") | |
| st.markdown( | |
| f"**Most likely quartile: {result['best_quartile']['label']}** - " | |
| f"{result['best_quartile']['probability'] * 100:.2f}%" | |
| ) | |
| render_quartiles(result["quartiles"]) | |
| if abstract: | |
| st.info("Classification used both the title and the abstract.") | |
| else: | |
| st.info("No abstract was provided, so classification used the title only.") | |
| st.link_button( | |
| "Click for something interesting", | |
| "https://www.youtube.com/watch?v=dQw4w9WgXcQ", | |
| ) | |
| except FileNotFoundError as e: | |
| status_placeholder.empty() | |
| st.error(f"Error loading model files: {e}") | |
| except RuntimeError as e: | |
| status_placeholder.empty() | |
| st.error(f"Runtime error during model execution: {e}") | |
| except Exception as e: | |
| status_placeholder.empty() | |
| st.error(f"Unexpected error: {e}") | |