Spaces:
Running
Running
Updated app.py
Browse filesAdded tf-idf method for better semantic search
app.py
CHANGED
|
@@ -49,40 +49,108 @@ from smolagents import CodeAgent, HfApiModel, tool
|
|
| 49 |
# print(f"ERROR: {str(e)}") # Debug errors
|
| 50 |
# return [f"Error fetching research papers: {str(e)}"]
|
| 51 |
|
| 52 |
-
from rank_bm25 import BM25Okapi
|
| 53 |
-
import nltk
|
| 54 |
|
| 55 |
-
|
| 56 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
if os.path.exists(nltk_data_path):
|
| 61 |
-
shutil.rmtree(nltk_data_path) # Remove corrupted version
|
| 62 |
|
| 63 |
-
print("✅ Removed old NLTK 'punkt' data. Reinstalling...")
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
| 73 |
-
"""Fetches and ranks arXiv papers using
|
| 74 |
|
| 75 |
Args:
|
| 76 |
keywords: List of keywords for search.
|
| 77 |
num_results: Number of results to return.
|
| 78 |
|
| 79 |
Returns:
|
| 80 |
-
List of the most relevant papers based on
|
| 81 |
"""
|
| 82 |
try:
|
| 83 |
print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
|
| 84 |
|
| 85 |
-
# Use a general keyword search
|
| 86 |
query = "+AND+".join([f"all:{kw}" for kw in keywords])
|
| 87 |
query_encoded = urllib.parse.quote(query)
|
| 88 |
url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
|
@@ -105,17 +173,22 @@ def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
|
| 105 |
if not papers:
|
| 106 |
return [{"error": "No results found. Try different keywords."}]
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
ranked_papers = sorted(zip(papers,
|
| 117 |
|
| 118 |
-
# Return the most relevant
|
| 119 |
return [paper[0] for paper in ranked_papers[:num_results]]
|
| 120 |
|
| 121 |
except Exception as e:
|
|
@@ -188,11 +261,11 @@ def search_papers(user_input):
|
|
| 188 |
results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
|
| 189 |
print(f"DEBUG: Results received - {results}") # Debug function output
|
| 190 |
|
| 191 |
-
#
|
| 192 |
if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
|
| 193 |
return results[0]["error"] # Return the error message directly
|
| 194 |
|
| 195 |
-
#
|
| 196 |
if isinstance(results, list) and results and isinstance(results[0], dict):
|
| 197 |
formatted_results = "\n\n".join([
|
| 198 |
f"---\n\n"
|
|
|
|
| 49 |
# print(f"ERROR: {str(e)}") # Debug errors
|
| 50 |
# return [f"Error fetching research papers: {str(e)}"]
|
| 51 |
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
#"""------Applied BM25 search for paper retrival------"""
|
| 54 |
+
# from rank_bm25 import BM25Okapi
|
| 55 |
+
# import nltk
|
| 56 |
+
|
| 57 |
+
# import os
|
| 58 |
+
# import shutil
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# nltk_data_path = os.path.join(nltk.data.path[0], "tokenizers", "punkt")
|
| 62 |
+
# if os.path.exists(nltk_data_path):
|
| 63 |
+
# shutil.rmtree(nltk_data_path) # Remove corrupted version
|
| 64 |
+
|
| 65 |
+
# print("Removed old NLTK 'punkt' data. Reinstalling...")
|
| 66 |
|
| 67 |
+
# # Step 2: Download the correct 'punkt' tokenizer
|
| 68 |
+
# nltk.download("punkt_tab")
|
| 69 |
|
| 70 |
+
# print("Successfully installed 'punkt'!")
|
|
|
|
|
|
|
| 71 |
|
|
|
|
| 72 |
|
| 73 |
+
# @tool # Register the function properly as a SmolAgents tool
|
| 74 |
+
# def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
| 75 |
+
# """Fetches and ranks arXiv papers using BM25 keyword relevance.
|
| 76 |
|
| 77 |
+
# Args:
|
| 78 |
+
# keywords: List of keywords for search.
|
| 79 |
+
# num_results: Number of results to return.
|
| 80 |
|
| 81 |
+
# Returns:
|
| 82 |
+
# List of the most relevant papers based on BM25 ranking.
|
| 83 |
+
# """
|
| 84 |
+
# try:
|
| 85 |
+
# print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
|
| 86 |
|
| 87 |
+
# # Use a general keyword search (without `ti:` and `abs:`)
|
| 88 |
+
# query = "+AND+".join([f"all:{kw}" for kw in keywords])
|
| 89 |
+
# query_encoded = urllib.parse.quote(query)
|
| 90 |
+
# url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
| 91 |
+
|
| 92 |
+
# print(f"DEBUG: Query URL - {url}")
|
| 93 |
+
|
| 94 |
+
# feed = feedparser.parse(url)
|
| 95 |
+
# papers = []
|
| 96 |
+
|
| 97 |
+
# # Extract papers from arXiv
|
| 98 |
+
# for entry in feed.entries:
|
| 99 |
+
# papers.append({
|
| 100 |
+
# "title": entry.title,
|
| 101 |
+
# "authors": ", ".join(author.name for author in entry.authors),
|
| 102 |
+
# "year": entry.published[:4],
|
| 103 |
+
# "abstract": entry.summary,
|
| 104 |
+
# "link": entry.link
|
| 105 |
+
# })
|
| 106 |
+
|
| 107 |
+
# if not papers:
|
| 108 |
+
# return [{"error": "No results found. Try different keywords."}]
|
| 109 |
+
|
| 110 |
+
# # Apply BM25 ranking
|
| 111 |
+
# tokenized_corpus = [nltk.word_tokenize(paper["title"].lower() + " " + paper["abstract"].lower()) for paper in papers]
|
| 112 |
+
# bm25 = BM25Okapi(tokenized_corpus)
|
| 113 |
+
|
| 114 |
+
# tokenized_query = nltk.word_tokenize(" ".join(keywords).lower())
|
| 115 |
+
# scores = bm25.get_scores(tokenized_query)
|
| 116 |
+
|
| 117 |
+
# # Sort papers based on BM25 score
|
| 118 |
+
# ranked_papers = sorted(zip(papers, scores), key=lambda x: x[1], reverse=True)
|
| 119 |
+
|
| 120 |
+
# # Return the most relevant ones
|
| 121 |
+
# return [paper[0] for paper in ranked_papers[:num_results]]
|
| 122 |
+
|
| 123 |
+
# except Exception as e:
|
| 124 |
+
# print(f"ERROR: {str(e)}")
|
| 125 |
+
# return [{"error": f"Error fetching research papers: {str(e)}"}]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
"""------Applied TF-IDF for better semantic search------"""
|
| 129 |
+
import numpy as np
|
| 130 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 131 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 132 |
+
import gradio as gr
|
| 133 |
+
from smolagents import CodeAgent, HfApiModel, tool
|
| 134 |
+
import nltk
|
| 135 |
+
|
| 136 |
+
nltk.download("stopwords")
|
| 137 |
+
from nltk.corpus import stopwords
|
| 138 |
+
|
| 139 |
+
@tool # ✅ Register the function properly as a SmolAgents tool
|
| 140 |
def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
|
| 141 |
+
"""Fetches and ranks arXiv papers using TF-IDF and Cosine Similarity.
|
| 142 |
|
| 143 |
Args:
|
| 144 |
keywords: List of keywords for search.
|
| 145 |
num_results: Number of results to return.
|
| 146 |
|
| 147 |
Returns:
|
| 148 |
+
List of the most relevant papers based on TF-IDF ranking.
|
| 149 |
"""
|
| 150 |
try:
|
| 151 |
print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
|
| 152 |
|
| 153 |
+
# Use a general keyword search
|
| 154 |
query = "+AND+".join([f"all:{kw}" for kw in keywords])
|
| 155 |
query_encoded = urllib.parse.quote(query)
|
| 156 |
url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
|
|
|
|
| 173 |
if not papers:
|
| 174 |
return [{"error": "No results found. Try different keywords."}]
|
| 175 |
|
| 176 |
+
# Prepare TF-IDF Vectorization
|
| 177 |
+
corpus = [paper["title"] + " " + paper["abstract"] for paper in papers]
|
| 178 |
+
vectorizer = TfidfVectorizer(stop_words=stopwords.words('english')) # Remove stopwords
|
| 179 |
+
tfidf_matrix = vectorizer.fit_transform(corpus)
|
| 180 |
+
|
| 181 |
+
# Transform Query into TF-IDF Vector
|
| 182 |
+
query_str = " ".join(keywords)
|
| 183 |
+
query_vec = vectorizer.transform([query_str])
|
| 184 |
|
| 185 |
+
#Compute Cosine Similarity
|
| 186 |
+
similarity_scores = cosine_similarity(query_vec, tfidf_matrix).flatten()
|
| 187 |
|
| 188 |
+
#Sort papers based on similarity score
|
| 189 |
+
ranked_papers = sorted(zip(papers, similarity_scores), key=lambda x: x[1], reverse=True)
|
| 190 |
|
| 191 |
+
# Return the most relevant papers
|
| 192 |
return [paper[0] for paper in ranked_papers[:num_results]]
|
| 193 |
|
| 194 |
except Exception as e:
|
|
|
|
| 261 |
results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
|
| 262 |
print(f"DEBUG: Results received - {results}") # Debug function output
|
| 263 |
|
| 264 |
+
# Check if the API returned an error
|
| 265 |
if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
|
| 266 |
return results[0]["error"] # Return the error message directly
|
| 267 |
|
| 268 |
+
# Format results only if valid papers exist
|
| 269 |
if isinstance(results, list) and results and isinstance(results[0], dict):
|
| 270 |
formatted_results = "\n\n".join([
|
| 271 |
f"---\n\n"
|