import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import numpy as np
import requests
import os
from PyPDF2 import PdfReader
st.set_page_config(
page_title="Citation Impact Predictor",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded",
)
MODEL_NAME = "allenai/specter2_base"
CHECKPOINT_PATH = "best_model.pt"
N_META = 5
N_CLASSES = 4
THRESHOLDS_5Y = [1.5, 3.5, 5.5]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASS_NAMES = ["🗑️ Low (0)", "📄 Medium (1)", "📈 High (2)", "🏆 Top (3)"]
CLASS_LABELS = ["Low", "Medium", "High", "Top"]
CLASS_COLORS = ["#e74c3c", "#f39c12", "#3498db", "#2ecc71"]
class CitationPredictor(nn.Module):
def __init__(self, model_name: str, n_meta: int):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.meta_proj = nn.Sequential(nn.Linear(n_meta, 64), nn.GELU())
self.layer = nn.Sequential(
nn.Linear(768 + 64, 256),
nn.GELU(),
nn.LayerNorm(256),
nn.Dropout(0.2),
)
self.head = nn.Linear(256, 1)
def forward(self, input_ids, attention_mask, meta):
cls_emb = self.encoder(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state[:, 0]
m_emb = self.meta_proj(meta)
feat = self.layer(torch.cat([cls_emb, m_emb], dim=-1))
return self.head(feat).squeeze(-1)
def to_class(pred: float) -> int:
if pred < THRESHOLDS_5Y[0]: return 0
if pred < THRESHOLDS_5Y[1]: return 1
if pred < THRESHOLDS_5Y[2]: return 2
return 3
def noise_score(text: str) -> float:
"""Доля букв в тексте — простая метрика осмысленности"""
letters = sum(c.isalpha() for c in text)
return letters / max(len(text), 1)
def compute_meta_from_inputs(
publication_year: int,
abstract: str,
title: str,
author_count: int,
) -> torch.Tensor:
text = (title + " " + abstract).strip()
meta = [
float(publication_year) / 2026,
float(np.log1p(len(abstract))),
float(np.log1p(len(title))),
float(np.log1p(min(author_count, 200))),
noise_score(text) # осмысленность текста
]
return torch.tensor([meta], dtype=torch.float)
def fetch_openalex_by_doi(doi: str) -> dict | None:
clean = doi.strip().replace("https://doi.org/", "").replace("http://doi.org/", "")
url = f"https://api.openalex.org/works/doi:{clean}"
params = {
"select": "title,abstract_inverted_index,publication_year,authorships",
"mailto": "demo@example.com",
}
try:
r = requests.get(url, params=params, timeout=15)
if r.status_code == 200:
return r.json()
except Exception:
pass
return None
def decode_abstract(inv_idx: dict) -> str:
if not inv_idx:
return ""
words = []
for word, positions in inv_idx.items():
for pos in positions:
words.append((pos, word))
return " ".join(w for _, w in sorted(words)).strip()
@st.cache_resource(show_spinner="Loading model weights…")
def load_model():
if not os.path.exists(CHECKPOINT_PATH):
st.error(
f"`{CHECKPOINT_PATH}` not found. "
"Make sure it is uploaded to the Space root directory."
)
st.stop()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = CitationPredictor(MODEL_NAME, N_META).to(DEVICE)
state = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(state)
model.eval()
return model, tokenizer
model, tokenizer = load_model()
def extract_text_from_pdf(file) -> str:
try:
reader = PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text() or ""
return text.strip()
except Exception:
return ""
def predict(title: str, abstract: str, meta_tensor: torch.Tensor):
text = f"{title} [SEP] {abstract}"
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding="max_length",
)
with torch.no_grad():
raw = model(
enc["input_ids"].to(DEVICE),
enc["attention_mask"].to(DEVICE),
meta_tensor.to(DEVICE),
)
score = raw.item()
pred_class = to_class(score)
est_citations = float(np.expm1(max(score, 0))) # inverse of log1p
return pred_class, score, est_citations
st.title("📊 Citation Impact Predictor")
st.markdown("""
### 🤔 Зачем это нужно?
Узнать заранее по названию абстракту, числу авторов, году выхода и наличию открытого доступа стоит ли вообще тратить время на изучение статьи
Мы делим работы на 4 категории:
- 🗑️ **Мусор** — не стоит читать
- 📄 **Середняк** — можно читать, если это ваша область и более сильных работ сейчас нет
- 📈 **Сильная работа** — стоит обратить внимание
- 🏆 **Топ** — читать обязательно
💡 Это не заменяет экспертную оценку —
но помогает быстро отфильтровать поток научных работ.
""")
st.divider()
st.sidebar.header("📥 Paper Input")
input_mode = st.sidebar.radio(
"Input method",
["Manual text", "Fetch by DOI", "Upload PDF"],
help="Choose how to provide the paper.",
)
title = ""
abstract = ""
pub_year = 2020
# DOI input stays in sidebar; text input moves to main area
if input_mode == "Fetch by DOI":
doi_input = st.sidebar.text_input("DOI", placeholder="10.1234/example")
else:
doi_input = ""
st.sidebar.divider()
st.sidebar.header("🔢 Metadata")
pub_year = st.sidebar.number_input("Publication year", 2000, 2024, 2020)
author_count = st.sidebar.number_input("Author count", min_value=1, max_value=200, value=3)
# ── Main panel: wide left for input, narrow right for button ──────────────────
col_left, col_right = st.columns([4, 1])
with col_left:
if input_mode == "Manual text":
title = st.text_input("Title", placeholder="e.g. Attention Is All You Need")
abstract = st.text_area("Abstract", height=250, placeholder="Paste the abstract here…")
elif input_mode == "Fetch by DOI":
if doi_input:
with st.spinner("Fetching metadata from OpenAlex…"):
paper = fetch_openalex_by_doi(doi_input)
if paper:
title = paper.get("title") or ""
abstract = decode_abstract(paper.get("abstract_inverted_index") or {})
pub_year = paper.get("publication_year") or 2020
st.sidebar.success("✅ Paper found!")
st.success(f"**{title}**")
st.markdown(abstract[:800] + ("…" if len(abstract) > 800 else ""))
else:
st.error("Could not fetch paper. Check the DOI.")
else:
st.info("Enter a DOI in the sidebar to fetch paper metadata.")
elif input_mode == "Upload PDF":
uploaded_file = st.file_uploader("Upload PDF", type=["pdf"])
if uploaded_file is not None:
with st.spinner("Extracting text from PDF…"):
text = extract_text_from_pdf(uploaded_file)
if text:
lines = text.split("\n")
title = lines[0][:300]
abstract = " ".join(lines[1:])[:3000]
st.success("✅ PDF processed")
st.markdown(f"**{title}**")
st.markdown(abstract[:800] + ("…" if len(abstract) > 800 else ""))
else:
st.error("Could not extract text from PDF.")
with col_right:
st.markdown("
", unsafe_allow_html=True)
run = st.button("🔍 Predict", use_container_width=True, type="primary")
if run:
if not title and not abstract:
st.warning("Please provide at least a title or abstract.")
else:
text = (title + " " + abstract).strip()
meta_tensor = compute_meta_from_inputs(
publication_year=int(pub_year),
abstract=abstract,
title=title,
author_count=int(author_count)
)
with st.spinner("Running inference…"):
pred_class, raw_score, est_citations = predict(title, abstract, meta_tensor)
st.divider()
st.subheader("📊 Prediction Results")
# Main result badge
color = CLASS_COLORS[pred_class]
label = CLASS_LABELS[pred_class]
st.markdown(
f"""
Estimated citations in the first 5 years: ~{est_citations:.0f}
(raw log-score: {raw_score:.3f} → e^score − 1 = {est_citations:.1f})