import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel import numpy as np import requests import os from PyPDF2 import PdfReader st.set_page_config( page_title="Citation Impact Predictor", page_icon="📊", layout="wide", initial_sidebar_state="expanded", ) MODEL_NAME = "allenai/specter2_base" CHECKPOINT_PATH = "best_model.pt" N_META = 5 N_CLASSES = 4 THRESHOLDS_5Y = [1.5, 3.5, 5.5] DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CLASS_NAMES = ["🗑️ Low (0)", "📄 Medium (1)", "📈 High (2)", "🏆 Top (3)"] CLASS_LABELS = ["Low", "Medium", "High", "Top"] CLASS_COLORS = ["#e74c3c", "#f39c12", "#3498db", "#2ecc71"] class CitationPredictor(nn.Module): def __init__(self, model_name: str, n_meta: int): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) self.meta_proj = nn.Sequential(nn.Linear(n_meta, 64), nn.GELU()) self.layer = nn.Sequential( nn.Linear(768 + 64, 256), nn.GELU(), nn.LayerNorm(256), nn.Dropout(0.2), ) self.head = nn.Linear(256, 1) def forward(self, input_ids, attention_mask, meta): cls_emb = self.encoder( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state[:, 0] m_emb = self.meta_proj(meta) feat = self.layer(torch.cat([cls_emb, m_emb], dim=-1)) return self.head(feat).squeeze(-1) def to_class(pred: float) -> int: if pred < THRESHOLDS_5Y[0]: return 0 if pred < THRESHOLDS_5Y[1]: return 1 if pred < THRESHOLDS_5Y[2]: return 2 return 3 def noise_score(text: str) -> float: """Доля букв в тексте — простая метрика осмысленности""" letters = sum(c.isalpha() for c in text) return letters / max(len(text), 1) def compute_meta_from_inputs( publication_year: int, abstract: str, title: str, author_count: int, ) -> torch.Tensor: text = (title + " " + abstract).strip() meta = [ float(publication_year) / 2026, float(np.log1p(len(abstract))), float(np.log1p(len(title))), float(np.log1p(min(author_count, 200))), noise_score(text) # осмысленность текста ] return torch.tensor([meta], dtype=torch.float) def fetch_openalex_by_doi(doi: str) -> dict | None: clean = doi.strip().replace("https://doi.org/", "").replace("http://doi.org/", "") url = f"https://api.openalex.org/works/doi:{clean}" params = { "select": "title,abstract_inverted_index,publication_year,authorships", "mailto": "demo@example.com", } try: r = requests.get(url, params=params, timeout=15) if r.status_code == 200: return r.json() except Exception: pass return None def decode_abstract(inv_idx: dict) -> str: if not inv_idx: return "" words = [] for word, positions in inv_idx.items(): for pos in positions: words.append((pos, word)) return " ".join(w for _, w in sorted(words)).strip() @st.cache_resource(show_spinner="Loading model weights…") def load_model(): if not os.path.exists(CHECKPOINT_PATH): st.error( f"`{CHECKPOINT_PATH}` not found. " "Make sure it is uploaded to the Space root directory." ) st.stop() tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = CitationPredictor(MODEL_NAME, N_META).to(DEVICE) state = torch.load(CHECKPOINT_PATH, map_location=DEVICE) model.load_state_dict(state) model.eval() return model, tokenizer model, tokenizer = load_model() def extract_text_from_pdf(file) -> str: try: reader = PdfReader(file) text = "" for page in reader.pages: text += page.extract_text() or "" return text.strip() except Exception: return "" def predict(title: str, abstract: str, meta_tensor: torch.Tensor): text = f"{title} [SEP] {abstract}" enc = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding="max_length", ) with torch.no_grad(): raw = model( enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE), meta_tensor.to(DEVICE), ) score = raw.item() pred_class = to_class(score) est_citations = float(np.expm1(max(score, 0))) # inverse of log1p return pred_class, score, est_citations st.title("📊 Citation Impact Predictor") st.markdown(""" ### 🤔 Зачем это нужно? Узнать заранее по названию абстракту, числу авторов, году выхода и наличию открытого доступа стоит ли вообще тратить время на изучение статьи Мы делим работы на 4 категории: - 🗑️ **Мусор** — не стоит читать - 📄 **Середняк** — можно читать, если это ваша область и более сильных работ сейчас нет - 📈 **Сильная работа** — стоит обратить внимание - 🏆 **Топ** — читать обязательно 💡 Это не заменяет экспертную оценку — но помогает быстро отфильтровать поток научных работ. """) st.divider() st.sidebar.header("📥 Paper Input") input_mode = st.sidebar.radio( "Input method", ["Manual text", "Fetch by DOI", "Upload PDF"], help="Choose how to provide the paper.", ) title = "" abstract = "" pub_year = 2020 # DOI input stays in sidebar; text input moves to main area if input_mode == "Fetch by DOI": doi_input = st.sidebar.text_input("DOI", placeholder="10.1234/example") else: doi_input = "" st.sidebar.divider() st.sidebar.header("🔢 Metadata") pub_year = st.sidebar.number_input("Publication year", 2000, 2024, 2020) author_count = st.sidebar.number_input("Author count", min_value=1, max_value=200, value=3) # ── Main panel: wide left for input, narrow right for button ────────────────── col_left, col_right = st.columns([4, 1]) with col_left: if input_mode == "Manual text": title = st.text_input("Title", placeholder="e.g. Attention Is All You Need") abstract = st.text_area("Abstract", height=250, placeholder="Paste the abstract here…") elif input_mode == "Fetch by DOI": if doi_input: with st.spinner("Fetching metadata from OpenAlex…"): paper = fetch_openalex_by_doi(doi_input) if paper: title = paper.get("title") or "" abstract = decode_abstract(paper.get("abstract_inverted_index") or {}) pub_year = paper.get("publication_year") or 2020 st.sidebar.success("✅ Paper found!") st.success(f"**{title}**") st.markdown(abstract[:800] + ("…" if len(abstract) > 800 else "")) else: st.error("Could not fetch paper. Check the DOI.") else: st.info("Enter a DOI in the sidebar to fetch paper metadata.") elif input_mode == "Upload PDF": uploaded_file = st.file_uploader("Upload PDF", type=["pdf"]) if uploaded_file is not None: with st.spinner("Extracting text from PDF…"): text = extract_text_from_pdf(uploaded_file) if text: lines = text.split("\n") title = lines[0][:300] abstract = " ".join(lines[1:])[:3000] st.success("✅ PDF processed") st.markdown(f"**{title}**") st.markdown(abstract[:800] + ("…" if len(abstract) > 800 else "")) else: st.error("Could not extract text from PDF.") with col_right: st.markdown("

", unsafe_allow_html=True) run = st.button("🔍 Predict", use_container_width=True, type="primary") if run: if not title and not abstract: st.warning("Please provide at least a title or abstract.") else: text = (title + " " + abstract).strip() meta_tensor = compute_meta_from_inputs( publication_year=int(pub_year), abstract=abstract, title=title, author_count=int(author_count) ) with st.spinner("Running inference…"): pred_class, raw_score, est_citations = predict(title, abstract, meta_tensor) st.divider() st.subheader("📊 Prediction Results") # Main result badge color = CLASS_COLORS[pred_class] label = CLASS_LABELS[pred_class] st.markdown( f"""

Class {pred_class} — {label}

Estimated citations in the first 5 years: ~{est_citations:.0f}

(raw log-score: {raw_score:.3f} → e^score − 1 = {est_citations:.1f})

""", unsafe_allow_html=True, ) st.markdown("**Score vs. class thresholds**") thresh_cols = st.columns(4) boundaries = [0, THRESHOLDS_5Y[0], THRESHOLDS_5Y[1], THRESHOLDS_5Y[2], 8] for i, (col, name, color) in enumerate(zip(thresh_cols, CLASS_LABELS, CLASS_COLORS)): lo, hi = boundaries[i], boundaries[i + 1] active = pred_class == i col.markdown( f"""
{name}
{lo:.1f} – {hi:.1f} {"
✅" if active else ""}
""", unsafe_allow_html=True, ) st.divider() interpretations = { 0: "This paper is predicted to receive **very few citations** in its first 5 years — typical of niche, incremental, or low-visibility work.", 1: "This paper is predicted to receive a **moderate number of citations** — solid work with a reasonable audience.", 2: "This paper is predicted to receive a **high number of citations** — likely a meaningful contribution to its field.", 3: "This paper is predicted to be a **top-cited paper** — potentially a landmark contribution with broad impact.", } st.markdown(f"💡 **Interpretation:** {interpretations[pred_class]}") st.divider() st.caption( "Model: fine-tuned `allenai/specter2_base` · " "Classes defined by log1p(5-year citations) thresholds [1.5, 3.5, 5.5] · " "© 2026 Citation Predictor" )