Spaces:
Build error
Build error
| import os | |
| import pandas as pd | |
| import requests | |
| from tqdm import tqdm | |
| import gradio as gr | |
| from datetime import timedelta | |
| from typing import Any, Optional, Dict, List | |
| from abc import ABC, abstractmethod | |
| from sentence_transformers import SentenceTransformer | |
| from openai import OpenAI | |
| import numpy as np | |
| import pickle | |
| import hashlib | |
| # Superlinked import matching reference | |
| import superlinked.framework as sl | |
| # Abstract Tool Class | |
| class Tool(ABC): | |
| def name(self) -> str: | |
| pass | |
| def description(self) -> str: | |
| pass | |
| def use(self, *args, **kwargs) -> Any: | |
| pass | |
| # Initialize OpenAI Client | |
| def get_openai_client(): | |
| # For Hugging Face Spaces, set this as a secret | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("Please set the OPENAI_API_KEY environment variable in HF Space settings.") | |
| return OpenAI(api_key=api_key) | |
| # Download dataset function | |
| def download_dataset(url, filename): | |
| if not os.path.exists(filename): | |
| print(f"Downloading {filename}...") | |
| response = requests.get(url, stream=True) | |
| total_size = int(response.headers.get('content-length', 0)) | |
| block_size = 1024 | |
| with open(filename, 'wb') as f, tqdm( | |
| total=total_size, | |
| unit='iB', | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as bar: | |
| for data in response.iter_content(block_size): | |
| size = f.write(data) | |
| bar.update(size) | |
| print(f"Download complete: {filename}") | |
| else: | |
| print(f"File {filename} already exists.") | |
| return filename | |
| # Data Loading Function - matching reference exactly | |
| def load_data(filename, sample_size=100): | |
| """Load research papers dataset and prepare it for search""" | |
| df = pd.read_csv(filename).head(sample_size) | |
| # Convert to datetime (matching reference) | |
| df['published'] = pd.to_datetime(df['published']) | |
| # Ensure summary is a string (matching reference) | |
| df['summary'] = df['summary'].astype(str) | |
| # Add 'text' column for similarity search (matching reference) | |
| df['text'] = df['title'] + " " + df['summary'] | |
| # Debug: Check for nan values in the loaded data | |
| print(f"Data loaded: {len(df)} papers") | |
| print(f"Columns: {list(df.columns)}") | |
| print(f"Title nan count: {df['title'].isna().sum()}") | |
| print(f"Summary nan count: {df['summary'].isna().sum()}") | |
| print(f"Sample titles: {df['title'].head(3).tolist()}") | |
| print(f"Sample summaries: {[str(s)[:50] + '...' for s in df['summary'].head(3)]}") | |
| return df | |
| # Superlinked Schema (matching reference exactly) | |
| class PaperSchema(sl.Schema): | |
| text: sl.String | |
| published: sl.Timestamp # This will handle Unix timestamps in seconds | |
| published_unix: sl.Timestamp # Explicit Unix seconds field | |
| entry_id: sl.IdField | |
| title: sl.String | |
| summary: sl.String | |
| def get_data_hash(df): | |
| """Generate a hash of the dataset to detect changes""" | |
| # Create a hash based on the dataset content | |
| data_str = f"{len(df)}_{df['entry_id'].iloc[0] if len(df) > 0 else ''}_{df['title'].iloc[0] if len(df) > 0 else ''}" | |
| return hashlib.md5(data_str.encode()).hexdigest()[:8] | |
| def save_superlinked_index(app, query, paper, df, cache_dir="superlinked_cache"): | |
| """Save the Superlinked index to disk""" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| data_hash = get_data_hash(df) | |
| cache_data = { | |
| 'app': app, | |
| 'query': query, | |
| 'paper': paper, | |
| 'data_hash': data_hash, | |
| 'dataset_size': len(df) | |
| } | |
| cache_file = os.path.join(cache_dir, f"index_{data_hash}_{len(df)}.pkl") | |
| try: | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(cache_data, f) | |
| print(f"β Saved Superlinked index to {cache_file}") | |
| return True | |
| except Exception as e: | |
| print(f"β οΈ Could not save index to cache (pickle error): {str(e)[:100]}...") | |
| print(" Index will work fine, but won't be cached for next time.") | |
| return False | |
| def load_superlinked_index(df, cache_dir="superlinked_cache"): | |
| """Load the Superlinked index from disk if available""" | |
| if not os.path.exists(cache_dir): | |
| return None, None, None | |
| data_hash = get_data_hash(df) | |
| cache_file = os.path.join(cache_dir, f"index_{data_hash}_{len(df)}.pkl") | |
| if not os.path.exists(cache_file): | |
| print(f"No cached index found for this dataset") | |
| return None, None, None | |
| try: | |
| with open(cache_file, 'rb') as f: | |
| cache_data = pickle.load(f) | |
| # Verify the cache is for the same data | |
| if cache_data['data_hash'] == data_hash and cache_data['dataset_size'] == len(df): | |
| print(f"β Loaded cached Superlinked index from {cache_file}") | |
| return cache_data['app'], cache_data['query'], cache_data['paper'] | |
| else: | |
| print(f"β Cached index doesn't match current dataset") | |
| return None, None, None | |
| except Exception as e: | |
| print(f"β Failed to load cached index: {e}") | |
| return None, None, None | |
| def load_data_efficient(filename, sample_size=100): | |
| """Load research papers dataset with optimized processing""" | |
| print(f"Loading {sample_size} papers efficiently...") | |
| # Read only the columns we need | |
| required_cols = ['entry_id', 'title', 'summary', 'published', 'authors'] | |
| try: | |
| # Read with optimized settings | |
| df = pd.read_csv( | |
| filename, | |
| nrows=sample_size, # Only read what we need | |
| usecols=lambda x: x in required_cols, # Only load required columns | |
| dtype={ | |
| 'entry_id': 'string', | |
| 'title': 'string', | |
| 'summary': 'string', | |
| 'authors': 'string' | |
| } | |
| ) | |
| except: | |
| # Fallback to reading all columns if specific ones don't exist | |
| df = pd.read_csv(filename).head(sample_size) | |
| # Efficient data cleaning | |
| df = df.dropna(subset=['title', 'summary']) # Remove rows with missing essential data | |
| # CRITICAL FIX: Convert timestamps to proper Unix seconds (not nanoseconds) | |
| df['published'] = pd.to_datetime(df['published'], errors='coerce') | |
| # Convert to Unix timestamp in SECONDS (not nanoseconds) | |
| df['published_unix'] = df['published'].astype('int64') // 10**9 # Convert nanoseconds to seconds | |
| print(f"π Sample publication dates:") | |
| for i, row in df.head(3).iterrows(): | |
| pub_date = row['published'] | |
| unix_seconds = row['published_unix'] | |
| print(f" {pub_date} β {unix_seconds} seconds since epoch") | |
| # Create text column efficiently - only for non-null values | |
| df['text'] = df['title'].fillna('') + " " + df['summary'].fillna('') | |
| # Remove any rows that ended up empty | |
| df = df[df['text'].str.len() > 10] # At least 10 characters | |
| print(f"β Loaded {len(df)} valid papers") | |
| return df | |
| def setup_superlinked_minimal(df): | |
| """Minimal Superlinked setup focused on speed""" | |
| # Try cache first | |
| cached_app, cached_query, cached_paper = load_superlinked_index(df) | |
| if cached_app is not None: | |
| print("π Using cached index!") | |
| return cached_app, cached_query, cached_paper | |
| print(f"Building minimal Superlinked index for {len(df)} papers...") | |
| paper = PaperSchema() | |
| # Ultra-minimal data prep - no copying, just ensure types | |
| df['entry_id'] = df['entry_id'].astype(str) | |
| df['published'] = pd.to_datetime(df['published'], errors='coerce') | |
| # CRITICAL: Use Unix seconds for Superlinked (not nanoseconds) | |
| df['published_unix'] = df['published'].astype('int64') // 10**9 | |
| print(f"π Timestamp debugging:") | |
| print(f"Sample published_unix values: {df['published_unix'].head(3).tolist()}") | |
| # Calculate actual date ranges in the dataset | |
| min_date = df['published'].min() | |
| max_date = df['published'].max() | |
| print(f"π Dataset date range: {min_date} to {max_date}") | |
| # Text similarity space | |
| text_space = sl.TextSimilaritySpace( | |
| text=sl.chunk(paper.text, chunk_size=1000, chunk_overlap=0), | |
| model="sentence-transformers/all-mpnet-base-v2" | |
| ) | |
| # CORRECT RECENCY: Following the official example pattern | |
| # Expanded for historical dataset (1993-2025 = ~32 years) | |
| recency_space = sl.RecencySpace( | |
| timestamp=paper.published_unix, | |
| period_time_list=[ | |
| sl.PeriodTime(timedelta(days=365)), # papers within 1 year | |
| sl.PeriodTime(timedelta(days=2*365)), # papers within 2 years | |
| sl.PeriodTime(timedelta(days=3*365)), # papers within 3 years | |
| sl.PeriodTime(timedelta(days=5*365)), # papers within 5 years | |
| sl.PeriodTime(timedelta(days=10*365)), # papers within 10 years | |
| sl.PeriodTime(timedelta(days=15*365)), # papers within 15 years | |
| sl.PeriodTime(timedelta(days=20*365)), # papers within 20 years | |
| sl.PeriodTime(timedelta(days=25*365)), # papers within 25 years | |
| sl.PeriodTime(timedelta(days=30*365)), # papers within 30 years | |
| sl.PeriodTime(timedelta(days=31*365)), # papers within 31 years | |
| sl.PeriodTime(timedelta(days=31*365 + 120)), # papers within 31.33 years (includes Feb 1994) | |
| sl.PeriodTime(timedelta(days=32*365)), # papers within 32 years (includes both) | |
| ], | |
| negative_filter=-0.25 | |
| ) | |
| # Create index | |
| paper_index = sl.Index([text_space, recency_space]) | |
| # Parser | |
| parser = sl.DataFrameParser( | |
| paper, | |
| mapping={ | |
| paper.entry_id: "entry_id", | |
| paper.published_unix: "published_unix", # Use Unix seconds | |
| paper.text: "text", | |
| paper.title: "title", | |
| paper.summary: "summary", | |
| } | |
| ) | |
| # Setup and load | |
| source = sl.InMemorySource(paper, parser=parser) | |
| executor = sl.InMemoryExecutor(sources=[source], indices=[paper_index]) | |
| app = executor.run() | |
| # Single batch load | |
| print("Loading data...") | |
| source.put([df]) | |
| # Query with query-time weights - more robust configuration | |
| knowledgebase_query = ( | |
| sl.Query( | |
| paper_index, | |
| weights={ | |
| text_space: sl.Param("relevance_weight"), | |
| recency_space: sl.Param("recency_weight"), | |
| } | |
| ) | |
| .find(paper) | |
| .similar(text_space, sl.Param("search_query")) | |
| .select(paper.entry_id, paper.title, paper.summary, paper.published, paper.text) # Reordered for clarity | |
| .limit(sl.Param("limit")) | |
| ) | |
| print("β Minimal index built!") | |
| save_superlinked_index(app, knowledgebase_query, paper, df) | |
| return app, knowledgebase_query, paper | |
| def setup_superlinked_ultrafast(df): | |
| """Ultra-fast Superlinked setup following the user's examples with proper query-time weights""" | |
| # Try to load from cache first | |
| print("Checking for cached Superlinked index...") | |
| cached_app, cached_query, cached_paper = load_superlinked_index(df) | |
| if cached_app is not None: | |
| print("π Using cached index!") | |
| return cached_app, cached_query, cached_paper | |
| # Check if we built this dataset before | |
| was_built_before = check_dataset_built_before(df) | |
| print(f"Building ULTRA-FAST index for {len(df)} papers (following user examples)...") | |
| if was_built_before: | |
| print("π Same dataset was built before - should be faster this time!") | |
| paper = PaperSchema() | |
| # Minimal data prep | |
| df['entry_id'] = df['entry_id'].astype(str) | |
| df['published'] = pd.to_datetime(df['published'], errors='coerce') | |
| # CRITICAL: Use Unix seconds for Superlinked (not nanoseconds) | |
| df['published_unix'] = df['published'].astype('int64') // 10**9 | |
| print(f"π Timestamp debugging:") | |
| print(f"Sample published_unix values: {df['published_unix'].head(3).tolist()}") | |
| # Calculate actual date ranges in the dataset | |
| min_date = df['published'].min() | |
| max_date = df['published'].max() | |
| print(f"π Dataset date range: {min_date} to {max_date}") | |
| # Text similarity space | |
| text_space = sl.TextSimilaritySpace( | |
| text=sl.chunk(paper.text, chunk_size=1000, chunk_overlap=0), | |
| model="sentence-transformers/all-mpnet-base-v2" | |
| ) | |
| # CORRECT RECENCY: Following the official example pattern | |
| # Expanded for historical dataset (1993-2025 = ~32 years) | |
| # Added granular periods for 30-32 year range to differentiate 1993 vs 1994 | |
| recency_space = sl.RecencySpace( | |
| timestamp=paper.published_unix, | |
| period_time_list=[ | |
| sl.PeriodTime(timedelta(days=365)), # papers within 1 year | |
| sl.PeriodTime(timedelta(days=2*365)), # papers within 2 years | |
| sl.PeriodTime(timedelta(days=3*365)), # papers within 3 years | |
| sl.PeriodTime(timedelta(days=5*365)), # papers within 5 years | |
| sl.PeriodTime(timedelta(days=10*365)), # papers within 10 years | |
| sl.PeriodTime(timedelta(days=15*365)), # papers within 15 years | |
| sl.PeriodTime(timedelta(days=20*365)), # papers within 20 years | |
| sl.PeriodTime(timedelta(days=25*365)), # papers within 25 years | |
| sl.PeriodTime(timedelta(days=30*365)), # papers within 30 years | |
| sl.PeriodTime(timedelta(days=31*365)), # papers within 31 years | |
| sl.PeriodTime(timedelta(days=31*365 + 120)), # papers within 31.33 years (includes Feb 1994) | |
| sl.PeriodTime(timedelta(days=32*365)), # papers within 32 years (includes both) | |
| ], | |
| negative_filter=-0.25 | |
| ) | |
| # Create index with both spaces - following query_time_weights.ipynb pattern | |
| paper_index = sl.Index([text_space, recency_space]) | |
| # Parser | |
| parser = sl.DataFrameParser( | |
| paper, | |
| mapping={ | |
| paper.entry_id: "entry_id", | |
| paper.published_unix: "published_unix", # Use Unix seconds | |
| paper.text: "text", | |
| paper.title: "title", | |
| paper.summary: "summary", | |
| } | |
| ) | |
| # Setup and load | |
| source = sl.InMemorySource(paper, parser=parser) | |
| executor = sl.InMemoryExecutor(sources=[source], indices=[paper_index]) | |
| app = executor.run() | |
| # Single batch load with debugging | |
| print("Loading data ultra-fast...") | |
| print(f"About to load {len(df)} papers into Superlinked...") | |
| print("This may take 30-60 seconds for text processing and vector generation...") | |
| import time | |
| start_time = time.time() | |
| try: | |
| source.put([df]) | |
| elapsed = time.time() - start_time | |
| print(f"β Data loaded successfully in {elapsed:.1f} seconds!") | |
| except Exception as e: | |
| elapsed = time.time() - start_time | |
| print(f"β Error after {elapsed:.1f} seconds: {e}") | |
| raise e | |
| # Query with query-time weights - more robust configuration | |
| knowledgebase_query = ( | |
| sl.Query( | |
| paper_index, | |
| weights={ | |
| text_space: sl.Param("relevance_weight"), | |
| recency_space: sl.Param("recency_weight"), | |
| } | |
| ) | |
| .find(paper) | |
| .similar(text_space, sl.Param("search_query")) | |
| .select(paper.entry_id, paper.title, paper.summary, paper.published, paper.text) # Reordered for clarity | |
| .limit(sl.Param("limit")) | |
| ) | |
| print("β Ultra-fast index built (following user examples)!") | |
| # Try to save to cache, but don't worry if it fails | |
| save_superlinked_index(app, knowledgebase_query, paper, df) | |
| # Save dataset info for future reference | |
| save_dataset_info(df) | |
| return app, knowledgebase_query, paper | |
| # Tool Implementations (matching reference exactly) | |
| class RetrievalTool(Tool): | |
| def __init__(self, df, app, knowledgebase_query, client, model): | |
| self.df = df | |
| self.app = app | |
| self.knowledgebase_query = knowledgebase_query | |
| self.client = client | |
| self.model = model | |
| def name(self) -> str: | |
| return "RetrievalTool" | |
| def description(self) -> str: | |
| return "Retrieves a list of relevant papers based on a query using Superlinked with query-time weights." | |
| def use(self, query: str, relevance_weight: float = 1.0, recency_weight: float = 0.5) -> pd.DataFrame: | |
| print(f"π Superlinked query: '{query}' with weights (relevance={relevance_weight}, recency={recency_weight})") | |
| # Execute the Superlinked query with query-time weights - following user's examples | |
| result = self.app.query( | |
| self.knowledgebase_query, | |
| relevance_weight=relevance_weight, | |
| recency_weight=recency_weight, | |
| search_query=query, | |
| limit=5 | |
| ) | |
| print(f"π Raw Superlinked result type: {type(result)}") | |
| print(f"π Raw result content: {result}") | |
| # Convert to pandas DataFrame | |
| df_result = sl.PandasConverter.to_pandas(result) | |
| print(f"π Superlinked returned {len(df_result)} results") | |
| print(f"π Result columns: {list(df_result.columns)}") | |
| # Remove duplicates if they exist (based on entry_id or id) | |
| if 'id' in df_result.columns: | |
| initial_count = len(df_result) | |
| df_result = df_result.drop_duplicates(subset=['id'], keep='first') | |
| if len(df_result) < initial_count: | |
| print(f"π Removed {initial_count - len(df_result)} duplicate results") | |
| # DEBUG: Show sample data to understand what's coming back | |
| if len(df_result) > 0: | |
| print(f"π Sample result data with scores:") | |
| for i, row in df_result.head(5).iterrows(): | |
| pub_date = row.get('published', 'MISSING') | |
| similarity = row.get('similarity_score', 'MISSING') | |
| title = str(row.get('title', 'MISSING'))[:50] | |
| print(f" Row {i}: published={pub_date}, similarity_score={similarity}") | |
| print(f" title={title}...") | |
| else: | |
| print("β οΈ No results returned from Superlinked!") | |
| return pd.DataFrame() # Return empty dataframe | |
| # ALWAYS merge with original dataframe to ensure complete data | |
| if self.df is not None: | |
| # Handle both 'entry_id' and 'id' column names from Superlinked | |
| merge_column = None | |
| if 'entry_id' in df_result.columns: | |
| merge_column = 'entry_id' | |
| elif 'id' in df_result.columns: | |
| # Use the full URL directly as entry_id (should match original dataset format) | |
| df_result['entry_id'] = df_result['id'] | |
| merge_column = 'entry_id' | |
| print(f"π Using full URLs as entry_ids: {df_result['entry_id'].head(3).tolist()}") | |
| print(f"π Original entry_ids sample: {self.df['entry_id'].head(3).tolist()}") | |
| if merge_column: | |
| print(f"π Merging with original dataframe using column '{merge_column}'...") | |
| # Debug: Show what we're trying to merge | |
| print(f"π Sample merge keys from results: {df_result['entry_id'].head(2).tolist()}") | |
| print(f"π Sample merge keys from original: {self.df['entry_id'].head(2).tolist()}") | |
| # ALWAYS merge to get complete data including publication dates | |
| df_result = df_result.merge( | |
| self.df[['entry_id', 'title', 'summary', 'published']], | |
| on='entry_id', | |
| how='left', | |
| suffixes=('', '_orig') | |
| ) | |
| # CRITICAL FIX: Use original data when current data is missing or nan | |
| for col in ['title', 'summary', 'published']: | |
| if f'{col}_orig' in df_result.columns: | |
| # Fill missing/nan values with original data | |
| mask = df_result[col].isna() | (df_result[col].astype(str).str.lower() == 'nan') | |
| df_result.loc[mask, col] = df_result.loc[mask, f'{col}_orig'] | |
| # After merge, check if we have publication dates for debugging recency | |
| if 'published' in df_result.columns: | |
| print(f"π Publication dates after merge:") | |
| for i, row in df_result.head(5).iterrows(): | |
| pub_date = row.get('published', 'MISSING') | |
| similarity = row.get('similarity_score', 'MISSING') | |
| print(f" Row {i}: {pub_date} (score: {similarity})") | |
| print(f"β After merge: {list(df_result.columns)}") | |
| print(f"π Merge success: {df_result['title'].notna().sum()}/{len(df_result)} papers have titles") | |
| else: | |
| print("β οΈ No suitable ID column found for merging!") | |
| # If no papers have titles after merge, return empty dataframe | |
| if len(df_result) > 0 and 'title' in df_result.columns: | |
| valid_papers = df_result['title'].notna().sum() | |
| if valid_papers == 0: | |
| print("β οΈ No papers found with valid data after merge - returning empty results") | |
| return pd.DataFrame() | |
| return df_result | |
| class SummarizationTool(Tool): | |
| def __init__(self, df, client, model): | |
| self.df = df | |
| self.client = client | |
| self.model = model | |
| def name(self) -> str: | |
| return "SummarizationTool" | |
| def description(self) -> str: | |
| return "Generates a concise summary of specified papers using an LLM." | |
| def use(self, query: str, paper_ids: list) -> str: | |
| papers = self.df[self.df['entry_id'].isin(paper_ids)] | |
| if papers.empty: | |
| return "No papers found with the given IDs." | |
| summaries = papers['summary'].tolist() | |
| summary_str = "\n\n".join(summaries) | |
| prompt = f""" | |
| Summarize the following paper summaries:\n\n{summary_str}\n\nProvide a concise summary. | |
| """ | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| max_tokens=500 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| class QuestionAnsweringTool(Tool): | |
| def __init__(self, retrieval_tool, client, model): | |
| self.retrieval_tool = retrieval_tool | |
| self.client = client | |
| self.model = model | |
| def name(self) -> str: | |
| return "QuestionAnsweringTool" | |
| def description(self) -> str: | |
| return "Answers questions about research topics using retrieved paper summaries or general knowledge if no specific context is available." | |
| def use(self, query: str, relevance_weight: float = 1.0, recency_weight: float = 0.5) -> str: | |
| df_result = self.retrieval_tool.use(query, relevance_weight, recency_weight) | |
| if 'summary' not in df_result.columns: | |
| # Tag as a general question if summary is missing | |
| prompt = f""" | |
| You are a knowledgeable research assistant. This is a general question tagged as [GENERAL]. Answer based on your broad knowledge, not limited to specific paper summaries. If you don't know the answer, provide a brief explanation of why. | |
| User's question: {query} | |
| """ | |
| else: | |
| # Use paper summaries for specific context | |
| contexts = df_result['summary'].tolist() | |
| context_str = "\n\n".join(contexts) | |
| prompt = f""" | |
| You are a research assistant. Use the following paper summaries to answer the user's question. If you don't know the answer based on the summaries, say 'I don't know.' | |
| Paper summaries: | |
| {context_str} | |
| User's question: {query} | |
| """ | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| max_tokens=500 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # Kernel Agent (matching reference exactly) | |
| class KernelAgent: | |
| def __init__(self, retrieval_tool: RetrievalTool, summarization_tool: SummarizationTool, question_answering_tool: QuestionAnsweringTool, client, model): | |
| self.retrieval_tool = retrieval_tool | |
| self.summarization_tool = summarization_tool | |
| self.question_answering_tool = question_answering_tool | |
| self.client = client | |
| self.model = model | |
| def classify_query(self, query: str) -> str: | |
| prompt = f""" | |
| Classify the following user prompt into one of the three categories: | |
| - retrieval: The user wants to find a list of papers based on some criteria (e.g., 'Find papers on AI ethics from 2020'). | |
| - summarization: The user wants to summarize a list of papers (e.g., 'Summarize papers with entry_id 123, 456, 789'). | |
| - question_answering: The user wants to ask a question about research topics and get an answer (e.g., 'What is the latest development in AI ethics?'). | |
| User prompt: {query} | |
| Respond with only the category name (retrieval, summarization, question_answering). | |
| If unsure, respond with 'unknown'. | |
| """ | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| max_tokens=10 | |
| ) | |
| classification = response.choices[0].message.content.strip().lower() | |
| print(f"Query type: {classification}") | |
| return classification | |
| def process_query(self, query: str, params: Optional[Dict] = None, relevance_weight: float = 1.0, recency_weight: float = 0.5) -> str: | |
| query_type = self.classify_query(query) | |
| if query_type == 'retrieval': | |
| df_result = self.retrieval_tool.use(query, relevance_weight, recency_weight) | |
| response = "Here are the top papers:\n" | |
| for i, row in df_result.iterrows(): | |
| # Ensure summary is a string and handle empty cases | |
| summary = str(row['summary']) if pd.notna(row['summary']) else "" | |
| response += f"{i+1}. {row['title']} \nSummary: {summary[:200]}...\n\n" | |
| return response | |
| elif query_type == 'summarization': | |
| if not params or 'paper_ids' not in params: | |
| return "Error: Summarization query requires a 'paper_ids' parameter with a list of entry_ids." | |
| return self.summarization_tool.use(query, params['paper_ids']) | |
| elif query_type == 'question_answering': | |
| return self.question_answering_tool.use(query, relevance_weight, recency_weight) | |
| else: | |
| return "Error: Unable to classify query as 'retrieval', 'summarization', or 'question_answering'." | |
| def setup_agent(dataset_size, fast_mode=True): | |
| """Set up the research agent with optional fast mode""" | |
| sample_size = int(dataset_size) | |
| # Download and load data | |
| dataset_url = "https://drive.google.com/uc?export=download&id=1FCR3TW5yLjGhEmm-Uclw0_5PWVEaLk1j" | |
| filename = "arxiv_ai_data.csv" | |
| download_dataset(dataset_url, filename) | |
| df = load_data_efficient(filename, sample_size=sample_size) | |
| # Debug: Show sample data to understand the structure | |
| print("Dataset loaded successfully!") | |
| print(f"Dataset shape: {df.shape}") | |
| print(f"Dataset columns: {list(df.columns)}") | |
| print("\nSample entry_ids:") | |
| print(df['entry_id'].head(3).tolist()) | |
| print("\nSample titles:") | |
| for i, title in enumerate(df['title'].head(3)): | |
| print(f"{i+1}. {title}") | |
| # Set up Superlinked - choose mode based on dataset size | |
| print(f"\nSetting up Superlinked vector database (fast_mode={fast_mode})...") | |
| if fast_mode and sample_size <= 20: | |
| app, knowledgebase_query, paper = setup_superlinked_ultrafast(df) | |
| else: | |
| app, knowledgebase_query, paper = setup_superlinked_minimal(df) | |
| # Initialize OpenAI | |
| client = get_openai_client() | |
| model = "gpt-4o-mini" | |
| # Initialize tools (following reference pattern) | |
| retrieval_tool = RetrievalTool(df, app, knowledgebase_query, client, model) | |
| summarization_tool = SummarizationTool(df, client, model) | |
| question_answering_tool = QuestionAnsweringTool(retrieval_tool, client, model) | |
| # Initialize KernelAgent | |
| kernel_agent = KernelAgent(retrieval_tool, summarization_tool, question_answering_tool, client, model) | |
| return kernel_agent, df | |
| def create_paper_cards(df_result): | |
| """Create beautiful HTML cards for paper results similar to Daily Papers""" | |
| if len(df_result) == 0: | |
| return "<p>No papers found for your query.</p>" | |
| print(f"π¨ Creating cards for {len(df_result)} results...") | |
| print(f"π Available columns: {list(df_result.columns)}") | |
| # Show sample data for debugging | |
| if len(df_result) > 0: | |
| print(f"π Sample card data:") | |
| for i, row in df_result.head(2).iterrows(): | |
| title = row.get('title', 'MISSING') | |
| summary = str(row.get('summary', 'MISSING'))[:50] | |
| published = row.get('published', 'MISSING') | |
| print(f" Row {i}: title='{title}', summary='{summary}...', published={published}") | |
| # MINIMAL filtering - only remove completely broken rows | |
| valid_results = [] | |
| for i, row in df_result.iterrows(): | |
| # Accept ALL rows that have any title data - don't be picky | |
| title = str(row.get('title', '')) | |
| if title and title.lower() not in ['nan', 'none', 'null', '']: | |
| valid_results.append(row) | |
| else: | |
| print(f"β οΈ Skipping row {i} with invalid title: '{title}'") | |
| print(f"π Original results: {len(df_result)}, Valid results after filtering: {len(valid_results)}") | |
| if len(valid_results) == 0: | |
| return "<div style='padding: 20px; text-align: center; color: #666;'>No valid papers found for your query. Try a different search term.</div>" | |
| cards_html = """ | |
| <style> | |
| .paper-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(320px, 1fr)); | |
| gap: 16px; | |
| margin: 20px 0; | |
| padding: 0 4px; | |
| } | |
| .paper-card { | |
| background: #1a1a1a; | |
| border: 1px solid #333; | |
| border-radius: 12px; | |
| padding: 20px; | |
| color: #e0e0e0; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.3); | |
| transition: all 0.2s ease; | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .paper-card:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 8px 24px rgba(0,0,0,0.4); | |
| border-color: #555; | |
| } | |
| .paper-card::before { | |
| content: ''; | |
| position: absolute; | |
| top: 0; | |
| left: 0; | |
| right: 0; | |
| height: 2px; | |
| background: linear-gradient(90deg, #666, #888, #666); | |
| } | |
| .paper-title { | |
| font-size: 16px; | |
| font-weight: 600; | |
| margin-bottom: 12px; | |
| line-height: 1.4; | |
| color: #ffffff; | |
| padding-right: 40px; | |
| word-wrap: break-word; | |
| overflow-wrap: break-word; | |
| } | |
| .paper-meta { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-bottom: 12px; | |
| flex-wrap: wrap; | |
| gap: 8px; | |
| } | |
| .paper-date { | |
| background: #2a2a2a; | |
| border: 1px solid #444; | |
| padding: 4px 8px; | |
| border-radius: 6px; | |
| font-size: 11px; | |
| font-weight: 500; | |
| color: #ccc; | |
| } | |
| .paper-relevance { | |
| background: #2a2a2a; | |
| border: 1px solid #444; | |
| padding: 3px 6px; | |
| border-radius: 4px; | |
| font-size: 10px; | |
| font-weight: 600; | |
| color: #fff; | |
| } | |
| .paper-summary { | |
| font-size: 13px; | |
| line-height: 1.5; | |
| color: #ccc; | |
| margin-bottom: 16px; | |
| background: #222; | |
| padding: 12px; | |
| border-radius: 8px; | |
| border: 1px solid #333; | |
| } | |
| .paper-id { | |
| font-size: 10px; | |
| color: #888; | |
| font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace; | |
| background: #1e1e1e; | |
| padding: 6px 8px; | |
| border-radius: 4px; | |
| word-break: break-all; | |
| border: 1px solid #333; | |
| } | |
| .rank-badge { | |
| position: absolute; | |
| top: 16px; | |
| right: 16px; | |
| background: #2a2a2a; | |
| border: 1px solid #444; | |
| color: #fff; | |
| width: 28px; | |
| height: 28px; | |
| border-radius: 50%; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| font-weight: 700; | |
| font-size: 12px; | |
| } | |
| .search-stats { | |
| background: #1a1a1a; | |
| border: 1px solid #333; | |
| color: #e0e0e0; | |
| padding: 16px 20px; | |
| border-radius: 8px; | |
| margin-bottom: 20px; | |
| text-align: center; | |
| } | |
| .search-stats h3 { | |
| margin: 0 0 6px 0; | |
| font-size: 16px; | |
| font-weight: 600; | |
| color: #fff; | |
| } | |
| .search-stats p { | |
| margin: 0; | |
| color: #aaa; | |
| font-size: 13px; | |
| } | |
| @media (max-width: 768px) { | |
| .paper-grid { | |
| grid-template-columns: 1fr; | |
| gap: 12px; | |
| margin: 16px 0; | |
| } | |
| .paper-card { | |
| padding: 16px; | |
| } | |
| .paper-title { | |
| font-size: 15px; | |
| padding-right: 35px; | |
| } | |
| } | |
| </style> | |
| <div class="search-stats"> | |
| <h3>π Superlinked Search Results</h3> | |
| <p>Found {len(valid_results)} relevant papers β’ Ranked by semantic similarity + recency</p> | |
| </div> | |
| <div class="paper-grid"> | |
| """ | |
| for i, row in enumerate(valid_results): | |
| # Extract data safely with better fallbacks | |
| title = str(row.get('title', 'Unknown Title')) | |
| # Try to get summary from multiple possible columns | |
| summary = None | |
| for summary_col in ['summary', 'summary_orig']: | |
| if summary_col in row and pd.notna(row.get(summary_col)): | |
| candidate_summary = str(row[summary_col]) | |
| if candidate_summary.lower() not in ['nan', 'none', 'null', '']: | |
| summary = candidate_summary | |
| break | |
| if not summary: | |
| summary = "Summary not available in search results" | |
| # Try multiple ways to get entry_id | |
| entry_id = None | |
| for id_col in ['entry_id', 'id', 'paper_id']: | |
| if id_col in row and pd.notna(row.get(id_col)): | |
| entry_id = str(row[id_col]) | |
| break | |
| if not entry_id: | |
| entry_id = f"Paper_{i+1}" # Fallback ID | |
| # Handle summary display | |
| if len(summary) > 200: | |
| summary_display = summary[:200] + "..." | |
| else: | |
| summary_display = summary | |
| # Format publication date - try multiple columns | |
| formatted_date = "Date not available" | |
| for date_col in ['published', 'published_orig']: | |
| if date_col in row and pd.notna(row.get(date_col)): | |
| try: | |
| pub_date = pd.to_datetime(row[date_col]) | |
| formatted_date = pub_date.strftime('%B %Y') # e.g., "March 2023" | |
| break | |
| except: | |
| formatted_date = str(row[date_col])[:10] | |
| break | |
| # LONGER title display - don't truncate so aggressively | |
| title_display = title[:120] + "..." if len(title) > 120 else title | |
| # Clean up the displays | |
| title_display = title_display.replace('nan', 'Unknown Title') | |
| # Relevance indicator | |
| if i == 0: | |
| relevance_text = "Most Relevant" | |
| elif i == 1: | |
| relevance_text = "Highly Relevant" | |
| else: | |
| relevance_text = f"Rank {i+1}" | |
| card_html = f""" | |
| <div class="paper-card"> | |
| <div class="rank-badge">#{i+1}</div> | |
| <div class="paper-title">{title_display}</div> | |
| <div class="paper-meta"> | |
| <div class="paper-date">π {formatted_date}</div> | |
| <div class="paper-relevance">{relevance_text}</div> | |
| </div> | |
| <div class="paper-summary">{summary_display}</div> | |
| <div class="paper-id">π {entry_id}</div> | |
| </div> | |
| """ | |
| cards_html += card_html | |
| cards_html += "</div>" | |
| return cards_html | |
| # Gradio Interface Functions | |
| def show_loading_state(query, relevance_weight, recency_weight): | |
| """Show immediate loading state when search button is clicked""" | |
| loading_html = f""" | |
| <div style='text-align: center; padding: 40px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; color: white; margin: 20px 0;'> | |
| <div style='font-size: 24px; margin-bottom: 15px;'> | |
| π <strong>Searching with Superlinked...</strong> | |
| </div> | |
| <div style='font-size: 16px; margin-bottom: 20px; opacity: 0.9;'> | |
| Query: "<em>{query}</em>" | |
| </div> | |
| <div style='font-size: 14px; margin-bottom: 20px; opacity: 0.8;'> | |
| Relevance Weight: <strong>{relevance_weight}</strong> | Recency Weight: <strong>{recency_weight}</strong> | |
| </div> | |
| <div style='font-size: 14px; opacity: 0.7;'> | |
| β‘ Performing semantic similarity + recency search... | |
| </div> | |
| <div style='margin-top: 20px;'> | |
| <div style='display: inline-block; width: 20px; height: 20px; border: 3px solid rgba(255,255,255,0.3); border-radius: 50%; border-top-color: white; animation: spin 1s ease-in-out infinite;'></div> | |
| </div> | |
| </div> | |
| <style> | |
| @keyframes spin {{ | |
| to {{ transform: rotate(360deg); }} | |
| }} | |
| </style> | |
| """ | |
| return loading_html | |
| def process_query(query, relevance_weight, recency_weight): | |
| if agent is None: | |
| return "β **Agent not initialized!** Please click 'Initialize Agent' first to load the dataset and set up the search system." | |
| # SIMPLIFIED: Always do retrieval search with Superlinked - no classification needed | |
| print(f"π Performing Superlinked search with weights: relevance={relevance_weight}, recency={recency_weight}") | |
| try: | |
| # Direct Superlinked search using the retrieval tool | |
| df_result = agent.retrieval_tool.use(query, relevance_weight, recency_weight) | |
| print(f"β Found {len(df_result)} results") | |
| return create_paper_cards(df_result) | |
| except Exception as e: | |
| error_msg = f"β Search failed: {str(e)}" | |
| print(error_msg) | |
| return f"<div style='padding: 20px; background: #ffe6e6; border-radius: 8px; color: #d63031;'>{error_msg}</div>" | |
| def get_paper_list(): | |
| if df is None: | |
| return "β **Agent not initialized!** Please click 'Initialize Agent' first to load the dataset." | |
| paper_list = [] | |
| for i, row in df.iterrows(): | |
| paper_list.append(f"{row['entry_id']} - {row['title']}") | |
| return "\n".join(paper_list) | |
| # Function to initialize the agent and enable buttons | |
| def initialize_agent_and_enable_buttons(dataset_size): | |
| global agent, df | |
| # Convert to int in case it comes as string from UI | |
| sample_size = int(dataset_size) | |
| try: | |
| # Set up the agent using the reference pattern | |
| agent, df = setup_agent(sample_size) | |
| success_html = f""" | |
| <div style="text-align: center; margin-bottom: 10px;"> | |
| <p>β <b>Agent initialized successfully!</b></p> | |
| <p>Loaded {sample_size} papers from the dataset.</p> | |
| </div> | |
| """ | |
| # Enable the search button | |
| return [ | |
| success_html, | |
| gr.update(interactive=True), | |
| gr.update(interactive=True) | |
| ] | |
| except Exception as e: | |
| error_html = f""" | |
| <div style="text-align: center; margin-bottom: 10px;"> | |
| <p>β <b>Initialization failed!</b></p> | |
| <p>Error: {str(e)}</p> | |
| </div> | |
| """ | |
| return [ | |
| error_html, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False) | |
| ] | |
| # Global variables for the agent and DataFrame | |
| agent = None | |
| df = None | |
| # Create Gradio interface | |
| with gr.Blocks(title="Research AI Agent") as demo: | |
| gr.Markdown("# π¬ Research AI Agent") | |
| gr.Markdown("""**Powered by Superlinked Vector Search** π | |
| This app demonstrates **direct Superlinked vector search** for AI research papers. | |
| Search uses semantic similarity combined with publication recency, with adjustable query-time weights. | |
| *Enter any research topic and get relevant papers ranked by both content similarity and recency.*""") | |
| # Add initialization message and UI controls | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| init_message = gr.HTML( | |
| """ | |
| <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 15px; margin-bottom: 15px;"> | |
| <h3 style="margin: 0 0 10px 0; color: #856404;">π Welcome to Superlinked Search!</h3> | |
| <p style="margin: 0; color: #856404;"><strong>Step 1:</strong> Select how many papers to load (10-100)</p> | |
| <p style="margin: 5px 0 0 0; color: #856404;"><strong>Step 2:</strong> Click "Initialize Agent" to build the vector index</p> | |
| <p style="margin: 5px 0 0 0; color: #856404;"><strong>Step 3:</strong> Search for papers using semantic similarity + recency!</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| dataset_size = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=10, | |
| label="Number of Papers to Load", | |
| info="Recommended: 50+ papers for better search results" | |
| ) | |
| init_button = gr.Button("Initialize Agent", variant="primary") | |
| # Superlinked filter controls | |
| with gr.Row(): | |
| with gr.Column(): | |
| relevance_weight = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Relevance Weight", | |
| info="Higher values prioritize content similarity" | |
| ) | |
| with gr.Column(): | |
| recency_weight = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.5, | |
| step=0.1, | |
| label="Recency Weight", | |
| info="Higher values prioritize recent papers" | |
| ) | |
| # Simple search interface with better styling | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| query_input = gr.Textbox( | |
| label="Superlinked Search Query", | |
| placeholder="Example: quantum computing, machine learning, neural networks, transformer models...", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| search_button = gr.Button( | |
| "π Search Papers", | |
| variant="primary", | |
| interactive=False, | |
| size="lg" | |
| ) | |
| output = gr.HTML(label="Results") | |
| # Smaller button for viewing available papers - placed lower | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("") # Spacer | |
| with gr.Column(scale=1): | |
| paper_list_button = gr.Button( | |
| "π View Available Papers", | |
| interactive=False, | |
| size="sm", | |
| variant="secondary" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("") # Spacer | |
| paper_list_output = gr.Textbox(label="Available Papers", lines=10, visible=False) | |
| # Connect the initialize button to the function with the dataset size parameter | |
| init_button.click( | |
| initialize_agent_and_enable_buttons, | |
| inputs=[dataset_size], | |
| outputs=[ | |
| init_message, | |
| search_button, | |
| paper_list_button | |
| ] | |
| ) | |
| # Connect search button to show loading first, then results | |
| search_button.click( | |
| show_loading_state, | |
| inputs=[query_input, relevance_weight, recency_weight], | |
| outputs=output | |
| ).then( | |
| process_query, | |
| inputs=[query_input, relevance_weight, recency_weight], | |
| outputs=output | |
| ) | |
| paper_list_button.click( | |
| lambda: (get_paper_list(), gr.update(visible=True)), | |
| outputs=[paper_list_output, paper_list_output] | |
| ) | |
| def save_dataset_info(df, cache_dir="superlinked_cache"): | |
| """Save just the dataset info to avoid rebuilding for same data""" | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| data_hash = get_data_hash(df) | |
| # Save just basic info, not the complex Superlinked objects | |
| info_data = { | |
| 'data_hash': data_hash, | |
| 'dataset_size': len(df), | |
| 'sample_titles': df['title'].head(3).tolist(), | |
| 'build_time': pd.Timestamp.now().isoformat() | |
| } | |
| info_file = os.path.join(cache_dir, f"dataset_info_{data_hash}_{len(df)}.json") | |
| import json | |
| with open(info_file, 'w') as f: | |
| json.dump(info_data, f) | |
| print(f"β Saved dataset info for future reference") | |
| return True | |
| except Exception as e: | |
| print(f"β οΈ Could not save dataset info: {e}") | |
| return False | |
| def check_dataset_built_before(df, cache_dir="superlinked_cache"): | |
| """Check if this exact dataset was built before""" | |
| try: | |
| if not os.path.exists(cache_dir): | |
| return False | |
| data_hash = get_data_hash(df) | |
| info_file = os.path.join(cache_dir, f"dataset_info_{data_hash}_{len(df)}.json") | |
| if os.path.exists(info_file): | |
| import json | |
| with open(info_file, 'r') as f: | |
| info_data = json.load(f) | |
| if info_data['data_hash'] == data_hash and info_data['dataset_size'] == len(df): | |
| print(f"π This dataset was built before on {info_data['build_time'][:19]}") | |
| return True | |
| return False | |
| except: | |
| return False | |
| if __name__ == "__main__": | |
| # Launch with minimal cache to reduce memory usage | |
| demo.launch(share=False, cache_examples=False) |