Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import re | |
| import requests | |
| from urllib.parse import urlparse | |
| import xml.etree.ElementTree as ET | |
| ################################################## | |
| # Global setup | |
| ################################################## | |
| model_path = "ssocean/NAIP" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = None | |
| tokenizer = None | |
| ################################################## | |
| # Fetch paper info from arXiv | |
| ################################################## | |
| def fetch_arxiv_paper(arxiv_input): | |
| """ | |
| Fetch paper title & abstract from an arXiv URL or ID. | |
| """ | |
| try: | |
| if "arxiv.org" in arxiv_input: | |
| parsed = urlparse(arxiv_input) | |
| path = parsed.path | |
| arxiv_id = path.split("/")[-1].replace(".pdf", "") | |
| else: | |
| arxiv_id = arxiv_input.strip() | |
| api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" | |
| resp = requests.get(api_url) | |
| if resp.status_code != 200: | |
| return { | |
| "title": "", | |
| "abstract": "", | |
| "success": False, | |
| "message": "Error fetching paper from arXiv API", | |
| } | |
| root = ET.fromstring(resp.text) | |
| ns = {"arxiv": "http://www.w3.org/2005/Atom"} | |
| entry = root.find(".//arxiv:entry", ns) | |
| if entry is None: | |
| return {"title": "", "abstract": "", "success": False, "message": "Paper not found"} | |
| title = entry.find("arxiv:title", ns).text.strip() | |
| abstract = entry.find("arxiv:summary", ns).text.strip() | |
| return { | |
| "title": title, | |
| "abstract": abstract, | |
| "success": True, | |
| "message": "Paper fetched successfully!", | |
| } | |
| except Exception as e: | |
| return { | |
| "title": "", | |
| "abstract": "", | |
| "success": False, | |
| "message": f"Error fetching paper: {e}", | |
| } | |
| ################################################## | |
| # Prediction function | |
| ################################################## | |
| def predict(title, abstract): | |
| """ | |
| Predict a normalized academic impact score (0β1) from title & abstract. | |
| """ | |
| global model, tokenizer | |
| if model is None: | |
| # 1) Load config | |
| config = AutoConfig.from_pretrained(model_path) | |
| # 2) Remove quantization_config if it exists (avoid NoneType error in PEFT) | |
| if hasattr(config, "quantization_config"): | |
| del config.quantization_config | |
| # 3) Optionally set number of labels | |
| config.num_labels = 1 | |
| # 4) Load the model | |
| model_loaded = AutoModelForSequenceClassification.from_pretrained( | |
| model_path, | |
| config=config, | |
| torch_dtype=torch.float32, # float32 for stable cublasLt | |
| device_map=None, | |
| low_cpu_mem_usage=False | |
| ) | |
| model_loaded.to(device) | |
| model_loaded.eval() | |
| # 5) Load tokenizer | |
| tokenizer_loaded = AutoTokenizer.from_pretrained(model_path) | |
| # Assign to globals | |
| model, tokenizer = model_loaded, tokenizer_loaded | |
| text = ( | |
| f"Given a certain paper,\n" | |
| f"Title: {title.strip()}\n" | |
| f"Abstract: {abstract.strip()}\n" | |
| f"Predict its normalized academic impact (0~1):" | |
| ) | |
| try: | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| prob = torch.sigmoid(logits).item() | |
| score = min(1.0, prob + 0.05) | |
| return round(score, 4) | |
| except Exception as e: | |
| print("Prediction error:", e) | |
| return 0.0 | |
| ################################################## | |
| # Grading | |
| ################################################## | |
| def get_grade_and_emoji(score): | |
| """Map a 0β1 score to an A/B/C style grade with an emoji indicator.""" | |
| if score >= 0.900: | |
| return "AAA π" | |
| if score >= 0.800: | |
| return "AA β" | |
| if score >= 0.650: | |
| return "A β¨" | |
| if score >= 0.600: | |
| return "BBB π΅" | |
| if score >= 0.550: | |
| return "BB π" | |
| if score >= 0.500: | |
| return "B π" | |
| if score >= 0.400: | |
| return "CCC π" | |
| if score >= 0.300: | |
| return "CC βοΈ" | |
| return "C π" | |
| ################################################## | |
| # Validation | |
| ################################################## | |
| def validate_input(title, abstract): | |
| """ | |
| Ensure the title has at least 3 words, the abstract at least 50, | |
| and check for ASCII-only characters. | |
| """ | |
| non_ascii = re.compile(r"[^\x00-\x7F]") | |
| if len(title.split()) < 3: | |
| return False, "Title must be at least 3 words." | |
| if len(abstract.split()) < 50: | |
| return False, "Abstract must be at least 50 words." | |
| if non_ascii.search(title): | |
| return False, "Title contains non-ASCII characters." | |
| if non_ascii.search(abstract): | |
| return False, "Abstract contains non-ASCII characters." | |
| return True, "Inputs look good." | |
| def update_button_status(title, abstract): | |
| """Enable or disable the predict button based on validation.""" | |
| valid, msg = validate_input(title, abstract) | |
| if not valid: | |
| return gr.update(value="Error: " + msg), gr.update(interactive=False) | |
| return gr.update(value=msg), gr.update(interactive=True) | |
| ################################################## | |
| # Process arXiv input | |
| ################################################## | |
| def process_arxiv_input(arxiv_input): | |
| """ | |
| Called when user clicks 'Fetch Paper Details' to fill in title/abstract from arXiv. | |
| """ | |
| if not arxiv_input.strip(): | |
| return "", "", "Please enter an arXiv URL or ID" | |
| res = fetch_arxiv_paper(arxiv_input) | |
| if res["success"]: | |
| return res["title"], res["abstract"], res["message"] | |
| return "", "", res["message"] | |
| ################################################## | |
| # Custom CSS | |
| ################################################## | |
| css = """ | |
| .gradio-container { font-family: Arial, sans-serif; } | |
| .main-title { | |
| text-align: center; color: #2563eb; font-size: 2.5rem!important; | |
| margin-bottom:1rem!important; | |
| background: linear-gradient(45deg,#2563eb,#1d4ed8); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
| } | |
| .input-section { | |
| background:#fff; padding:1.5rem; border-radius:0.5rem; | |
| box-shadow:0 4px 6px rgba(0,0,0,0.1); | |
| } | |
| .result-section { | |
| background:#f7f9fc; padding:1.5rem; border-radius:0.5rem; | |
| margin-top:2rem; | |
| } | |
| .grade-display { | |
| font-size:2.5rem; text-align:center; margin-top:1rem; | |
| } | |
| .arxiv-input { | |
| margin-bottom:1.5rem; padding:1rem; background:#f3f4f6; | |
| border-radius:0.5rem; | |
| } | |
| .arxiv-link { | |
| color:#2563eb; text-decoration: underline; | |
| } | |
| """ | |
| ################################################## | |
| # Example Papers | |
| ################################################## | |
| example_papers = [ | |
| { | |
| "title": "Attention Is All You Need", | |
| "abstract": ( | |
| "The dominant sequence transduction models are based on complex recurrent or " | |
| "convolutional neural networks that include an encoder and a decoder. The best performing " | |
| "models also connect the encoder and decoder through an attention mechanism. We propose a " | |
| "new simple network architecture, the Transformer, based solely on attention mechanisms, " | |
| "dispensing with recurrence and convolutions entirely. Experiments on two machine " | |
| "translation tasks show these models to be superior in quality while being more " | |
| "parallelizable and requiring significantly less time to train." | |
| ), | |
| "score": 0.982, | |
| "note": "Revolutionary paper that introduced the Transformer architecture." | |
| }, | |
| { | |
| "title": "Language Models are Few-Shot Learners", | |
| "abstract": ( | |
| "Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by " | |
| "pre-training on a large corpus of text followed by fine-tuning on a specific task. While " | |
| "typically task-agnostic in architecture, this method still requires task-specific " | |
| "fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans " | |
| "can generally perform a new language task from only a few examples or from simple " | |
| "instructionsβsomething which current NLP systems still largely struggle to do. Here we " | |
| "show that scaling up language models greatly improves task-agnostic, few-shot " | |
| "performance, sometimes even reaching competitiveness with prior state-of-the-art " | |
| "fine-tuning approaches." | |
| ), | |
| "score": 0.956, | |
| "note": "Groundbreaking GPT-3 paper on few-shot learning." | |
| }, | |
| { | |
| "title": "An Empirical Study of Neural Network Training Protocols", | |
| "abstract": ( | |
| "This paper presents a comparative analysis of different training protocols for neural " | |
| "networks across various architectures. We examine the effects of learning rate schedules, " | |
| "batch size selection, and optimization algorithms on model convergence and final " | |
| "performance. Our experiments span multiple datasets and model sizes, providing practical " | |
| "insights for deep learning practitioners." | |
| ), | |
| "score": 0.623, | |
| "note": "Solid empirical comparison of training protocols." | |
| } | |
| ] | |
| ################################################## | |
| # Build the Gradio Interface | |
| ################################################## | |
| with gr.Blocks(theme=gr.themes.Default(), css=css) as iface: | |
| gr.Markdown("<div class='main-title'>Papers Impact: AI-Powered Research Impact Predictor</div>") | |
| gr.Markdown("**Predict the potential research impact (0β1) from title & abstract.**") | |
| with gr.Row(): | |
| with gr.Column(elem_classes="input-section"): | |
| gr.Markdown("### Import from arXiv") | |
| with gr.Group(elem_classes="arxiv-input"): | |
| arxiv_input = gr.Textbox( | |
| lines=1, | |
| placeholder="e.g. 2504.11651", | |
| label="arXiv URL or ID", | |
| value="2504.11651" | |
| ) | |
| gr.Markdown( | |
| """ | |
| <p> | |
| Enter an arXiv ID or URL. For example: | |
| <code>2504.11651</code> or <code>https://arxiv.org/pdf/2504.11651</code> | |
| </p> | |
| """ | |
| ) | |
| fetch_btn = gr.Button("π Fetch Paper Details", variant="secondary") | |
| gr.Markdown("### Or Enter Manually") | |
| title_input = gr.Textbox( | |
| lines=2, | |
| placeholder="Paper title (β₯3 words)...", | |
| label="Paper Title" | |
| ) | |
| abs_input = gr.Textbox( | |
| lines=5, | |
| placeholder="Paper abstract (β₯50 words)...", | |
| label="Paper Abstract" | |
| ) | |
| status_box = gr.Textbox(label="Validation Status", interactive=False) | |
| predict_btn = gr.Button("π― Predict Impact", interactive=False, variant="primary") | |
| with gr.Column(elem_classes="result-section"): | |
| score_box = gr.Number(label="Impact Score") | |
| grade_box = gr.Textbox(label="Grade", elem_classes="grade-display") | |
| ############## METHODOLOGY EXPLANATION ############## | |
| gr.Markdown( | |
| """ | |
| ### Scientific Methodology | |
| - **Training Data**: Model trained on an extensive dataset of published papers in CS.CV, CS.CL, CS.AI | |
| - **Optimization**: NDCG optimization with Sigmoid activation and MSE loss | |
| - **Validation**: Cross-validated against historical citation data | |
| - **Architecture**: Advanced transformer-based (LLaMA derivative) textual encoder | |
| - **Metrics**: Quantitative analysis of citation patterns and research influence | |
| """ | |
| ) | |
| ############## RATING SCALE ############## | |
| gr.Markdown( | |
| """ | |
| ### Rating Scale | |
| | Grade | Score Range | Description | Emoji | | |
| |-------|-------------|---------------------|-------| | |
| | AAA | 0.900β1.000 | **Exceptional** | π | | |
| | AA | 0.800β0.899 | **Very High** | β | | |
| | A | 0.650β0.799 | **High** | β¨ | | |
| | BBB | 0.600β0.649 | **Above Average** | π΅ | | |
| | BB | 0.550β0.599 | **Moderate** | π | | |
| | B | 0.500β0.549 | **Average** | π | | |
| | CCC | 0.400β0.499 | **Below Average** | π | | |
| | CC | 0.300β0.399 | **Low** | βοΈ | | |
| | C | <0.300 | **Limited** | π | | |
| """ | |
| ) | |
| ############## EXAMPLE PAPERS ############## | |
| gr.Markdown("### Example Papers") | |
| for paper in example_papers: | |
| gr.Markdown( | |
| f"**{paper['title']}** \n" | |
| f"Score: {paper['score']} | Grade: {get_grade_and_emoji(paper['score'])} \n" | |
| f"{paper['abstract']} \n" | |
| f"*{paper['note']}*\n---" | |
| ) | |
| ################################################## | |
| # Events | |
| ################################################## | |
| # Validation triggers | |
| title_input.change(update_button_status, [title_input, abs_input], [status_box, predict_btn]) | |
| abs_input.change(update_button_status, [title_input, abs_input], [status_box, predict_btn]) | |
| # arXiv fetch | |
| fetch_btn.click(process_arxiv_input, [arxiv_input], [title_input, abs_input, status_box]) | |
| # Predict handler | |
| def run_predict(t, a): | |
| s = predict(t, a) | |
| return s, get_grade_and_emoji(s) | |
| predict_btn.click(run_predict, [title_input, abs_input], [score_box, grade_box]) | |
| ################################################## | |
| # Launch | |
| ################################################## | |
| if __name__ == "__main__": | |
| iface.launch() | |