File size: 9,390 Bytes
4711fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a426a46
4711fe9
 
 
 
 
 
 
 
a426a46
4711fe9
 
 
 
 
 
 
 
 
a426a46
 
 
4711fe9
 
 
 
 
a426a46
 
4711fe9
 
 
 
 
 
 
a426a46
 
4711fe9
 
 
 
a426a46
 
 
 
4711fe9
a426a46
4711fe9
a426a46
4711fe9
 
 
a426a46
4711fe9
 
a426a46
 
4711fe9
a426a46
 
 
 
 
 
 
 
 
 
 
4711fe9
a426a46
 
 
4711fe9
a426a46
4711fe9
a426a46
4711fe9
 
a426a46
 
 
 
 
4711fe9
 
a426a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4711fe9
 
 
 
 
 
a426a46
 
 
 
4711fe9
a426a46
 
 
 
4711fe9
 
 
 
 
 
a426a46
 
 
 
 
 
 
 
 
 
 
4711fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
import json
import logging
from typing import Dict, List, Any
from dotenv import load_dotenv
import asyncio
import time

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def retrieve_information(query: str, top_k: int = 5, threshold: float = 0.3) -> Dict:
    """
    Retrieve information from the knowledge base based on a query
    """
    from retrieving import RAGRetriever
    retriever = RAGRetriever()

    try:
        # Call the existing retrieve method from the RAGRetriever instance
        json_response = retriever.retrieve(query_text=query, top_k=top_k, threshold=threshold)
        results = json.loads(json_response)

        # Format the results for the assistant
        formatted_results = []
        for result in results.get('results', []):
            formatted_results.append({
                'content': result['content'],
                'url': result['url'],
                'position': result['position'],
                'similarity_score': result['similarity_score'],
                'chunk_id': result.get('chunk_id', ''),
                'created_at': result.get('created_at', '')
            })

        return {
            'query': query,
            'retrieved_chunks': formatted_results,
            'total_results': len(formatted_results),
            'metadata': results.get('metadata', {})
        }
    except Exception as e:
        logger.error(f"Error in retrieve_information: {e}")
        return {
            'query': query,
            'retrieved_chunks': [],
            'total_results': 0,
            'error': str(e),
            'metadata': {}
        }

class RAGAgent:
    def __init__(self):
        # Initialize the RAG system components
        # For now, we'll use the retrieval function directly
        # In a real implementation, you would initialize your existing RAG components
        logger.info("RAG Agent initialized with retrieval and generation components")

    def query_agent(self, query_text: str, session_id: str = None, query_type: str = "global", selected_text: str = None) -> Dict:
        """
        Process a query through the RAG system and return structured response
        """
        start_time = time.time()

        logger.info(f"Processing query through RAG system: '{query_text[:50]}...'")

        try:
            # Retrieve relevant information using our retrieval system
            retrieval_result = retrieve_information(query_text, top_k=5, threshold=0.3)

            if retrieval_result.get('error'):
                return {
                    "answer": "Sorry, I encountered an error retrieving information.",
                    "sources": [],
                    "matched_chunks": [],
                    "citations": [],
                    "error": retrieval_result['error'],
                    "query_time_ms": (time.time() - start_time) * 1000,
                    "session_id": session_id,
                    "query_type": query_type
                }

            # Format the retrieved information for response generation
            # In a real implementation, you would connect this to your response generator
            retrieved_chunks = retrieval_result.get('retrieved_chunks', [])

            if not retrieved_chunks:
                return {
                    "answer": "I couldn't find relevant information in the Physical AI & Humanoid Robotics curriculum to answer your question. Please try asking about specific topics from the curriculum like ROS 2, Digital Twins, AI-Brain, or VLA.",
                    "sources": [],
                    "matched_chunks": [],
                    "citations": [],
                    "error": None,
                    "query_time_ms": (time.time() - start_time) * 1000,
                    "session_id": session_id,
                    "query_type": query_type
                }

            # Generate a response based on the retrieved information
            # For now, we'll create a simple response based on the retrieved chunks
            answer_parts = ["Based on the Physical AI & Humanoid Robotics curriculum:"]

            # Include content from the most relevant chunks
            for i, chunk in enumerate(retrieved_chunks[:2]):  # Use top 2 chunks
                content = chunk.get('content', '')[:300]  # Limit content length
                answer_parts.append(f"{content}...")

            answer = " ".join(answer_parts)

            # Create citations from the retrieved chunks
            citations = []
            for chunk in retrieved_chunks:
                citation = {
                    "document_id": chunk.get('chunk_id', ''),
                    "title": chunk.get('url', ''),
                    "chapter": "",
                    "section": "",
                    "page_reference": ""
                }
                citations.append(citation)

            # Calculate query time
            query_time_ms = (time.time() - start_time) * 1000

            # Format the response
            response = {
                "answer": answer,
                "sources": [chunk.get('url', '') for chunk in retrieved_chunks if chunk.get('url')],
                "matched_chunks": retrieved_chunks,
                "citations": citations,
                "query_time_ms": query_time_ms,
                "session_id": session_id,
                "query_type": query_type,
                "confidence": self._calculate_confidence(retrieved_chunks),
                "error": None
            }

            logger.info(f"Query processed in {query_time_ms:.2f}ms")
            return response

        except Exception as e:
            logger.error(f"Error processing query: {e}")
            return {
                "answer": "Sorry, I encountered an error processing your request.",
                "sources": [],
                "matched_chunks": [],
                "citations": [],
                "error": str(e),
                "query_time_ms": (time.time() - start_time) * 1000,
                "session_id": session_id,
                "query_type": query_type
            }

    def _calculate_confidence(self, matched_chunks: List[Dict]) -> str:
        """
        Calculate confidence level based on similarity scores and number of matches
        """
        if not matched_chunks:
            return "low"

        avg_score = sum(chunk.get('similarity_score', 0.0) for chunk in matched_chunks) / len(matched_chunks)

        if avg_score >= 0.7:
            return "high"
        elif avg_score >= 0.4:
            return "medium"
        else:
            return "low"

def query_agent(query_text: str) -> Dict:
    """
    Convenience function to query the RAG agent
    """
    agent = RAGAgent()
    return agent.query_agent(query_text)

def run_agent_sync(query_text: str) -> Dict:
    """
    Synchronous function to run the agent for direct usage
    """
    import asyncio

    async def run_async():
        agent = RAGAgent()
        return await agent._async_query_agent(query_text)

    # Check if there's already a running event loop
    try:
        loop = asyncio.get_running_loop()
        # If there's already a loop, run in a separate thread
        import concurrent.futures
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future = executor.submit(asyncio.run, run_async())
            return future.result()
    except RuntimeError:
        # No running loop, safe to use asyncio.run
        return asyncio.run(run_async())

def main():
    """
    Main function to demonstrate the RAG agent functionality
    """
    logger.info("Initializing RAG Agent...")

    # Initialize the agent
    agent = RAGAgent()

    # Example queries to test the system
    test_queries = [
        "What is ROS2?",
        "Explain humanoid design principles",
        "How does VLA work?",
        "What are simulation techniques?",
        "Explain AI control systems"
    ]

    print("RAG Agent - Testing Queries")
    print("=" * 50)

    for i, query in enumerate(test_queries, 1):
        print(f"\nQuery {i}: {query}")
        print("-" * 30)

        # Process query through agent
        response = agent.query_agent(query)

        # Print formatted results
        print(f"Answer: {response['answer']}")

        if response.get('sources'):
            print(f"Sources: {len(response['sources'])} documents")
            for source in response['sources'][:3]:  # Show first 3 sources
                print(f"  - {source}")

        if response.get('matched_chunks'):
            print(f"Matched chunks: {len(response['matched_chunks'])}")
            for j, chunk in enumerate(response['matched_chunks'][:2], 1):  # Show first 2 chunks
                content_preview = chunk['content'][:100] + "..." if len(chunk['content']) > 100 else chunk['content']
                print(f"  Chunk {j}: {content_preview}")
                print(f"    Source: {chunk['url']}")
                print(f"    Score: {chunk['similarity_score']:.3f}")

        print(f"Query time: {response['query_time_ms']:.2f}ms")
        print(f"Confidence: {response.get('confidence', 'unknown')}")

        if i < len(test_queries):  # Don't sleep after the last query
            time.sleep(1)  # Small delay between queries

if __name__ == "__main__":
    main()