Spaces:
Sleeping
Sleeping
| """ | |
| PhySH Taxonomy Classifier — Gradio App | |
| Two-stage hierarchical cascade: | |
| Stage 1 → Discipline prediction (18-class multi-label) | |
| Stage 2 → Concept prediction (186-class multi-label, conditioned on discipline probs) | |
| Models were trained on APS PhySH labels with google/embeddinggemma-300m embeddings. | |
| """ | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from sentence_transformers import SentenceTransformer | |
| # --------------------------------------------------------------------------- | |
| # Model definitions (mirror the training code exactly) | |
| # --------------------------------------------------------------------------- | |
| class MultiLabelMLP(nn.Module): | |
| def __init__(self, input_dim: int, output_dim: int, | |
| hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3): | |
| super().__init__() | |
| layers = [] | |
| prev_dim = input_dim | |
| for hidden_dim in hidden_layers: | |
| layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) | |
| prev_dim = hidden_dim | |
| layers.append(nn.Linear(prev_dim, output_dim)) | |
| self.network = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.network(x) | |
| class DisciplineConditionedMLP(nn.Module): | |
| def __init__(self, embedding_dim: int, discipline_dim: int, output_dim: int, | |
| hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3, | |
| discipline_dropout: float = 0.0, use_logits: bool = False): | |
| super().__init__() | |
| self.use_logits = use_logits | |
| self.discipline_dropout = nn.Dropout(discipline_dropout) | |
| layers = [] | |
| prev_dim = embedding_dim + discipline_dim | |
| for hidden_dim in hidden_layers: | |
| layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) | |
| prev_dim = hidden_dim | |
| layers.append(nn.Linear(prev_dim, output_dim)) | |
| self.network = nn.Sequential(*layers) | |
| def forward(self, embedding: torch.Tensor, discipline_probs: torch.Tensor) -> torch.Tensor: | |
| if self.use_logits: | |
| disc_features = torch.clamp(discipline_probs, 1e-7, 1 - 1e-7) | |
| disc_features = torch.log(disc_features / (1 - disc_features)) | |
| else: | |
| disc_features = discipline_probs | |
| disc_features = self.discipline_dropout(disc_features) | |
| return self.network(torch.cat([embedding, disc_features], dim=1)) | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| MODELS_DIR = Path(__file__).resolve().parent | |
| DISCIPLINE_MODEL_PATH = MODELS_DIR / "discipline_classifier_gemma_20260130_140842.pt" | |
| CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt" | |
| EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m" | |
| EXCLUDED_DISCIPLINES = {"Quantum Physics"} | |
| # --------------------------------------------------------------------------- | |
| # Globals (loaded once at startup) | |
| # --------------------------------------------------------------------------- | |
| device: str = "cpu" | |
| embedding_model: SentenceTransformer = None | |
| discipline_model: MultiLabelMLP = None | |
| concept_model: DisciplineConditionedMLP = None | |
| discipline_labels: List[Dict] = [] | |
| concept_labels: List[Dict] = [] | |
| def load_models(): | |
| global device, embedding_model, discipline_model, concept_model | |
| global discipline_labels, concept_labels | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| print(f"Loading embedding model ({EMBEDDING_MODEL_NAME}) on {device} …") | |
| hf_token = os.environ.get("HF_TOKEN") | |
| embedding_model = SentenceTransformer( | |
| EMBEDDING_MODEL_NAME, device=device, token=hf_token, | |
| ) | |
| # --- discipline model --- | |
| disc_ckpt = torch.load(DISCIPLINE_MODEL_PATH, map_location=device, weights_only=False) | |
| dc = disc_ckpt["model_config"] | |
| discipline_model = MultiLabelMLP( | |
| dc["input_dim"], dc["output_dim"], | |
| tuple(dc["hidden_layers"]), dc["dropout"], | |
| ) | |
| discipline_model.load_state_dict(disc_ckpt["model_state_dict"]) | |
| discipline_model.to(device).eval() | |
| discipline_labels = disc_ckpt["class_labels"] | |
| # --- concept model --- | |
| conc_ckpt = torch.load(CONCEPT_MODEL_PATH, map_location=device, weights_only=False) | |
| cc = conc_ckpt["model_config"] | |
| concept_model = DisciplineConditionedMLP( | |
| cc["embedding_dim"], cc["discipline_dim"], cc["output_dim"], | |
| tuple(cc["hidden_layers"]), cc["dropout"], | |
| cc.get("discipline_dropout", 0.0), cc.get("use_logits", False), | |
| ) | |
| concept_model.load_state_dict(conc_ckpt["model_state_dict"]) | |
| concept_model.to(device).eval() | |
| concept_labels = conc_ckpt["class_labels"] | |
| print(f"Loaded {len(discipline_labels)} disciplines, {len(concept_labels)} concepts") | |
| # --------------------------------------------------------------------------- | |
| # Prediction | |
| # --------------------------------------------------------------------------- | |
| def clean_text(text: str) -> str: | |
| if not text: | |
| return "" | |
| return re.sub(r"\s+", " ", text).strip() | |
| def predict(title: str, abstract: str, threshold: float, top_k: int): | |
| """Run the two-stage cascade and return formatted results.""" | |
| combined = clean_text(title) | |
| abs_clean = clean_text(abstract) | |
| if combined and abs_clean: | |
| combined = f"{combined} [SEP] {abs_clean}" | |
| elif abs_clean: | |
| combined = abs_clean | |
| if not combined.strip(): | |
| return "Please enter at least a title or abstract.", "" | |
| # Embed | |
| embedding = embedding_model.encode( | |
| [combined], normalize_embeddings=True, convert_to_numpy=True, | |
| ) | |
| emb_tensor = torch.FloatTensor(embedding).to(device) | |
| with torch.no_grad(): | |
| # Stage 1 | |
| disc_logits = discipline_model(emb_tensor) | |
| disc_probs = torch.sigmoid(disc_logits).cpu().numpy()[0] | |
| # Stage 2 | |
| disc_probs_tensor = torch.FloatTensor(disc_probs).unsqueeze(0).to(device) | |
| conc_logits = concept_model(emb_tensor, disc_probs_tensor) | |
| conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0] | |
| # Format discipline results (skip excluded labels) | |
| disc_order = np.argsort(disc_probs)[::-1] | |
| disc_lines = [] | |
| rank = 0 | |
| for idx in disc_order: | |
| label = discipline_labels[idx].get("label", f"Discipline_{idx}") | |
| if label in EXCLUDED_DISCIPLINES: | |
| continue | |
| rank += 1 | |
| if rank > top_k: | |
| break | |
| prob = disc_probs[idx] | |
| marker = "**" if prob >= threshold else "" | |
| disc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}") | |
| # Format concept results | |
| conc_order = np.argsort(conc_probs)[::-1] | |
| conc_lines = [] | |
| for rank, idx in enumerate(conc_order[:top_k], 1): | |
| prob = conc_probs[idx] | |
| label = concept_labels[idx].get("label", f"Concept_{idx}") | |
| marker = "**" if prob >= threshold else "" | |
| conc_lines.append(f"{rank}. {marker}{label}{marker} — {prob:.1%}") | |
| disc_md = f"### Disciplines (threshold ≥ {threshold:.0%})\n\n" + "\n".join(disc_lines) | |
| conc_md = f"### Research-Area Concepts (threshold ≥ {threshold:.0%})\n\n" + "\n".join(conc_lines) | |
| return disc_md, conc_md | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| EXAMPLES = [ | |
| [ | |
| "Quantum Computing: Vision and Challenges", | |
| ( | |
| "The recent development of quantum computing, which uses entanglement, superposition, and other quantum fundamental concepts, " | |
| "can provide substantial processing advantages over traditional computing. These quantum features help solve many complex " | |
| "problems that cannot be solved otherwise with conventional computing methods. These problems include modeling quantum mechanics, " | |
| "logistics, chemical-based advances, drug design, statistical science, sustainable energy, banking, reliable communication, and " | |
| "quantum chemical engineering. The last few years have witnessed remarkable progress in quantum software and algorithm creation " | |
| "and quantum hardware research, which has significantly advanced the prospect of realizing quantum computers. It would be helpful " | |
| "to have comprehensive literature research on this area to grasp the current status and find outstanding problems that require " | |
| "considerable attention from the research community working in the quantum computing industry. To better understand quantum computing, " | |
| "this paper examines the foundations and vision based on current research in this area. We discuss cutting-edge developments in quantum " | |
| "computer hardware advancement and subsequent advances in quantum cryptography, quantum software, and high-scalability quantum computers. " | |
| "Many potential challenges and exciting new trends for quantum technology research and development are highlighted in this paper for a broader debate." | |
| ), | |
| ], | |
| [ | |
| "Topological Insulators and Superconductors", | |
| ( | |
| "Topological insulators are electronic materials that have a bulk band gap like an ordinary insulator but have protected conducting states " | |
| "on their edge or surface. We review the theoretical foundation for topological insulators and superconductors and describe recent experiments." | |
| ), | |
| ], | |
| [ | |
| "Floquet Topological Insulator in Semiconductor Quantum Wells", | |
| ( | |
| "Topological phase transitions between a conventional insulator and a state of matter with topological properties have been proposed and observed " | |
| "in mercury telluride - cadmium telluride quantum wells. We show that a topological state can be induced in such a device, initially in the trivial " | |
| "phase, by irradiation with microwave frequencies, without closing the gap and crossing the phase transition. We show that the quasi-energy spectrum " | |
| "exhibits a single pair of helical edge states. The velocity of the edge states can be tuned by adjusting the intensity of the microwave radiation. " | |
| "We discuss the necessary experimental parameters for our proposal. This proposal provides an example and a proof of principle of a new non-equilibrium " | |
| "topological state, Floquet topological insulator, introduced in this paper." | |
| ), | |
| ], | |
| ] | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks( | |
| title="PhySH Taxonomy Classifier", | |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"), | |
| ) as demo: | |
| gr.Markdown( | |
| "# PhySH Taxonomy Classifier\n" | |
| "Enter a paper **title** and **abstract** to predict APS PhySH disciplines " | |
| "and research-area concepts using a two-stage hierarchical cascade.\n\n" | |
| "Labels above the threshold are **bolded**." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| title_box = gr.Textbox(label="Title", lines=2, placeholder="Paper title …") | |
| abstract_box = gr.Textbox(label="Abstract", lines=8, placeholder="Paper abstract …") | |
| with gr.Row(): | |
| threshold_slider = gr.Slider( | |
| minimum=0.05, maximum=0.95, value=0.35, step=0.05, | |
| label="Threshold", | |
| ) | |
| topk_slider = gr.Slider( | |
| minimum=1, maximum=20, value=10, step=1, label="Top-K", | |
| ) | |
| predict_btn = gr.Button("Classify", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| disc_output = gr.Markdown(label="Disciplines") | |
| conc_output = gr.Markdown(label="Concepts") | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[title_box, abstract_box, threshold_slider, topk_slider], | |
| outputs=[disc_output, conc_output], | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[title_box, abstract_box], | |
| label="Example papers", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| load_models() | |
| app = build_app() | |
| app.launch() | |