geronimo-pericoli's picture
Update app.py
17027bd verified
import gradio as gr
from llama_index.core import VectorStoreIndex
from llama_index.core import (
StorageContext,
load_index_from_storage,
)
from llama_index.tools.arxiv import ArxivToolSpec
from llama_index.core import Settings
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from typing import Optional, List, Dict, Any
from pathlib import Path
import aiohttp
import json
import os
import asyncio
from gradio_client import Client, handle_file
HF_TOKEN = os.environ.get('HF_TOKEN')
##### LLM #####
openai_api_key = os.environ.get('OPENAI_API_KEY')
llm = OpenAI(
model="gpt-4.1",
api_key=openai_api_key,
)
embed_model = OpenAIEmbedding(
model="text-embedding-ada-002",
api_key=openai_api_key,
)
Settings.llm = llm
Settings.embed_model = embed_model
##### END LLM #####
##### LOAD RETRIEVERS #####
DOCUMENTS_BASE_PATH = "./"
RETRIEVERS_JSON_PATH = Path("./retrievers.json")
# Load metadata
def load_retrievers_metadata():
try:
with open(RETRIEVERS_JSON_PATH, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error loading retrievers.json: {str(e)}")
print(f"Error details: {traceback.format_exc()}") # You would need to import traceback
return {}
retrievers_metadata = load_retrievers_metadata()
SOURCES = {source: f"{source.lower()}/" for source in retrievers_metadata.keys()}
# Load indexes
indices: Dict[str, VectorStoreIndex] = {}
for source, rel_path in SOURCES.items():
full_path = os.path.join(DOCUMENTS_BASE_PATH, rel_path)
if not os.path.exists(full_path):
print(f"Warning: Path not found for {source}")
continue
for root, dirs, files in os.walk(full_path):
if "storage_nodes" in dirs:
try:
storage_path = os.path.join(root, "storage_nodes")
storage_context = StorageContext.from_defaults(persist_dir=storage_path)
index_name = os.path.basename(root)
indices[index_name] = load_index_from_storage(storage_context) #, index_id="vector_index"
print(f"Index loaded successfully: {index_name}")
except Exception as e:
print(f"Error loading index {index_name}: {str(e)}")
print(f"Error details: {traceback.format_exc()}")
##### ARXIV INSTANCE #####
arxiv_tool = ArxivToolSpec(max_results=5).to_tool_list()[0]
arxiv_tool.return_direct = True
##### MCP TOOLS #####
async def search_arxiv(
query: str,
max_results: int = 5
) -> Dict[str, Any]:
"""
Searches for academic papers on ArXiv.
Args:
query: Search terms (e.g. "deep learning")
max_results: Maximum number of results (1-10, default 5)
Returns:
Dict: Search results with paper metadata
"""
try:
# Configure maximum results
max_results = min(max(1, max_results), 10)
arxiv_tool.metadata.max_results = max_results
# Execute search and get results
tool_output = arxiv_tool(query=query)
# Process documents
papers = []
for doc in tool_output.raw_output: # Correctly access documents
content = doc.text_resource.text.split('\n')
papers.append({
'title': content[0].split(': ')[1] if ': ' in content[0] else content[0],
'abstract': '\n'.join(content[1:]).strip(),
'pdf_url': content[0].split(': ')[0].replace('http://', 'https://'),
'arxiv_id': content[0].split(': ')[0].split('/')[-1].replace('v1', '')
})
return {
'papers': papers,
'count': len(papers),
'query': query,
'status': 'success'
}
except Exception as e:
return {
'papers': [],
'count': 0,
'query': query,
'status': 'error',
'error': str(e)
}
async def list_retrievers(source: str = None) -> dict:
"""
Returns the list of available retrievers.
If a source is specified and exists, filters by it; if it doesn't exist, returns all.
Args:
source (str, optional): Source to filter by. If it doesn't exist, it will be ignored. Defaults to None.
Returns:
dict: {
"retrievers": List of retrievers (filtered or complete),
"count": Total count,
"status": "success"|"error",
"source_requested": source, # Shows what was requested
"source_used": "all"|source # Shows what was actually used
}
"""
try:
available = []
source_exists = source in retrievers_metadata if source else False
for current_source, indexes in retrievers_metadata.items():
# Only filter if source exists, otherwise show all
if source_exists and current_source != source:
continue
for index_name, metadata in indexes.items():
available.append({
"name": index_name,
"source": current_source,
"title": metadata.get("title", ""),
"description": metadata.get("description", "")
})
return {
"retrievers": available,
"count": len(available),
"status": "success",
"source_requested": source,
"source_used": source if source_exists else "all"
}
except Exception as e:
return {
"retrievers": [],
"count": 0,
"status": "error",
"error": str(e),
"source_requested": source,
"source_used": "none"
}
def retrieve_docs(
query: str,
retrievers: List[str],
top_k: int = 3
) -> dict:
"""
Performs semantic search on indexed documents.
Parameters:
query (str): Search text (required)
retrievers (List[str]): Names of retrievers to query (required)
top_k (int): Number of results per retriever (optional, default=3)
"""
print(f"Starting search for query: '{query}'")
print(f"Parameters - retrievers: {retrievers}, top_k: {top_k}")
results = {}
invalid = []
for name in retrievers:
if name not in indices:
print(f"Retriever not found: {name}")
invalid.append(name)
continue
try:
print(f"Processing retriever: {name}")
retriever = indices[name].as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
print(f"Retrieved {len(nodes)} documents from {name}")
# 2. Search for COMPLETE metadata
metadata = {}
source = "unknown"
for src, indexes in retrievers_metadata.items():
if name in indexes:
metadata = indexes[name]
source = src
break
print(f"Metadata found for {name}: {metadata.keys()}")
# 3. Build response
results[name] = {
"title": metadata.get("title", name),
"documents": [
{
"content": node.get_content(),
"metadata": node.metadata,
"score": node.score
}
for node in nodes
],
"description": metadata.get("description", ""),
"source": source,
"last_updated": metadata.get("last_updated", "")
}
print(f"Retriever {name} processed successfully")
except Exception as e:
print(f"Error processing retriever {name}: {str(e)}", exc_info=True)
results[name] = {
"error": str(e),
"retriever": name
}
# Build final response
response = {
"query": query,
"results": results,
"top_k": top_k,
}
if invalid:
print(f"Invalid retrievers: {invalid}. Valid options: {list(indices.keys())}")
response["warnings"] = {
"invalid_retrievers": invalid,
"valid_options": list(indices.keys())
}
print(f"Search completed. Total results: {len(results)}")
return response
async def search_tavily(
query: str,
days: int = 7,
max_results: int = 1,
include_answer: bool = False
) -> dict:
"""Perform a web search using the Tavily API.
Args:
query: Search query string (required)
days: Restrict search to last N days (default: 7)
max_results: Maximum results to return (default: 1)
include_answer: Include a direct answer only when requested by the user (default: False)
Returns:
dict: Search results from Tavily
"""
# Get API key from environment variables
tavily_api_key = os.environ.get('TAVILY_API_KEY')
if not tavily_api_key:
raise ValueError("TAVILY_API_KEY environment variable not set")
headers = {
"Authorization": f"Bearer {tavily_api_key}",
"Content-Type": "application/json"
}
payload = {
"query": query,
"search_depth": "basic",
"max_results": max_results,
"days": days if days else None,
"include_answer": include_answer
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.tavily.com/search",
headers=headers,
json=payload
) as response:
response.raise_for_status()
result = await response.json()
return result
except Exception as e:
return {
"error": str(e),
"status": "failed",
"query": query
}
##### EVALS #####
async def evaluate_answer_relevancy(
query: str,
response: str,
) -> float:
"""Evaluate how relevant the answer is to the query using AnswerRelevancyEvaluator.
Args:
query: Original user query (required)
response: Generated response to evaluate (required)
Returns:
float: Relevancy score between 0 and 1 (higher is better)
"""
try:
from llama_index.core.evaluation import AnswerRelevancyEvaluator
# Initialize the evaluator
evaluator = AnswerRelevancyEvaluator(llm=llm)
# Perform the evaluation
eval_result = evaluator.evaluate(query=query, response=response)
# Return the score as a float
return float(eval_result.score)
except Exception as e:
# In case of error, return 0.0 (minimum score) and log the error
print(f"Error in relevancy evaluation: {str(e)}")
return 0.0
async def evaluate_context_relevancy(
context: str,
query: str,
response: str
) -> float:
"""Evaluates the relevance of the response considering both the query and the context.
Args:
context: Contextual information / knowledge base (required)
query: Original user query (required)
response: Generated response to evaluate (required)
Returns:
float: Relevance score between 0 and 1 (higher is better)
"""
try:
from llama_index.core.evaluation import ContextRelevancyEvaluator
# Initialize the relevancy evaluator with context
evaluator = ContextRelevancyEvaluator(llm=llm)
# Perform the evaluation (adapted to handle context)
eval_result = evaluator.evaluate(
query=query,
response=response,
contexts=[context]
)
return float(eval_result.score)
except Exception as e:
print(f"Error during context relevancy evaluation: {str(e)}")
return 0.0
async def evaluate_faithfulness(
query: str,
response: str,
context: str
) -> float:
"""Evaluate how faithful (factually consistent) the response is to the provided context.
Args:
query: Original user query (required)
response: Generated response to evaluate (required)
context: Source context/knowledge base used for the response (required)
Returns:
float: Faithfulness score between 0 and 1 (higher is better)
"""
try:
from llama_index.core.evaluation import FaithfulnessEvaluator
# Initialize evaluator
evaluator = FaithfulnessEvaluator(llm=llm)
# Perform evaluation
eval_result = evaluator.evaluate(
query=query,
response=response,
contexts=[context]
)
# Return score as float
return float(eval_result.score)
except Exception as e:
# On error, return 0.0 (minimum score) and log the error
print(f"Error in faithfulness evaluation: {str(e)}")
return 0.0
# Gradio interface
with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as arxiv_tab:
arxiv_interface = gr.Interface(
fn=search_arxiv,
inputs=[
gr.Textbox(label="Search terms", placeholder="E.g.: deep learning"),
gr.Slider(1, 10, value=5, step=1, label="Maximum number of results")
],
outputs=gr.JSON(label="Search results"),
title="ArXiv Search",
description="Search for academic papers on ArXiv using keywords.",
api_name="_search_arxiv"
)
with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as list_retrievers_tab:
retrievers_interface = gr.Interface(
fn=list_retrievers,
inputs=gr.Textbox(label="Source (optional)", placeholder="Leave empty to list all"),
outputs=gr.JSON(label="List of retrievers"),
title="List of Retrievers",
description="Shows available retrievers, optionally filtered by source.",
api_name="_list_retrievers"
)
with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as tavily_tab:
tavily_interface = gr.Interface(
fn=search_tavily,
inputs=[
gr.Textbox(label="Search query", placeholder="E.g.: latest news about AI"),
gr.Slider(1, 30, value=7, step=1, label="Last N days (0 for no limit)"),
gr.Slider(1, 10, value=1, step=1, label="Maximum results"),
gr.Checkbox(label="Include direct answer", value=False)
],
outputs=gr.JSON(label="Tavily results"),
title="Web Search (Tavily)",
description="Perform web searches using the Tavily API.",
api_name="_search_tavily"
)
with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as retrieve_tab:
# Interface for retrieve_docs
retrieve_interface = gr.Interface(
fn=retrieve_docs,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your question or search terms..."),
gr.Dropdown(
choices=list(indices.keys()),
label="Retrievers",
multiselect=True,
info="Select one or more retrievers"
),
gr.Slider(1, 10, value=3, step=1, label="Number of results per retriever (top_k)")
],
outputs=gr.JSON(label="Semantic search results"),
title="Semantic Document Search",
description="""Perform semantic search on indexed documents using retrievers.
Select available retrievers and adjust the number of results.""",
api_name="_retrieve"
)
with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as asw_relevance_tab:
relevancy_interface = gr.Interface(
fn=evaluate_answer_relevancy,
inputs=[
gr.Textbox(label="Original Query", placeholder="E.g.: How does photosynthesis work?"),
gr.Textbox(label="Answer to Evaluate", placeholder="Paste the generated answer here", lines=5),
],
outputs=gr.Number(label="Relevancy Score (0-1)", precision=3),
title="Relevancy Evaluator (Query-Answer)",
description="Evaluates how relevant an answer is to the original query (1 = perfectly relevant).",
api_name="_evaluate_relevancy"
)
with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as ctx_relevance_tab:
context_relevancy_interface = gr.Interface(
fn=evaluate_context_relevancy,
inputs=[
gr.Textbox(label="Context", placeholder="Relevant text / knowledge base", lines=3),
gr.Textbox(label="Original Query", placeholder="What question is being answered?"),
gr.Textbox(label="Generated Answer", placeholder="The answer to evaluate", lines=5),
],
outputs=gr.Number(label="Relevancy Score (0-1)", precision=3),
title="Relevancy Evaluator (Context-Query-Answer)",
description="Evaluates how relevant the answer is considering both the query and the reference context.",
api_name="_evaluate_context_relevancy"
)
with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as faithfulness_tab:
faithfulness_interface = gr.Interface(
fn=evaluate_faithfulness,
inputs=[
gr.Textbox(label="Original Query", placeholder="E.g.: What are the causes of climate change?"),
gr.Textbox(label="Answer to Evaluate", placeholder="Paste the generated answer here", lines=5),
gr.Textbox(label="Context", placeholder="Reference text / knowledge base", lines=3),
],
outputs=gr.Number(label="Faithfulness Score (0-1)", precision=3),
title="Faithfulness Evaluator",
description="Evaluates how faithful/factually consistent the answer is with respect to the provided context (1 = perfectly faithful).",
api_name="_evaluate_faithfulness"
)
# Create the interface with separate tabs
demo = gr.TabbedInterface(
[arxiv_tab, tavily_tab, list_retrievers_tab, retrieve_tab, asw_relevance_tab, ctx_relevance_tab, faithfulness_tab],
["ArXiv", "Tavily", "List Retrievers", "Retrieve", "Answer Relevance", "Context Relevance", "Faithfulness"],
theme=gr.themes.Base(),
)
demo.launch(mcp_server=True)