Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| import requests | |
| import os | |
| from PyPDF2 import PdfReader | |
| st.set_page_config( | |
| page_title="Citation Impact Predictor", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| MODEL_NAME = "allenai/specter2_base" | |
| CHECKPOINT_PATH = "best_model.pt" | |
| N_META = 5 | |
| N_CLASSES = 4 | |
| THRESHOLDS_5Y = [1.5, 3.5, 5.5] | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| CLASS_NAMES = ["ποΈ Low (0)", "π Medium (1)", "π High (2)", "π Top (3)"] | |
| CLASS_LABELS = ["Low", "Medium", "High", "Top"] | |
| CLASS_COLORS = ["#e74c3c", "#f39c12", "#3498db", "#2ecc71"] | |
| class CitationPredictor(nn.Module): | |
| def __init__(self, model_name: str, n_meta: int): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| self.meta_proj = nn.Sequential(nn.Linear(n_meta, 64), nn.GELU()) | |
| self.layer = nn.Sequential( | |
| nn.Linear(768 + 64, 256), | |
| nn.GELU(), | |
| nn.LayerNorm(256), | |
| nn.Dropout(0.2), | |
| ) | |
| self.head = nn.Linear(256, 1) | |
| def forward(self, input_ids, attention_mask, meta): | |
| cls_emb = self.encoder( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| ).last_hidden_state[:, 0] | |
| m_emb = self.meta_proj(meta) | |
| feat = self.layer(torch.cat([cls_emb, m_emb], dim=-1)) | |
| return self.head(feat).squeeze(-1) | |
| def to_class(pred: float) -> int: | |
| if pred < THRESHOLDS_5Y[0]: return 0 | |
| if pred < THRESHOLDS_5Y[1]: return 1 | |
| if pred < THRESHOLDS_5Y[2]: return 2 | |
| return 3 | |
| def noise_score(text: str) -> float: | |
| """ΠΠΎΠ»Ρ Π±ΡΠΊΠ² Π² ΡΠ΅ΠΊΡΡΠ΅ β ΠΏΡΠΎΡΡΠ°Ρ ΠΌΠ΅ΡΡΠΈΠΊΠ° ΠΎΡΠΌΡΡΠ»Π΅Π½Π½ΠΎΡΡΠΈ""" | |
| letters = sum(c.isalpha() for c in text) | |
| return letters / max(len(text), 1) | |
| def compute_meta_from_inputs( | |
| publication_year: int, | |
| abstract: str, | |
| title: str, | |
| author_count: int, | |
| ) -> torch.Tensor: | |
| text = (title + " " + abstract).strip() | |
| meta = [ | |
| float(publication_year) / 2026, | |
| float(np.log1p(len(abstract))), | |
| float(np.log1p(len(title))), | |
| float(np.log1p(min(author_count, 200))), | |
| noise_score(text) # ΠΎΡΠΌΡΡΠ»Π΅Π½Π½ΠΎΡΡΡ ΡΠ΅ΠΊΡΡΠ° | |
| ] | |
| return torch.tensor([meta], dtype=torch.float) | |
| def fetch_openalex_by_doi(doi: str) -> dict | None: | |
| clean = doi.strip().replace("https://doi.org/", "").replace("http://doi.org/", "") | |
| url = f"https://api.openalex.org/works/doi:{clean}" | |
| params = { | |
| "select": "title,abstract_inverted_index,publication_year,authorships", | |
| "mailto": "demo@example.com", | |
| } | |
| try: | |
| r = requests.get(url, params=params, timeout=15) | |
| if r.status_code == 200: | |
| return r.json() | |
| except Exception: | |
| pass | |
| return None | |
| def decode_abstract(inv_idx: dict) -> str: | |
| if not inv_idx: | |
| return "" | |
| words = [] | |
| for word, positions in inv_idx.items(): | |
| for pos in positions: | |
| words.append((pos, word)) | |
| return " ".join(w for _, w in sorted(words)).strip() | |
| def load_model(): | |
| if not os.path.exists(CHECKPOINT_PATH): | |
| st.error( | |
| f"`{CHECKPOINT_PATH}` not found. " | |
| "Make sure it is uploaded to the Space root directory." | |
| ) | |
| st.stop() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = CitationPredictor(MODEL_NAME, N_META).to(DEVICE) | |
| state = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| model.load_state_dict(state) | |
| model.eval() | |
| return model, tokenizer | |
| model, tokenizer = load_model() | |
| def extract_text_from_pdf(file) -> str: | |
| try: | |
| reader = PdfReader(file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() or "" | |
| return text.strip() | |
| except Exception: | |
| return "" | |
| def predict(title: str, abstract: str, meta_tensor: torch.Tensor): | |
| text = f"{title} [SEP] {abstract}" | |
| enc = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding="max_length", | |
| ) | |
| with torch.no_grad(): | |
| raw = model( | |
| enc["input_ids"].to(DEVICE), | |
| enc["attention_mask"].to(DEVICE), | |
| meta_tensor.to(DEVICE), | |
| ) | |
| score = raw.item() | |
| pred_class = to_class(score) | |
| est_citations = float(np.expm1(max(score, 0))) # inverse of log1p | |
| return pred_class, score, est_citations | |
| st.title("π Citation Impact Predictor") | |
| st.markdown(""" | |
| ### π€ ΠΠ°ΡΠ΅ΠΌ ΡΡΠΎ Π½ΡΠΆΠ½ΠΎ? | |
| Π£Π·Π½Π°ΡΡ Π·Π°ΡΠ°Π½Π΅Π΅ ΠΏΠΎ Π½Π°Π·Π²Π°Π½ΠΈΡ Π°Π±ΡΡΡΠ°ΠΊΡΡ, ΡΠΈΡΠ»Ρ Π°Π²ΡΠΎΡΠΎΠ², Π³ΠΎΠ΄Ρ Π²ΡΡ ΠΎΠ΄Π° ΠΈ Π½Π°Π»ΠΈΡΠΈΡ ΠΎΡΠΊΡΡΡΠΎΠ³ΠΎ Π΄ΠΎΡΡΡΠΏΠ° ΡΡΠΎΠΈΡ Π»ΠΈ Π²ΠΎΠΎΠ±ΡΠ΅ ΡΡΠ°ΡΠΈΡΡ Π²ΡΠ΅ΠΌΡ Π½Π° ΠΈΠ·ΡΡΠ΅Π½ΠΈΠ΅ ΡΡΠ°ΡΡΠΈ | |
| ΠΡ Π΄Π΅Π»ΠΈΠΌ ΡΠ°Π±ΠΎΡΡ Π½Π° 4 ΠΊΠ°ΡΠ΅Π³ΠΎΡΠΈΠΈ: | |
| - ποΈ **ΠΡΡΠΎΡ** β Π½Π΅ ΡΡΠΎΠΈΡ ΡΠΈΡΠ°ΡΡ | |
| - π **Π‘Π΅ΡΠ΅Π΄Π½ΡΠΊ** β ΠΌΠΎΠΆΠ½ΠΎ ΡΠΈΡΠ°ΡΡ, Π΅ΡΠ»ΠΈ ΡΡΠΎ Π²Π°ΡΠ° ΠΎΠ±Π»Π°ΡΡΡ ΠΈ Π±ΠΎΠ»Π΅Π΅ ΡΠΈΠ»ΡΠ½ΡΡ ΡΠ°Π±ΠΎΡ ΡΠ΅ΠΉΡΠ°Ρ Π½Π΅Ρ | |
| - π **Π‘ΠΈΠ»ΡΠ½Π°Ρ ΡΠ°Π±ΠΎΡΠ°** β ΡΡΠΎΠΈΡ ΠΎΠ±ΡΠ°ΡΠΈΡΡ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅ | |
| - π **Π’ΠΎΠΏ** β ΡΠΈΡΠ°ΡΡ ΠΎΠ±ΡΠ·Π°ΡΠ΅Π»ΡΠ½ΠΎ | |
| π‘ ΠΡΠΎ Π½Π΅ Π·Π°ΠΌΠ΅Π½ΡΠ΅Ρ ΡΠΊΡΠΏΠ΅ΡΡΠ½ΡΡ ΠΎΡΠ΅Π½ΠΊΡ β | |
| Π½ΠΎ ΠΏΠΎΠΌΠΎΠ³Π°Π΅Ρ Π±ΡΡΡΡΠΎ ΠΎΡΡΠΈΠ»ΡΡΡΠΎΠ²Π°ΡΡ ΠΏΠΎΡΠΎΠΊ Π½Π°ΡΡΠ½ΡΡ ΡΠ°Π±ΠΎΡ. | |
| """) | |
| st.divider() | |
| st.sidebar.header("π₯ Paper Input") | |
| input_mode = st.sidebar.radio( | |
| "Input method", | |
| ["Manual text", "Fetch by DOI", "Upload PDF"], | |
| help="Choose how to provide the paper.", | |
| ) | |
| title = "" | |
| abstract = "" | |
| pub_year = 2020 | |
| # DOI input stays in sidebar; text input moves to main area | |
| if input_mode == "Fetch by DOI": | |
| doi_input = st.sidebar.text_input("DOI", placeholder="10.1234/example") | |
| else: | |
| doi_input = "" | |
| st.sidebar.divider() | |
| st.sidebar.header("π’ Metadata") | |
| pub_year = st.sidebar.number_input("Publication year", 2000, 2024, 2020) | |
| author_count = st.sidebar.number_input("Author count", min_value=1, max_value=200, value=3) | |
| # ββ Main panel: wide left for input, narrow right for button ββββββββββββββββββ | |
| col_left, col_right = st.columns([4, 1]) | |
| with col_left: | |
| if input_mode == "Manual text": | |
| title = st.text_input("Title", placeholder="e.g. Attention Is All You Need") | |
| abstract = st.text_area("Abstract", height=250, placeholder="Paste the abstract hereβ¦") | |
| elif input_mode == "Fetch by DOI": | |
| if doi_input: | |
| with st.spinner("Fetching metadata from OpenAlexβ¦"): | |
| paper = fetch_openalex_by_doi(doi_input) | |
| if paper: | |
| title = paper.get("title") or "" | |
| abstract = decode_abstract(paper.get("abstract_inverted_index") or {}) | |
| pub_year = paper.get("publication_year") or 2020 | |
| st.sidebar.success("β Paper found!") | |
| st.success(f"**{title}**") | |
| st.markdown(abstract[:800] + ("β¦" if len(abstract) > 800 else "")) | |
| else: | |
| st.error("Could not fetch paper. Check the DOI.") | |
| else: | |
| st.info("Enter a DOI in the sidebar to fetch paper metadata.") | |
| elif input_mode == "Upload PDF": | |
| uploaded_file = st.file_uploader("Upload PDF", type=["pdf"]) | |
| if uploaded_file is not None: | |
| with st.spinner("Extracting text from PDFβ¦"): | |
| text = extract_text_from_pdf(uploaded_file) | |
| if text: | |
| lines = text.split("\n") | |
| title = lines[0][:300] | |
| abstract = " ".join(lines[1:])[:3000] | |
| st.success("β PDF processed") | |
| st.markdown(f"**{title}**") | |
| st.markdown(abstract[:800] + ("β¦" if len(abstract) > 800 else "")) | |
| else: | |
| st.error("Could not extract text from PDF.") | |
| with col_right: | |
| st.markdown("<br><br>", unsafe_allow_html=True) | |
| run = st.button("π Predict", use_container_width=True, type="primary") | |
| if run: | |
| if not title and not abstract: | |
| st.warning("Please provide at least a title or abstract.") | |
| else: | |
| text = (title + " " + abstract).strip() | |
| meta_tensor = compute_meta_from_inputs( | |
| publication_year=int(pub_year), | |
| abstract=abstract, | |
| title=title, | |
| author_count=int(author_count) | |
| ) | |
| with st.spinner("Running inferenceβ¦"): | |
| pred_class, raw_score, est_citations = predict(title, abstract, meta_tensor) | |
| st.divider() | |
| st.subheader("π Prediction Results") | |
| # Main result badge | |
| color = CLASS_COLORS[pred_class] | |
| label = CLASS_LABELS[pred_class] | |
| st.markdown( | |
| f""" | |
| <div style=" | |
| background:{color}22; | |
| border-left: 6px solid {color}; | |
| padding: 1rem 1.5rem; | |
| border-radius: 8px; | |
| margin-bottom: 1rem; | |
| "> | |
| <h2 style="margin:0; color:{color}">Class {pred_class} β {label}</h2> | |
| <p style="margin:0.4rem 0 0; color:#555; font-size:1.1rem;"> | |
| Estimated citations in the first 5 years: | |
| <strong style="font-size:1.3rem;">~{est_citations:.0f}</strong> | |
| </p> | |
| <p style="margin:0.15rem 0 0; color:#aaa; font-size:0.85rem;"> | |
| (raw log-score: {raw_score:.3f} β e^score β 1 = {est_citations:.1f}) | |
| </p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("**Score vs. class thresholds**") | |
| thresh_cols = st.columns(4) | |
| boundaries = [0, THRESHOLDS_5Y[0], THRESHOLDS_5Y[1], THRESHOLDS_5Y[2], 8] | |
| for i, (col, name, color) in enumerate(zip(thresh_cols, CLASS_LABELS, CLASS_COLORS)): | |
| lo, hi = boundaries[i], boundaries[i + 1] | |
| active = pred_class == i | |
| col.markdown( | |
| f"""<div style=" | |
| background:{'#2222' if not active else color+'33'}; | |
| border:2px solid {color if active else '#ccc'}; | |
| border-radius:6px; padding:0.5rem; text-align:center;"> | |
| <b style="color:{color}">{name}</b><br> | |
| <small style="color:#888">{lo:.1f} β {hi:.1f}</small> | |
| {"<br>β " if active else ""} | |
| </div>""", | |
| unsafe_allow_html=True, | |
| ) | |
| st.divider() | |
| interpretations = { | |
| 0: "This paper is predicted to receive **very few citations** in its first 5 years β typical of niche, incremental, or low-visibility work.", | |
| 1: "This paper is predicted to receive a **moderate number of citations** β solid work with a reasonable audience.", | |
| 2: "This paper is predicted to receive a **high number of citations** β likely a meaningful contribution to its field.", | |
| 3: "This paper is predicted to be a **top-cited paper** β potentially a landmark contribution with broad impact.", | |
| } | |
| st.markdown(f"π‘ **Interpretation:** {interpretations[pred_class]}") | |
| st.divider() | |
| st.caption( | |
| "Model: fine-tuned `allenai/specter2_base` Β· " | |
| "Classes defined by log1p(5-year citations) thresholds [1.5, 3.5, 5.5] Β· " | |
| "Β© 2026 Citation Predictor" | |
| ) |