Spaces:
Build error
Build error
| import os | |
| import json | |
| import logging | |
| from typing import Dict, List, Any | |
| from dotenv import load_dotenv | |
| import asyncio | |
| import time | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def retrieve_information(query: str, top_k: int = 5, threshold: float = 0.3) -> Dict: | |
| """ | |
| Retrieve information from the knowledge base based on a query | |
| """ | |
| from retrieving import RAGRetriever | |
| retriever = RAGRetriever() | |
| try: | |
| # Call the existing retrieve method from the RAGRetriever instance | |
| json_response = retriever.retrieve(query_text=query, top_k=top_k, threshold=threshold) | |
| results = json.loads(json_response) | |
| # Format the results for the assistant | |
| formatted_results = [] | |
| for result in results.get('results', []): | |
| formatted_results.append({ | |
| 'content': result['content'], | |
| 'url': result['url'], | |
| 'position': result['position'], | |
| 'similarity_score': result['similarity_score'], | |
| 'chunk_id': result.get('chunk_id', ''), | |
| 'created_at': result.get('created_at', '') | |
| }) | |
| return { | |
| 'query': query, | |
| 'retrieved_chunks': formatted_results, | |
| 'total_results': len(formatted_results), | |
| 'metadata': results.get('metadata', {}) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in retrieve_information: {e}") | |
| return { | |
| 'query': query, | |
| 'retrieved_chunks': [], | |
| 'total_results': 0, | |
| 'error': str(e), | |
| 'metadata': {} | |
| } | |
| class RAGAgent: | |
| def __init__(self): | |
| # Initialize the RAG system components | |
| # For now, we'll use the retrieval function directly | |
| # In a real implementation, you would initialize your existing RAG components | |
| logger.info("RAG Agent initialized with retrieval and generation components") | |
| def query_agent(self, query_text: str, session_id: str = None, query_type: str = "global", selected_text: str = None) -> Dict: | |
| """ | |
| Process a query through the RAG system and return structured response | |
| """ | |
| start_time = time.time() | |
| logger.info(f"Processing query through RAG system: '{query_text[:50]}...'") | |
| try: | |
| # Retrieve relevant information using our retrieval system | |
| retrieval_result = retrieve_information(query_text, top_k=5, threshold=0.3) | |
| if retrieval_result.get('error'): | |
| return { | |
| "answer": "Sorry, I encountered an error retrieving information.", | |
| "sources": [], | |
| "matched_chunks": [], | |
| "citations": [], | |
| "error": retrieval_result['error'], | |
| "query_time_ms": (time.time() - start_time) * 1000, | |
| "session_id": session_id, | |
| "query_type": query_type | |
| } | |
| # Format the retrieved information for response generation | |
| # In a real implementation, you would connect this to your response generator | |
| retrieved_chunks = retrieval_result.get('retrieved_chunks', []) | |
| if not retrieved_chunks: | |
| return { | |
| "answer": "I couldn't find relevant information in the Physical AI & Humanoid Robotics curriculum to answer your question. Please try asking about specific topics from the curriculum like ROS 2, Digital Twins, AI-Brain, or VLA.", | |
| "sources": [], | |
| "matched_chunks": [], | |
| "citations": [], | |
| "error": None, | |
| "query_time_ms": (time.time() - start_time) * 1000, | |
| "session_id": session_id, | |
| "query_type": query_type | |
| } | |
| # Generate a response based on the retrieved information | |
| # For now, we'll create a simple response based on the retrieved chunks | |
| answer_parts = ["Based on the Physical AI & Humanoid Robotics curriculum:"] | |
| # Include content from the most relevant chunks | |
| for i, chunk in enumerate(retrieved_chunks[:2]): # Use top 2 chunks | |
| content = chunk.get('content', '')[:300] # Limit content length | |
| answer_parts.append(f"{content}...") | |
| answer = " ".join(answer_parts) | |
| # Create citations from the retrieved chunks | |
| citations = [] | |
| for chunk in retrieved_chunks: | |
| citation = { | |
| "document_id": chunk.get('chunk_id', ''), | |
| "title": chunk.get('url', ''), | |
| "chapter": "", | |
| "section": "", | |
| "page_reference": "" | |
| } | |
| citations.append(citation) | |
| # Calculate query time | |
| query_time_ms = (time.time() - start_time) * 1000 | |
| # Format the response | |
| response = { | |
| "answer": answer, | |
| "sources": [chunk.get('url', '') for chunk in retrieved_chunks if chunk.get('url')], | |
| "matched_chunks": retrieved_chunks, | |
| "citations": citations, | |
| "query_time_ms": query_time_ms, | |
| "session_id": session_id, | |
| "query_type": query_type, | |
| "confidence": self._calculate_confidence(retrieved_chunks), | |
| "error": None | |
| } | |
| logger.info(f"Query processed in {query_time_ms:.2f}ms") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}") | |
| return { | |
| "answer": "Sorry, I encountered an error processing your request.", | |
| "sources": [], | |
| "matched_chunks": [], | |
| "citations": [], | |
| "error": str(e), | |
| "query_time_ms": (time.time() - start_time) * 1000, | |
| "session_id": session_id, | |
| "query_type": query_type | |
| } | |
| def _calculate_confidence(self, matched_chunks: List[Dict]) -> str: | |
| """ | |
| Calculate confidence level based on similarity scores and number of matches | |
| """ | |
| if not matched_chunks: | |
| return "low" | |
| avg_score = sum(chunk.get('similarity_score', 0.0) for chunk in matched_chunks) / len(matched_chunks) | |
| if avg_score >= 0.7: | |
| return "high" | |
| elif avg_score >= 0.4: | |
| return "medium" | |
| else: | |
| return "low" | |
| def query_agent(query_text: str) -> Dict: | |
| """ | |
| Convenience function to query the RAG agent | |
| """ | |
| agent = RAGAgent() | |
| return agent.query_agent(query_text) | |
| def run_agent_sync(query_text: str) -> Dict: | |
| """ | |
| Synchronous function to run the agent for direct usage | |
| """ | |
| import asyncio | |
| async def run_async(): | |
| agent = RAGAgent() | |
| return await agent._async_query_agent(query_text) | |
| # Check if there's already a running event loop | |
| try: | |
| loop = asyncio.get_running_loop() | |
| # If there's already a loop, run in a separate thread | |
| import concurrent.futures | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(asyncio.run, run_async()) | |
| return future.result() | |
| except RuntimeError: | |
| # No running loop, safe to use asyncio.run | |
| return asyncio.run(run_async()) | |
| def main(): | |
| """ | |
| Main function to demonstrate the RAG agent functionality | |
| """ | |
| logger.info("Initializing RAG Agent...") | |
| # Initialize the agent | |
| agent = RAGAgent() | |
| # Example queries to test the system | |
| test_queries = [ | |
| "What is ROS2?", | |
| "Explain humanoid design principles", | |
| "How does VLA work?", | |
| "What are simulation techniques?", | |
| "Explain AI control systems" | |
| ] | |
| print("RAG Agent - Testing Queries") | |
| print("=" * 50) | |
| for i, query in enumerate(test_queries, 1): | |
| print(f"\nQuery {i}: {query}") | |
| print("-" * 30) | |
| # Process query through agent | |
| response = agent.query_agent(query) | |
| # Print formatted results | |
| print(f"Answer: {response['answer']}") | |
| if response.get('sources'): | |
| print(f"Sources: {len(response['sources'])} documents") | |
| for source in response['sources'][:3]: # Show first 3 sources | |
| print(f" - {source}") | |
| if response.get('matched_chunks'): | |
| print(f"Matched chunks: {len(response['matched_chunks'])}") | |
| for j, chunk in enumerate(response['matched_chunks'][:2], 1): # Show first 2 chunks | |
| content_preview = chunk['content'][:100] + "..." if len(chunk['content']) > 100 else chunk['content'] | |
| print(f" Chunk {j}: {content_preview}") | |
| print(f" Source: {chunk['url']}") | |
| print(f" Score: {chunk['similarity_score']:.3f}") | |
| print(f"Query time: {response['query_time_ms']:.2f}ms") | |
| print(f"Confidence: {response.get('confidence', 'unknown')}") | |
| if i < len(test_queries): # Don't sleep after the last query | |
| time.sleep(1) # Small delay between queries | |
| if __name__ == "__main__": | |
| main() |