Spaces:
Running
Running
| import gradio as gr | |
| from llama_index.core import VectorStoreIndex | |
| from llama_index.core import ( | |
| StorageContext, | |
| load_index_from_storage, | |
| ) | |
| from llama_index.tools.arxiv import ArxivToolSpec | |
| from llama_index.core import Settings | |
| from llama_index.llms.azure_openai import AzureOpenAI | |
| from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from typing import Optional, List, Dict, Any | |
| from pathlib import Path | |
| import aiohttp | |
| import json | |
| import os | |
| import asyncio | |
| from gradio_client import Client, handle_file | |
| HF_TOKEN = os.environ.get('HF_TOKEN') | |
| ##### LLM ##### | |
| openai_api_key = os.environ.get('OPENAI_API_KEY') | |
| llm = OpenAI( | |
| model="gpt-4.1", | |
| api_key=openai_api_key, | |
| ) | |
| embed_model = OpenAIEmbedding( | |
| model="text-embedding-ada-002", | |
| api_key=openai_api_key, | |
| ) | |
| Settings.llm = llm | |
| Settings.embed_model = embed_model | |
| ##### END LLM ##### | |
| ##### LOAD RETRIEVERS ##### | |
| DOCUMENTS_BASE_PATH = "./" | |
| RETRIEVERS_JSON_PATH = Path("./retrievers.json") | |
| # Load metadata | |
| def load_retrievers_metadata(): | |
| try: | |
| with open(RETRIEVERS_JSON_PATH, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Error loading retrievers.json: {str(e)}") | |
| print(f"Error details: {traceback.format_exc()}") # You would need to import traceback | |
| return {} | |
| retrievers_metadata = load_retrievers_metadata() | |
| SOURCES = {source: f"{source.lower()}/" for source in retrievers_metadata.keys()} | |
| # Load indexes | |
| indices: Dict[str, VectorStoreIndex] = {} | |
| for source, rel_path in SOURCES.items(): | |
| full_path = os.path.join(DOCUMENTS_BASE_PATH, rel_path) | |
| if not os.path.exists(full_path): | |
| print(f"Warning: Path not found for {source}") | |
| continue | |
| for root, dirs, files in os.walk(full_path): | |
| if "storage_nodes" in dirs: | |
| try: | |
| storage_path = os.path.join(root, "storage_nodes") | |
| storage_context = StorageContext.from_defaults(persist_dir=storage_path) | |
| index_name = os.path.basename(root) | |
| indices[index_name] = load_index_from_storage(storage_context) #, index_id="vector_index" | |
| print(f"Index loaded successfully: {index_name}") | |
| except Exception as e: | |
| print(f"Error loading index {index_name}: {str(e)}") | |
| print(f"Error details: {traceback.format_exc()}") | |
| ##### ARXIV INSTANCE ##### | |
| arxiv_tool = ArxivToolSpec(max_results=5).to_tool_list()[0] | |
| arxiv_tool.return_direct = True | |
| ##### MCP TOOLS ##### | |
| async def search_arxiv( | |
| query: str, | |
| max_results: int = 5 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Searches for academic papers on ArXiv. | |
| Args: | |
| query: Search terms (e.g. "deep learning") | |
| max_results: Maximum number of results (1-10, default 5) | |
| Returns: | |
| Dict: Search results with paper metadata | |
| """ | |
| try: | |
| # Configure maximum results | |
| max_results = min(max(1, max_results), 10) | |
| arxiv_tool.metadata.max_results = max_results | |
| # Execute search and get results | |
| tool_output = arxiv_tool(query=query) | |
| # Process documents | |
| papers = [] | |
| for doc in tool_output.raw_output: # Correctly access documents | |
| content = doc.text_resource.text.split('\n') | |
| papers.append({ | |
| 'title': content[0].split(': ')[1] if ': ' in content[0] else content[0], | |
| 'abstract': '\n'.join(content[1:]).strip(), | |
| 'pdf_url': content[0].split(': ')[0].replace('http://', 'https://'), | |
| 'arxiv_id': content[0].split(': ')[0].split('/')[-1].replace('v1', '') | |
| }) | |
| return { | |
| 'papers': papers, | |
| 'count': len(papers), | |
| 'query': query, | |
| 'status': 'success' | |
| } | |
| except Exception as e: | |
| return { | |
| 'papers': [], | |
| 'count': 0, | |
| 'query': query, | |
| 'status': 'error', | |
| 'error': str(e) | |
| } | |
| async def list_retrievers(source: str = None) -> dict: | |
| """ | |
| Returns the list of available retrievers. | |
| If a source is specified and exists, filters by it; if it doesn't exist, returns all. | |
| Args: | |
| source (str, optional): Source to filter by. If it doesn't exist, it will be ignored. Defaults to None. | |
| Returns: | |
| dict: { | |
| "retrievers": List of retrievers (filtered or complete), | |
| "count": Total count, | |
| "status": "success"|"error", | |
| "source_requested": source, # Shows what was requested | |
| "source_used": "all"|source # Shows what was actually used | |
| } | |
| """ | |
| try: | |
| available = [] | |
| source_exists = source in retrievers_metadata if source else False | |
| for current_source, indexes in retrievers_metadata.items(): | |
| # Only filter if source exists, otherwise show all | |
| if source_exists and current_source != source: | |
| continue | |
| for index_name, metadata in indexes.items(): | |
| available.append({ | |
| "name": index_name, | |
| "source": current_source, | |
| "title": metadata.get("title", ""), | |
| "description": metadata.get("description", "") | |
| }) | |
| return { | |
| "retrievers": available, | |
| "count": len(available), | |
| "status": "success", | |
| "source_requested": source, | |
| "source_used": source if source_exists else "all" | |
| } | |
| except Exception as e: | |
| return { | |
| "retrievers": [], | |
| "count": 0, | |
| "status": "error", | |
| "error": str(e), | |
| "source_requested": source, | |
| "source_used": "none" | |
| } | |
| def retrieve_docs( | |
| query: str, | |
| retrievers: List[str], | |
| top_k: int = 3 | |
| ) -> dict: | |
| """ | |
| Performs semantic search on indexed documents. | |
| Parameters: | |
| query (str): Search text (required) | |
| retrievers (List[str]): Names of retrievers to query (required) | |
| top_k (int): Number of results per retriever (optional, default=3) | |
| """ | |
| print(f"Starting search for query: '{query}'") | |
| print(f"Parameters - retrievers: {retrievers}, top_k: {top_k}") | |
| results = {} | |
| invalid = [] | |
| for name in retrievers: | |
| if name not in indices: | |
| print(f"Retriever not found: {name}") | |
| invalid.append(name) | |
| continue | |
| try: | |
| print(f"Processing retriever: {name}") | |
| retriever = indices[name].as_retriever(similarity_top_k=top_k) | |
| nodes = retriever.retrieve(query) | |
| print(f"Retrieved {len(nodes)} documents from {name}") | |
| # 2. Search for COMPLETE metadata | |
| metadata = {} | |
| source = "unknown" | |
| for src, indexes in retrievers_metadata.items(): | |
| if name in indexes: | |
| metadata = indexes[name] | |
| source = src | |
| break | |
| print(f"Metadata found for {name}: {metadata.keys()}") | |
| # 3. Build response | |
| results[name] = { | |
| "title": metadata.get("title", name), | |
| "documents": [ | |
| { | |
| "content": node.get_content(), | |
| "metadata": node.metadata, | |
| "score": node.score | |
| } | |
| for node in nodes | |
| ], | |
| "description": metadata.get("description", ""), | |
| "source": source, | |
| "last_updated": metadata.get("last_updated", "") | |
| } | |
| print(f"Retriever {name} processed successfully") | |
| except Exception as e: | |
| print(f"Error processing retriever {name}: {str(e)}", exc_info=True) | |
| results[name] = { | |
| "error": str(e), | |
| "retriever": name | |
| } | |
| # Build final response | |
| response = { | |
| "query": query, | |
| "results": results, | |
| "top_k": top_k, | |
| } | |
| if invalid: | |
| print(f"Invalid retrievers: {invalid}. Valid options: {list(indices.keys())}") | |
| response["warnings"] = { | |
| "invalid_retrievers": invalid, | |
| "valid_options": list(indices.keys()) | |
| } | |
| print(f"Search completed. Total results: {len(results)}") | |
| return response | |
| async def search_tavily( | |
| query: str, | |
| days: int = 7, | |
| max_results: int = 1, | |
| include_answer: bool = False | |
| ) -> dict: | |
| """Perform a web search using the Tavily API. | |
| Args: | |
| query: Search query string (required) | |
| days: Restrict search to last N days (default: 7) | |
| max_results: Maximum results to return (default: 1) | |
| include_answer: Include a direct answer only when requested by the user (default: False) | |
| Returns: | |
| dict: Search results from Tavily | |
| """ | |
| # Get API key from environment variables | |
| tavily_api_key = os.environ.get('TAVILY_API_KEY') | |
| if not tavily_api_key: | |
| raise ValueError("TAVILY_API_KEY environment variable not set") | |
| headers = { | |
| "Authorization": f"Bearer {tavily_api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "query": query, | |
| "search_depth": "basic", | |
| "max_results": max_results, | |
| "days": days if days else None, | |
| "include_answer": include_answer | |
| } | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| "https://api.tavily.com/search", | |
| headers=headers, | |
| json=payload | |
| ) as response: | |
| response.raise_for_status() | |
| result = await response.json() | |
| return result | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "status": "failed", | |
| "query": query | |
| } | |
| ##### EVALS ##### | |
| async def evaluate_answer_relevancy( | |
| query: str, | |
| response: str, | |
| ) -> float: | |
| """Evaluate how relevant the answer is to the query using AnswerRelevancyEvaluator. | |
| Args: | |
| query: Original user query (required) | |
| response: Generated response to evaluate (required) | |
| Returns: | |
| float: Relevancy score between 0 and 1 (higher is better) | |
| """ | |
| try: | |
| from llama_index.core.evaluation import AnswerRelevancyEvaluator | |
| # Initialize the evaluator | |
| evaluator = AnswerRelevancyEvaluator(llm=llm) | |
| # Perform the evaluation | |
| eval_result = evaluator.evaluate(query=query, response=response) | |
| # Return the score as a float | |
| return float(eval_result.score) | |
| except Exception as e: | |
| # In case of error, return 0.0 (minimum score) and log the error | |
| print(f"Error in relevancy evaluation: {str(e)}") | |
| return 0.0 | |
| async def evaluate_context_relevancy( | |
| context: str, | |
| query: str, | |
| response: str | |
| ) -> float: | |
| """Evaluates the relevance of the response considering both the query and the context. | |
| Args: | |
| context: Contextual information / knowledge base (required) | |
| query: Original user query (required) | |
| response: Generated response to evaluate (required) | |
| Returns: | |
| float: Relevance score between 0 and 1 (higher is better) | |
| """ | |
| try: | |
| from llama_index.core.evaluation import ContextRelevancyEvaluator | |
| # Initialize the relevancy evaluator with context | |
| evaluator = ContextRelevancyEvaluator(llm=llm) | |
| # Perform the evaluation (adapted to handle context) | |
| eval_result = evaluator.evaluate( | |
| query=query, | |
| response=response, | |
| contexts=[context] | |
| ) | |
| return float(eval_result.score) | |
| except Exception as e: | |
| print(f"Error during context relevancy evaluation: {str(e)}") | |
| return 0.0 | |
| async def evaluate_faithfulness( | |
| query: str, | |
| response: str, | |
| context: str | |
| ) -> float: | |
| """Evaluate how faithful (factually consistent) the response is to the provided context. | |
| Args: | |
| query: Original user query (required) | |
| response: Generated response to evaluate (required) | |
| context: Source context/knowledge base used for the response (required) | |
| Returns: | |
| float: Faithfulness score between 0 and 1 (higher is better) | |
| """ | |
| try: | |
| from llama_index.core.evaluation import FaithfulnessEvaluator | |
| # Initialize evaluator | |
| evaluator = FaithfulnessEvaluator(llm=llm) | |
| # Perform evaluation | |
| eval_result = evaluator.evaluate( | |
| query=query, | |
| response=response, | |
| contexts=[context] | |
| ) | |
| # Return score as float | |
| return float(eval_result.score) | |
| except Exception as e: | |
| # On error, return 0.0 (minimum score) and log the error | |
| print(f"Error in faithfulness evaluation: {str(e)}") | |
| return 0.0 | |
| # Gradio interface | |
| with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as arxiv_tab: | |
| arxiv_interface = gr.Interface( | |
| fn=search_arxiv, | |
| inputs=[ | |
| gr.Textbox(label="Search terms", placeholder="E.g.: deep learning"), | |
| gr.Slider(1, 10, value=5, step=1, label="Maximum number of results") | |
| ], | |
| outputs=gr.JSON(label="Search results"), | |
| title="ArXiv Search", | |
| description="Search for academic papers on ArXiv using keywords.", | |
| api_name="_search_arxiv" | |
| ) | |
| with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as list_retrievers_tab: | |
| retrievers_interface = gr.Interface( | |
| fn=list_retrievers, | |
| inputs=gr.Textbox(label="Source (optional)", placeholder="Leave empty to list all"), | |
| outputs=gr.JSON(label="List of retrievers"), | |
| title="List of Retrievers", | |
| description="Shows available retrievers, optionally filtered by source.", | |
| api_name="_list_retrievers" | |
| ) | |
| with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as tavily_tab: | |
| tavily_interface = gr.Interface( | |
| fn=search_tavily, | |
| inputs=[ | |
| gr.Textbox(label="Search query", placeholder="E.g.: latest news about AI"), | |
| gr.Slider(1, 30, value=7, step=1, label="Last N days (0 for no limit)"), | |
| gr.Slider(1, 10, value=1, step=1, label="Maximum results"), | |
| gr.Checkbox(label="Include direct answer", value=False) | |
| ], | |
| outputs=gr.JSON(label="Tavily results"), | |
| title="Web Search (Tavily)", | |
| description="Perform web searches using the Tavily API.", | |
| api_name="_search_tavily" | |
| ) | |
| with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as retrieve_tab: | |
| # Interface for retrieve_docs | |
| retrieve_interface = gr.Interface( | |
| fn=retrieve_docs, | |
| inputs=[ | |
| gr.Textbox(label="Query", placeholder="Enter your question or search terms..."), | |
| gr.Dropdown( | |
| choices=list(indices.keys()), | |
| label="Retrievers", | |
| multiselect=True, | |
| info="Select one or more retrievers" | |
| ), | |
| gr.Slider(1, 10, value=3, step=1, label="Number of results per retriever (top_k)") | |
| ], | |
| outputs=gr.JSON(label="Semantic search results"), | |
| title="Semantic Document Search", | |
| description="""Perform semantic search on indexed documents using retrievers. | |
| Select available retrievers and adjust the number of results.""", | |
| api_name="_retrieve" | |
| ) | |
| with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as asw_relevance_tab: | |
| relevancy_interface = gr.Interface( | |
| fn=evaluate_answer_relevancy, | |
| inputs=[ | |
| gr.Textbox(label="Original Query", placeholder="E.g.: How does photosynthesis work?"), | |
| gr.Textbox(label="Answer to Evaluate", placeholder="Paste the generated answer here", lines=5), | |
| ], | |
| outputs=gr.Number(label="Relevancy Score (0-1)", precision=3), | |
| title="Relevancy Evaluator (Query-Answer)", | |
| description="Evaluates how relevant an answer is to the original query (1 = perfectly relevant).", | |
| api_name="_evaluate_relevancy" | |
| ) | |
| with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as ctx_relevance_tab: | |
| context_relevancy_interface = gr.Interface( | |
| fn=evaluate_context_relevancy, | |
| inputs=[ | |
| gr.Textbox(label="Context", placeholder="Relevant text / knowledge base", lines=3), | |
| gr.Textbox(label="Original Query", placeholder="What question is being answered?"), | |
| gr.Textbox(label="Generated Answer", placeholder="The answer to evaluate", lines=5), | |
| ], | |
| outputs=gr.Number(label="Relevancy Score (0-1)", precision=3), | |
| title="Relevancy Evaluator (Context-Query-Answer)", | |
| description="Evaluates how relevant the answer is considering both the query and the reference context.", | |
| api_name="_evaluate_context_relevancy" | |
| ) | |
| with gr.Blocks(title="MCP Tools", theme=gr.themes.Base()) as faithfulness_tab: | |
| faithfulness_interface = gr.Interface( | |
| fn=evaluate_faithfulness, | |
| inputs=[ | |
| gr.Textbox(label="Original Query", placeholder="E.g.: What are the causes of climate change?"), | |
| gr.Textbox(label="Answer to Evaluate", placeholder="Paste the generated answer here", lines=5), | |
| gr.Textbox(label="Context", placeholder="Reference text / knowledge base", lines=3), | |
| ], | |
| outputs=gr.Number(label="Faithfulness Score (0-1)", precision=3), | |
| title="Faithfulness Evaluator", | |
| description="Evaluates how faithful/factually consistent the answer is with respect to the provided context (1 = perfectly faithful).", | |
| api_name="_evaluate_faithfulness" | |
| ) | |
| # Create the interface with separate tabs | |
| demo = gr.TabbedInterface( | |
| [arxiv_tab, tavily_tab, list_retrievers_tab, retrieve_tab, asw_relevance_tab, ctx_relevance_tab, faithfulness_tab], | |
| ["ArXiv", "Tavily", "List Retrievers", "Retrieve", "Answer Relevance", "Context Relevance", "Faithfulness"], | |
| theme=gr.themes.Base(), | |
| ) | |
| demo.launch(mcp_server=True) |