research_paper_ai_agent / research_ai_agent.py
Filip Makraduli
better nan handling
e9012d2
import os
import pandas as pd
import requests
from tqdm import tqdm
import gradio as gr
from datetime import timedelta
from typing import Any, Optional, Dict, List
from abc import ABC, abstractmethod
from sentence_transformers import SentenceTransformer
from openai import OpenAI
import numpy as np
import pickle
import hashlib
# Superlinked import matching reference
import superlinked.framework as sl
# Abstract Tool Class
class Tool(ABC):
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
def description(self) -> str:
pass
@abstractmethod
def use(self, *args, **kwargs) -> Any:
pass
# Initialize OpenAI Client
def get_openai_client():
# For Hugging Face Spaces, set this as a secret
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable in HF Space settings.")
return OpenAI(api_key=api_key)
# Download dataset function
def download_dataset(url, filename):
if not os.path.exists(filename):
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
with open(filename, 'wb') as f, tqdm(
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
size = f.write(data)
bar.update(size)
print(f"Download complete: {filename}")
else:
print(f"File {filename} already exists.")
return filename
# Data Loading Function - matching reference exactly
def load_data(filename, sample_size=100):
"""Load research papers dataset and prepare it for search"""
df = pd.read_csv(filename).head(sample_size)
# Convert to datetime (matching reference)
df['published'] = pd.to_datetime(df['published'])
# Ensure summary is a string (matching reference)
df['summary'] = df['summary'].astype(str)
# Add 'text' column for similarity search (matching reference)
df['text'] = df['title'] + " " + df['summary']
# Debug: Check for nan values in the loaded data
print(f"Data loaded: {len(df)} papers")
print(f"Columns: {list(df.columns)}")
print(f"Title nan count: {df['title'].isna().sum()}")
print(f"Summary nan count: {df['summary'].isna().sum()}")
print(f"Sample titles: {df['title'].head(3).tolist()}")
print(f"Sample summaries: {[str(s)[:50] + '...' for s in df['summary'].head(3)]}")
return df
# Superlinked Schema (matching reference exactly)
class PaperSchema(sl.Schema):
text: sl.String
published: sl.Timestamp # This will handle Unix timestamps in seconds
published_unix: sl.Timestamp # Explicit Unix seconds field
entry_id: sl.IdField
title: sl.String
summary: sl.String
def get_data_hash(df):
"""Generate a hash of the dataset to detect changes"""
# Create a hash based on the dataset content
data_str = f"{len(df)}_{df['entry_id'].iloc[0] if len(df) > 0 else ''}_{df['title'].iloc[0] if len(df) > 0 else ''}"
return hashlib.md5(data_str.encode()).hexdigest()[:8]
def save_superlinked_index(app, query, paper, df, cache_dir="superlinked_cache"):
"""Save the Superlinked index to disk"""
os.makedirs(cache_dir, exist_ok=True)
data_hash = get_data_hash(df)
cache_data = {
'app': app,
'query': query,
'paper': paper,
'data_hash': data_hash,
'dataset_size': len(df)
}
cache_file = os.path.join(cache_dir, f"index_{data_hash}_{len(df)}.pkl")
try:
with open(cache_file, 'wb') as f:
pickle.dump(cache_data, f)
print(f"βœ… Saved Superlinked index to {cache_file}")
return True
except Exception as e:
print(f"⚠️ Could not save index to cache (pickle error): {str(e)[:100]}...")
print(" Index will work fine, but won't be cached for next time.")
return False
def load_superlinked_index(df, cache_dir="superlinked_cache"):
"""Load the Superlinked index from disk if available"""
if not os.path.exists(cache_dir):
return None, None, None
data_hash = get_data_hash(df)
cache_file = os.path.join(cache_dir, f"index_{data_hash}_{len(df)}.pkl")
if not os.path.exists(cache_file):
print(f"No cached index found for this dataset")
return None, None, None
try:
with open(cache_file, 'rb') as f:
cache_data = pickle.load(f)
# Verify the cache is for the same data
if cache_data['data_hash'] == data_hash and cache_data['dataset_size'] == len(df):
print(f"βœ… Loaded cached Superlinked index from {cache_file}")
return cache_data['app'], cache_data['query'], cache_data['paper']
else:
print(f"❌ Cached index doesn't match current dataset")
return None, None, None
except Exception as e:
print(f"❌ Failed to load cached index: {e}")
return None, None, None
def load_data_efficient(filename, sample_size=100):
"""Load research papers dataset with optimized processing"""
print(f"Loading {sample_size} papers efficiently...")
# Read only the columns we need
required_cols = ['entry_id', 'title', 'summary', 'published', 'authors']
try:
# Read with optimized settings
df = pd.read_csv(
filename,
nrows=sample_size, # Only read what we need
usecols=lambda x: x in required_cols, # Only load required columns
dtype={
'entry_id': 'string',
'title': 'string',
'summary': 'string',
'authors': 'string'
}
)
except:
# Fallback to reading all columns if specific ones don't exist
df = pd.read_csv(filename).head(sample_size)
# Efficient data cleaning
df = df.dropna(subset=['title', 'summary']) # Remove rows with missing essential data
# CRITICAL FIX: Convert timestamps to proper Unix seconds (not nanoseconds)
df['published'] = pd.to_datetime(df['published'], errors='coerce')
# Convert to Unix timestamp in SECONDS (not nanoseconds)
df['published_unix'] = df['published'].astype('int64') // 10**9 # Convert nanoseconds to seconds
print(f"πŸ“… Sample publication dates:")
for i, row in df.head(3).iterrows():
pub_date = row['published']
unix_seconds = row['published_unix']
print(f" {pub_date} β†’ {unix_seconds} seconds since epoch")
# Create text column efficiently - only for non-null values
df['text'] = df['title'].fillna('') + " " + df['summary'].fillna('')
# Remove any rows that ended up empty
df = df[df['text'].str.len() > 10] # At least 10 characters
print(f"βœ… Loaded {len(df)} valid papers")
return df
def setup_superlinked_minimal(df):
"""Minimal Superlinked setup focused on speed"""
# Try cache first
cached_app, cached_query, cached_paper = load_superlinked_index(df)
if cached_app is not None:
print("πŸš€ Using cached index!")
return cached_app, cached_query, cached_paper
print(f"Building minimal Superlinked index for {len(df)} papers...")
paper = PaperSchema()
# Ultra-minimal data prep - no copying, just ensure types
df['entry_id'] = df['entry_id'].astype(str)
df['published'] = pd.to_datetime(df['published'], errors='coerce')
# CRITICAL: Use Unix seconds for Superlinked (not nanoseconds)
df['published_unix'] = df['published'].astype('int64') // 10**9
print(f"πŸ• Timestamp debugging:")
print(f"Sample published_unix values: {df['published_unix'].head(3).tolist()}")
# Calculate actual date ranges in the dataset
min_date = df['published'].min()
max_date = df['published'].max()
print(f"πŸ“… Dataset date range: {min_date} to {max_date}")
# Text similarity space
text_space = sl.TextSimilaritySpace(
text=sl.chunk(paper.text, chunk_size=1000, chunk_overlap=0),
model="sentence-transformers/all-mpnet-base-v2"
)
# CORRECT RECENCY: Following the official example pattern
# Expanded for historical dataset (1993-2025 = ~32 years)
recency_space = sl.RecencySpace(
timestamp=paper.published_unix,
period_time_list=[
sl.PeriodTime(timedelta(days=365)), # papers within 1 year
sl.PeriodTime(timedelta(days=2*365)), # papers within 2 years
sl.PeriodTime(timedelta(days=3*365)), # papers within 3 years
sl.PeriodTime(timedelta(days=5*365)), # papers within 5 years
sl.PeriodTime(timedelta(days=10*365)), # papers within 10 years
sl.PeriodTime(timedelta(days=15*365)), # papers within 15 years
sl.PeriodTime(timedelta(days=20*365)), # papers within 20 years
sl.PeriodTime(timedelta(days=25*365)), # papers within 25 years
sl.PeriodTime(timedelta(days=30*365)), # papers within 30 years
sl.PeriodTime(timedelta(days=31*365)), # papers within 31 years
sl.PeriodTime(timedelta(days=31*365 + 120)), # papers within 31.33 years (includes Feb 1994)
sl.PeriodTime(timedelta(days=32*365)), # papers within 32 years (includes both)
],
negative_filter=-0.25
)
# Create index
paper_index = sl.Index([text_space, recency_space])
# Parser
parser = sl.DataFrameParser(
paper,
mapping={
paper.entry_id: "entry_id",
paper.published_unix: "published_unix", # Use Unix seconds
paper.text: "text",
paper.title: "title",
paper.summary: "summary",
}
)
# Setup and load
source = sl.InMemorySource(paper, parser=parser)
executor = sl.InMemoryExecutor(sources=[source], indices=[paper_index])
app = executor.run()
# Single batch load
print("Loading data...")
source.put([df])
# Query with query-time weights - more robust configuration
knowledgebase_query = (
sl.Query(
paper_index,
weights={
text_space: sl.Param("relevance_weight"),
recency_space: sl.Param("recency_weight"),
}
)
.find(paper)
.similar(text_space, sl.Param("search_query"))
.select(paper.entry_id, paper.title, paper.summary, paper.published, paper.text) # Reordered for clarity
.limit(sl.Param("limit"))
)
print("βœ… Minimal index built!")
save_superlinked_index(app, knowledgebase_query, paper, df)
return app, knowledgebase_query, paper
def setup_superlinked_ultrafast(df):
"""Ultra-fast Superlinked setup following the user's examples with proper query-time weights"""
# Try to load from cache first
print("Checking for cached Superlinked index...")
cached_app, cached_query, cached_paper = load_superlinked_index(df)
if cached_app is not None:
print("πŸš€ Using cached index!")
return cached_app, cached_query, cached_paper
# Check if we built this dataset before
was_built_before = check_dataset_built_before(df)
print(f"Building ULTRA-FAST index for {len(df)} papers (following user examples)...")
if was_built_before:
print("πŸ“‹ Same dataset was built before - should be faster this time!")
paper = PaperSchema()
# Minimal data prep
df['entry_id'] = df['entry_id'].astype(str)
df['published'] = pd.to_datetime(df['published'], errors='coerce')
# CRITICAL: Use Unix seconds for Superlinked (not nanoseconds)
df['published_unix'] = df['published'].astype('int64') // 10**9
print(f"πŸ• Timestamp debugging:")
print(f"Sample published_unix values: {df['published_unix'].head(3).tolist()}")
# Calculate actual date ranges in the dataset
min_date = df['published'].min()
max_date = df['published'].max()
print(f"πŸ“… Dataset date range: {min_date} to {max_date}")
# Text similarity space
text_space = sl.TextSimilaritySpace(
text=sl.chunk(paper.text, chunk_size=1000, chunk_overlap=0),
model="sentence-transformers/all-mpnet-base-v2"
)
# CORRECT RECENCY: Following the official example pattern
# Expanded for historical dataset (1993-2025 = ~32 years)
# Added granular periods for 30-32 year range to differentiate 1993 vs 1994
recency_space = sl.RecencySpace(
timestamp=paper.published_unix,
period_time_list=[
sl.PeriodTime(timedelta(days=365)), # papers within 1 year
sl.PeriodTime(timedelta(days=2*365)), # papers within 2 years
sl.PeriodTime(timedelta(days=3*365)), # papers within 3 years
sl.PeriodTime(timedelta(days=5*365)), # papers within 5 years
sl.PeriodTime(timedelta(days=10*365)), # papers within 10 years
sl.PeriodTime(timedelta(days=15*365)), # papers within 15 years
sl.PeriodTime(timedelta(days=20*365)), # papers within 20 years
sl.PeriodTime(timedelta(days=25*365)), # papers within 25 years
sl.PeriodTime(timedelta(days=30*365)), # papers within 30 years
sl.PeriodTime(timedelta(days=31*365)), # papers within 31 years
sl.PeriodTime(timedelta(days=31*365 + 120)), # papers within 31.33 years (includes Feb 1994)
sl.PeriodTime(timedelta(days=32*365)), # papers within 32 years (includes both)
],
negative_filter=-0.25
)
# Create index with both spaces - following query_time_weights.ipynb pattern
paper_index = sl.Index([text_space, recency_space])
# Parser
parser = sl.DataFrameParser(
paper,
mapping={
paper.entry_id: "entry_id",
paper.published_unix: "published_unix", # Use Unix seconds
paper.text: "text",
paper.title: "title",
paper.summary: "summary",
}
)
# Setup and load
source = sl.InMemorySource(paper, parser=parser)
executor = sl.InMemoryExecutor(sources=[source], indices=[paper_index])
app = executor.run()
# Single batch load with debugging
print("Loading data ultra-fast...")
print(f"About to load {len(df)} papers into Superlinked...")
print("This may take 30-60 seconds for text processing and vector generation...")
import time
start_time = time.time()
try:
source.put([df])
elapsed = time.time() - start_time
print(f"βœ… Data loaded successfully in {elapsed:.1f} seconds!")
except Exception as e:
elapsed = time.time() - start_time
print(f"❌ Error after {elapsed:.1f} seconds: {e}")
raise e
# Query with query-time weights - more robust configuration
knowledgebase_query = (
sl.Query(
paper_index,
weights={
text_space: sl.Param("relevance_weight"),
recency_space: sl.Param("recency_weight"),
}
)
.find(paper)
.similar(text_space, sl.Param("search_query"))
.select(paper.entry_id, paper.title, paper.summary, paper.published, paper.text) # Reordered for clarity
.limit(sl.Param("limit"))
)
print("βœ… Ultra-fast index built (following user examples)!")
# Try to save to cache, but don't worry if it fails
save_superlinked_index(app, knowledgebase_query, paper, df)
# Save dataset info for future reference
save_dataset_info(df)
return app, knowledgebase_query, paper
# Tool Implementations (matching reference exactly)
class RetrievalTool(Tool):
def __init__(self, df, app, knowledgebase_query, client, model):
self.df = df
self.app = app
self.knowledgebase_query = knowledgebase_query
self.client = client
self.model = model
def name(self) -> str:
return "RetrievalTool"
def description(self) -> str:
return "Retrieves a list of relevant papers based on a query using Superlinked with query-time weights."
def use(self, query: str, relevance_weight: float = 1.0, recency_weight: float = 0.5) -> pd.DataFrame:
print(f"πŸ” Superlinked query: '{query}' with weights (relevance={relevance_weight}, recency={recency_weight})")
# Execute the Superlinked query with query-time weights - following user's examples
result = self.app.query(
self.knowledgebase_query,
relevance_weight=relevance_weight,
recency_weight=recency_weight,
search_query=query,
limit=5
)
print(f"πŸ” Raw Superlinked result type: {type(result)}")
print(f"πŸ” Raw result content: {result}")
# Convert to pandas DataFrame
df_result = sl.PandasConverter.to_pandas(result)
print(f"πŸ“Š Superlinked returned {len(df_result)} results")
print(f"πŸ“‹ Result columns: {list(df_result.columns)}")
# Remove duplicates if they exist (based on entry_id or id)
if 'id' in df_result.columns:
initial_count = len(df_result)
df_result = df_result.drop_duplicates(subset=['id'], keep='first')
if len(df_result) < initial_count:
print(f"πŸ”„ Removed {initial_count - len(df_result)} duplicate results")
# DEBUG: Show sample data to understand what's coming back
if len(df_result) > 0:
print(f"πŸ” Sample result data with scores:")
for i, row in df_result.head(5).iterrows():
pub_date = row.get('published', 'MISSING')
similarity = row.get('similarity_score', 'MISSING')
title = str(row.get('title', 'MISSING'))[:50]
print(f" Row {i}: published={pub_date}, similarity_score={similarity}")
print(f" title={title}...")
else:
print("⚠️ No results returned from Superlinked!")
return pd.DataFrame() # Return empty dataframe
# ALWAYS merge with original dataframe to ensure complete data
if self.df is not None:
# Handle both 'entry_id' and 'id' column names from Superlinked
merge_column = None
if 'entry_id' in df_result.columns:
merge_column = 'entry_id'
elif 'id' in df_result.columns:
# Use the full URL directly as entry_id (should match original dataset format)
df_result['entry_id'] = df_result['id']
merge_column = 'entry_id'
print(f"πŸ”— Using full URLs as entry_ids: {df_result['entry_id'].head(3).tolist()}")
print(f"πŸ” Original entry_ids sample: {self.df['entry_id'].head(3).tolist()}")
if merge_column:
print(f"πŸ”„ Merging with original dataframe using column '{merge_column}'...")
# Debug: Show what we're trying to merge
print(f"πŸ” Sample merge keys from results: {df_result['entry_id'].head(2).tolist()}")
print(f"πŸ” Sample merge keys from original: {self.df['entry_id'].head(2).tolist()}")
# ALWAYS merge to get complete data including publication dates
df_result = df_result.merge(
self.df[['entry_id', 'title', 'summary', 'published']],
on='entry_id',
how='left',
suffixes=('', '_orig')
)
# CRITICAL FIX: Use original data when current data is missing or nan
for col in ['title', 'summary', 'published']:
if f'{col}_orig' in df_result.columns:
# Fill missing/nan values with original data
mask = df_result[col].isna() | (df_result[col].astype(str).str.lower() == 'nan')
df_result.loc[mask, col] = df_result.loc[mask, f'{col}_orig']
# After merge, check if we have publication dates for debugging recency
if 'published' in df_result.columns:
print(f"πŸ“… Publication dates after merge:")
for i, row in df_result.head(5).iterrows():
pub_date = row.get('published', 'MISSING')
similarity = row.get('similarity_score', 'MISSING')
print(f" Row {i}: {pub_date} (score: {similarity})")
print(f"βœ… After merge: {list(df_result.columns)}")
print(f"πŸ“Š Merge success: {df_result['title'].notna().sum()}/{len(df_result)} papers have titles")
else:
print("⚠️ No suitable ID column found for merging!")
# If no papers have titles after merge, return empty dataframe
if len(df_result) > 0 and 'title' in df_result.columns:
valid_papers = df_result['title'].notna().sum()
if valid_papers == 0:
print("⚠️ No papers found with valid data after merge - returning empty results")
return pd.DataFrame()
return df_result
class SummarizationTool(Tool):
def __init__(self, df, client, model):
self.df = df
self.client = client
self.model = model
def name(self) -> str:
return "SummarizationTool"
def description(self) -> str:
return "Generates a concise summary of specified papers using an LLM."
def use(self, query: str, paper_ids: list) -> str:
papers = self.df[self.df['entry_id'].isin(paper_ids)]
if papers.empty:
return "No papers found with the given IDs."
summaries = papers['summary'].tolist()
summary_str = "\n\n".join(summaries)
prompt = f"""
Summarize the following paper summaries:\n\n{summary_str}\n\nProvide a concise summary.
"""
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=500
)
return response.choices[0].message.content.strip()
class QuestionAnsweringTool(Tool):
def __init__(self, retrieval_tool, client, model):
self.retrieval_tool = retrieval_tool
self.client = client
self.model = model
def name(self) -> str:
return "QuestionAnsweringTool"
def description(self) -> str:
return "Answers questions about research topics using retrieved paper summaries or general knowledge if no specific context is available."
def use(self, query: str, relevance_weight: float = 1.0, recency_weight: float = 0.5) -> str:
df_result = self.retrieval_tool.use(query, relevance_weight, recency_weight)
if 'summary' not in df_result.columns:
# Tag as a general question if summary is missing
prompt = f"""
You are a knowledgeable research assistant. This is a general question tagged as [GENERAL]. Answer based on your broad knowledge, not limited to specific paper summaries. If you don't know the answer, provide a brief explanation of why.
User's question: {query}
"""
else:
# Use paper summaries for specific context
contexts = df_result['summary'].tolist()
context_str = "\n\n".join(contexts)
prompt = f"""
You are a research assistant. Use the following paper summaries to answer the user's question. If you don't know the answer based on the summaries, say 'I don't know.'
Paper summaries:
{context_str}
User's question: {query}
"""
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=500
)
return response.choices[0].message.content.strip()
# Kernel Agent (matching reference exactly)
class KernelAgent:
def __init__(self, retrieval_tool: RetrievalTool, summarization_tool: SummarizationTool, question_answering_tool: QuestionAnsweringTool, client, model):
self.retrieval_tool = retrieval_tool
self.summarization_tool = summarization_tool
self.question_answering_tool = question_answering_tool
self.client = client
self.model = model
def classify_query(self, query: str) -> str:
prompt = f"""
Classify the following user prompt into one of the three categories:
- retrieval: The user wants to find a list of papers based on some criteria (e.g., 'Find papers on AI ethics from 2020').
- summarization: The user wants to summarize a list of papers (e.g., 'Summarize papers with entry_id 123, 456, 789').
- question_answering: The user wants to ask a question about research topics and get an answer (e.g., 'What is the latest development in AI ethics?').
User prompt: {query}
Respond with only the category name (retrieval, summarization, question_answering).
If unsure, respond with 'unknown'.
"""
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=10
)
classification = response.choices[0].message.content.strip().lower()
print(f"Query type: {classification}")
return classification
def process_query(self, query: str, params: Optional[Dict] = None, relevance_weight: float = 1.0, recency_weight: float = 0.5) -> str:
query_type = self.classify_query(query)
if query_type == 'retrieval':
df_result = self.retrieval_tool.use(query, relevance_weight, recency_weight)
response = "Here are the top papers:\n"
for i, row in df_result.iterrows():
# Ensure summary is a string and handle empty cases
summary = str(row['summary']) if pd.notna(row['summary']) else ""
response += f"{i+1}. {row['title']} \nSummary: {summary[:200]}...\n\n"
return response
elif query_type == 'summarization':
if not params or 'paper_ids' not in params:
return "Error: Summarization query requires a 'paper_ids' parameter with a list of entry_ids."
return self.summarization_tool.use(query, params['paper_ids'])
elif query_type == 'question_answering':
return self.question_answering_tool.use(query, relevance_weight, recency_weight)
else:
return "Error: Unable to classify query as 'retrieval', 'summarization', or 'question_answering'."
def setup_agent(dataset_size, fast_mode=True):
"""Set up the research agent with optional fast mode"""
sample_size = int(dataset_size)
# Download and load data
dataset_url = "https://drive.google.com/uc?export=download&id=1FCR3TW5yLjGhEmm-Uclw0_5PWVEaLk1j"
filename = "arxiv_ai_data.csv"
download_dataset(dataset_url, filename)
df = load_data_efficient(filename, sample_size=sample_size)
# Debug: Show sample data to understand the structure
print("Dataset loaded successfully!")
print(f"Dataset shape: {df.shape}")
print(f"Dataset columns: {list(df.columns)}")
print("\nSample entry_ids:")
print(df['entry_id'].head(3).tolist())
print("\nSample titles:")
for i, title in enumerate(df['title'].head(3)):
print(f"{i+1}. {title}")
# Set up Superlinked - choose mode based on dataset size
print(f"\nSetting up Superlinked vector database (fast_mode={fast_mode})...")
if fast_mode and sample_size <= 20:
app, knowledgebase_query, paper = setup_superlinked_ultrafast(df)
else:
app, knowledgebase_query, paper = setup_superlinked_minimal(df)
# Initialize OpenAI
client = get_openai_client()
model = "gpt-4o-mini"
# Initialize tools (following reference pattern)
retrieval_tool = RetrievalTool(df, app, knowledgebase_query, client, model)
summarization_tool = SummarizationTool(df, client, model)
question_answering_tool = QuestionAnsweringTool(retrieval_tool, client, model)
# Initialize KernelAgent
kernel_agent = KernelAgent(retrieval_tool, summarization_tool, question_answering_tool, client, model)
return kernel_agent, df
def create_paper_cards(df_result):
"""Create beautiful HTML cards for paper results similar to Daily Papers"""
if len(df_result) == 0:
return "<p>No papers found for your query.</p>"
print(f"🎨 Creating cards for {len(df_result)} results...")
print(f"πŸ“‹ Available columns: {list(df_result.columns)}")
# Show sample data for debugging
if len(df_result) > 0:
print(f"πŸ” Sample card data:")
for i, row in df_result.head(2).iterrows():
title = row.get('title', 'MISSING')
summary = str(row.get('summary', 'MISSING'))[:50]
published = row.get('published', 'MISSING')
print(f" Row {i}: title='{title}', summary='{summary}...', published={published}")
# MINIMAL filtering - only remove completely broken rows
valid_results = []
for i, row in df_result.iterrows():
# Accept ALL rows that have any title data - don't be picky
title = str(row.get('title', ''))
if title and title.lower() not in ['nan', 'none', 'null', '']:
valid_results.append(row)
else:
print(f"⚠️ Skipping row {i} with invalid title: '{title}'")
print(f"πŸ“Š Original results: {len(df_result)}, Valid results after filtering: {len(valid_results)}")
if len(valid_results) == 0:
return "<div style='padding: 20px; text-align: center; color: #666;'>No valid papers found for your query. Try a different search term.</div>"
cards_html = """
<style>
.paper-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(320px, 1fr));
gap: 16px;
margin: 20px 0;
padding: 0 4px;
}
.paper-card {
background: #1a1a1a;
border: 1px solid #333;
border-radius: 12px;
padding: 20px;
color: #e0e0e0;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
transition: all 0.2s ease;
position: relative;
overflow: hidden;
}
.paper-card:hover {
transform: translateY(-2px);
box-shadow: 0 8px 24px rgba(0,0,0,0.4);
border-color: #555;
}
.paper-card::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
height: 2px;
background: linear-gradient(90deg, #666, #888, #666);
}
.paper-title {
font-size: 16px;
font-weight: 600;
margin-bottom: 12px;
line-height: 1.4;
color: #ffffff;
padding-right: 40px;
word-wrap: break-word;
overflow-wrap: break-word;
}
.paper-meta {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 12px;
flex-wrap: wrap;
gap: 8px;
}
.paper-date {
background: #2a2a2a;
border: 1px solid #444;
padding: 4px 8px;
border-radius: 6px;
font-size: 11px;
font-weight: 500;
color: #ccc;
}
.paper-relevance {
background: #2a2a2a;
border: 1px solid #444;
padding: 3px 6px;
border-radius: 4px;
font-size: 10px;
font-weight: 600;
color: #fff;
}
.paper-summary {
font-size: 13px;
line-height: 1.5;
color: #ccc;
margin-bottom: 16px;
background: #222;
padding: 12px;
border-radius: 8px;
border: 1px solid #333;
}
.paper-id {
font-size: 10px;
color: #888;
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
background: #1e1e1e;
padding: 6px 8px;
border-radius: 4px;
word-break: break-all;
border: 1px solid #333;
}
.rank-badge {
position: absolute;
top: 16px;
right: 16px;
background: #2a2a2a;
border: 1px solid #444;
color: #fff;
width: 28px;
height: 28px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-weight: 700;
font-size: 12px;
}
.search-stats {
background: #1a1a1a;
border: 1px solid #333;
color: #e0e0e0;
padding: 16px 20px;
border-radius: 8px;
margin-bottom: 20px;
text-align: center;
}
.search-stats h3 {
margin: 0 0 6px 0;
font-size: 16px;
font-weight: 600;
color: #fff;
}
.search-stats p {
margin: 0;
color: #aaa;
font-size: 13px;
}
@media (max-width: 768px) {
.paper-grid {
grid-template-columns: 1fr;
gap: 12px;
margin: 16px 0;
}
.paper-card {
padding: 16px;
}
.paper-title {
font-size: 15px;
padding-right: 35px;
}
}
</style>
<div class="search-stats">
<h3>πŸ” Superlinked Search Results</h3>
<p>Found {len(valid_results)} relevant papers β€’ Ranked by semantic similarity + recency</p>
</div>
<div class="paper-grid">
"""
for i, row in enumerate(valid_results):
# Extract data safely with better fallbacks
title = str(row.get('title', 'Unknown Title'))
# Try to get summary from multiple possible columns
summary = None
for summary_col in ['summary', 'summary_orig']:
if summary_col in row and pd.notna(row.get(summary_col)):
candidate_summary = str(row[summary_col])
if candidate_summary.lower() not in ['nan', 'none', 'null', '']:
summary = candidate_summary
break
if not summary:
summary = "Summary not available in search results"
# Try multiple ways to get entry_id
entry_id = None
for id_col in ['entry_id', 'id', 'paper_id']:
if id_col in row and pd.notna(row.get(id_col)):
entry_id = str(row[id_col])
break
if not entry_id:
entry_id = f"Paper_{i+1}" # Fallback ID
# Handle summary display
if len(summary) > 200:
summary_display = summary[:200] + "..."
else:
summary_display = summary
# Format publication date - try multiple columns
formatted_date = "Date not available"
for date_col in ['published', 'published_orig']:
if date_col in row and pd.notna(row.get(date_col)):
try:
pub_date = pd.to_datetime(row[date_col])
formatted_date = pub_date.strftime('%B %Y') # e.g., "March 2023"
break
except:
formatted_date = str(row[date_col])[:10]
break
# LONGER title display - don't truncate so aggressively
title_display = title[:120] + "..." if len(title) > 120 else title
# Clean up the displays
title_display = title_display.replace('nan', 'Unknown Title')
# Relevance indicator
if i == 0:
relevance_text = "Most Relevant"
elif i == 1:
relevance_text = "Highly Relevant"
else:
relevance_text = f"Rank {i+1}"
card_html = f"""
<div class="paper-card">
<div class="rank-badge">#{i+1}</div>
<div class="paper-title">{title_display}</div>
<div class="paper-meta">
<div class="paper-date">πŸ“… {formatted_date}</div>
<div class="paper-relevance">{relevance_text}</div>
</div>
<div class="paper-summary">{summary_display}</div>
<div class="paper-id">πŸ”— {entry_id}</div>
</div>
"""
cards_html += card_html
cards_html += "</div>"
return cards_html
# Gradio Interface Functions
def show_loading_state(query, relevance_weight, recency_weight):
"""Show immediate loading state when search button is clicked"""
loading_html = f"""
<div style='text-align: center; padding: 40px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; color: white; margin: 20px 0;'>
<div style='font-size: 24px; margin-bottom: 15px;'>
πŸ” <strong>Searching with Superlinked...</strong>
</div>
<div style='font-size: 16px; margin-bottom: 20px; opacity: 0.9;'>
Query: "<em>{query}</em>"
</div>
<div style='font-size: 14px; margin-bottom: 20px; opacity: 0.8;'>
Relevance Weight: <strong>{relevance_weight}</strong> | Recency Weight: <strong>{recency_weight}</strong>
</div>
<div style='font-size: 14px; opacity: 0.7;'>
⚑ Performing semantic similarity + recency search...
</div>
<div style='margin-top: 20px;'>
<div style='display: inline-block; width: 20px; height: 20px; border: 3px solid rgba(255,255,255,0.3); border-radius: 50%; border-top-color: white; animation: spin 1s ease-in-out infinite;'></div>
</div>
</div>
<style>
@keyframes spin {{
to {{ transform: rotate(360deg); }}
}}
</style>
"""
return loading_html
def process_query(query, relevance_weight, recency_weight):
if agent is None:
return "❌ **Agent not initialized!** Please click 'Initialize Agent' first to load the dataset and set up the search system."
# SIMPLIFIED: Always do retrieval search with Superlinked - no classification needed
print(f"πŸ” Performing Superlinked search with weights: relevance={relevance_weight}, recency={recency_weight}")
try:
# Direct Superlinked search using the retrieval tool
df_result = agent.retrieval_tool.use(query, relevance_weight, recency_weight)
print(f"βœ… Found {len(df_result)} results")
return create_paper_cards(df_result)
except Exception as e:
error_msg = f"❌ Search failed: {str(e)}"
print(error_msg)
return f"<div style='padding: 20px; background: #ffe6e6; border-radius: 8px; color: #d63031;'>{error_msg}</div>"
def get_paper_list():
if df is None:
return "❌ **Agent not initialized!** Please click 'Initialize Agent' first to load the dataset."
paper_list = []
for i, row in df.iterrows():
paper_list.append(f"{row['entry_id']} - {row['title']}")
return "\n".join(paper_list)
# Function to initialize the agent and enable buttons
def initialize_agent_and_enable_buttons(dataset_size):
global agent, df
# Convert to int in case it comes as string from UI
sample_size = int(dataset_size)
try:
# Set up the agent using the reference pattern
agent, df = setup_agent(sample_size)
success_html = f"""
<div style="text-align: center; margin-bottom: 10px;">
<p>βœ… <b>Agent initialized successfully!</b></p>
<p>Loaded {sample_size} papers from the dataset.</p>
</div>
"""
# Enable the search button
return [
success_html,
gr.update(interactive=True),
gr.update(interactive=True)
]
except Exception as e:
error_html = f"""
<div style="text-align: center; margin-bottom: 10px;">
<p>❌ <b>Initialization failed!</b></p>
<p>Error: {str(e)}</p>
</div>
"""
return [
error_html,
gr.update(interactive=False),
gr.update(interactive=False)
]
# Global variables for the agent and DataFrame
agent = None
df = None
# Create Gradio interface
with gr.Blocks(title="Research AI Agent") as demo:
gr.Markdown("# πŸ”¬ Research AI Agent")
gr.Markdown("""**Powered by Superlinked Vector Search** πŸš€
This app demonstrates **direct Superlinked vector search** for AI research papers.
Search uses semantic similarity combined with publication recency, with adjustable query-time weights.
*Enter any research topic and get relevant papers ranked by both content similarity and recency.*""")
# Add initialization message and UI controls
with gr.Row():
with gr.Column(scale=3):
init_message = gr.HTML(
"""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 15px; margin-bottom: 15px;">
<h3 style="margin: 0 0 10px 0; color: #856404;">πŸš€ Welcome to Superlinked Search!</h3>
<p style="margin: 0; color: #856404;"><strong>Step 1:</strong> Select how many papers to load (10-100)</p>
<p style="margin: 5px 0 0 0; color: #856404;"><strong>Step 2:</strong> Click "Initialize Agent" to build the vector index</p>
<p style="margin: 5px 0 0 0; color: #856404;"><strong>Step 3:</strong> Search for papers using semantic similarity + recency!</p>
</div>
"""
)
with gr.Column(scale=1):
dataset_size = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=10,
label="Number of Papers to Load",
info="Recommended: 50+ papers for better search results"
)
init_button = gr.Button("Initialize Agent", variant="primary")
# Superlinked filter controls
with gr.Row():
with gr.Column():
relevance_weight = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.1,
label="Relevance Weight",
info="Higher values prioritize content similarity"
)
with gr.Column():
recency_weight = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.5,
step=0.1,
label="Recency Weight",
info="Higher values prioritize recent papers"
)
# Simple search interface with better styling
with gr.Row():
with gr.Column(scale=4):
query_input = gr.Textbox(
label="Superlinked Search Query",
placeholder="Example: quantum computing, machine learning, neural networks, transformer models...",
lines=2
)
with gr.Column(scale=1):
search_button = gr.Button(
"πŸ” Search Papers",
variant="primary",
interactive=False,
size="lg"
)
output = gr.HTML(label="Results")
# Smaller button for viewing available papers - placed lower
with gr.Row():
with gr.Column():
gr.Markdown("") # Spacer
with gr.Column(scale=1):
paper_list_button = gr.Button(
"πŸ“‹ View Available Papers",
interactive=False,
size="sm",
variant="secondary"
)
with gr.Column():
gr.Markdown("") # Spacer
paper_list_output = gr.Textbox(label="Available Papers", lines=10, visible=False)
# Connect the initialize button to the function with the dataset size parameter
init_button.click(
initialize_agent_and_enable_buttons,
inputs=[dataset_size],
outputs=[
init_message,
search_button,
paper_list_button
]
)
# Connect search button to show loading first, then results
search_button.click(
show_loading_state,
inputs=[query_input, relevance_weight, recency_weight],
outputs=output
).then(
process_query,
inputs=[query_input, relevance_weight, recency_weight],
outputs=output
)
paper_list_button.click(
lambda: (get_paper_list(), gr.update(visible=True)),
outputs=[paper_list_output, paper_list_output]
)
def save_dataset_info(df, cache_dir="superlinked_cache"):
"""Save just the dataset info to avoid rebuilding for same data"""
try:
os.makedirs(cache_dir, exist_ok=True)
data_hash = get_data_hash(df)
# Save just basic info, not the complex Superlinked objects
info_data = {
'data_hash': data_hash,
'dataset_size': len(df),
'sample_titles': df['title'].head(3).tolist(),
'build_time': pd.Timestamp.now().isoformat()
}
info_file = os.path.join(cache_dir, f"dataset_info_{data_hash}_{len(df)}.json")
import json
with open(info_file, 'w') as f:
json.dump(info_data, f)
print(f"βœ… Saved dataset info for future reference")
return True
except Exception as e:
print(f"⚠️ Could not save dataset info: {e}")
return False
def check_dataset_built_before(df, cache_dir="superlinked_cache"):
"""Check if this exact dataset was built before"""
try:
if not os.path.exists(cache_dir):
return False
data_hash = get_data_hash(df)
info_file = os.path.join(cache_dir, f"dataset_info_{data_hash}_{len(df)}.json")
if os.path.exists(info_file):
import json
with open(info_file, 'r') as f:
info_data = json.load(f)
if info_data['data_hash'] == data_hash and info_data['dataset_size'] == len(df):
print(f"πŸ“‹ This dataset was built before on {info_data['build_time'][:19]}")
return True
return False
except:
return False
if __name__ == "__main__":
# Launch with minimal cache to reduce memory usage
demo.launch(share=False, cache_examples=False)