Claude commited on
Commit
7afbd6c
·
1 Parent(s): 4dec3fa

feat: Deploy Physical AI & Humanoid Robotics RAG backend

Browse files

- Add RAG backend with Cohere and Qdrant integration
- Include API endpoints for chat functionality
- Configure Dockerfile for Hugging Face Spaces deployment
- Add proper health check endpoint
- Set up requirements and dependencies

Files changed (13) hide show
  1. .env.example +26 -0
  2. .gitignore +26 -0
  3. Dockerfile +44 -0
  4. README.md +76 -10
  5. agent.py +254 -0
  6. api.py +215 -0
  7. app.py +8 -0
  8. main.py +322 -0
  9. pyproject.toml +27 -0
  10. requirements.txt +22 -0
  11. retrieving.py +267 -0
  12. sdk.md +935 -0
  13. uv.lock +0 -0
.env.example ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenRouter API Configuration
2
+ OPENROUTER_API_KEY=your_openrouter_api_key_here
3
+
4
+ # Qdrant Vector Database Configuration
5
+ QDRANT_URL=your_qdrant_url_here
6
+ QDRANT_API_KEY=your_qdrant_api_key_here
7
+ QDRANT_CLUSTER_ID=your_qdrant_cluster_id_here
8
+
9
+ # Neon PostgreSQL Database Configuration
10
+ NEON_DATABASE_URL=your_neon_database_url_here
11
+
12
+ # Cohere API Key (if needed)
13
+ COHERE_API_KEY=your_cohere_api_key_here
14
+
15
+ # Backend API Key
16
+ BACKEND_API_KEY=your_backend_api_key_here
17
+
18
+ # Target URL for Docusaurus site
19
+ TARGET_URL=your_vercel_url_here
20
+
21
+ # Application Configuration
22
+ DEBUG=False
23
+ LOG_LEVEL=INFO
24
+ MAX_CONTENT_LENGTH=5000
25
+ RATE_LIMIT_REQUESTS=100
26
+ RATE_LIMIT_WINDOW=60
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ env/
7
+ venv/
8
+ .venv/
9
+ pip-log.txt
10
+ pip-delete-this-directory.txt
11
+ .tox/
12
+ .coverage
13
+ .coverage.*
14
+ .cache
15
+ nosetests.xml
16
+ coverage.xml
17
+ *.cover
18
+ *.log
19
+ .git/
20
+ .DS_Store
21
+ .DS_Store?
22
+ ._*
23
+ .Spotlight-V100
24
+ .Trashes
25
+ ehthumbs.db
26
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.11 slim image as base
2
+ FROM python:3.11-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONDONTWRITEBYTECODE=1 \
6
+ PYTHONUNBUFFERED=1 \
7
+ PYTHONPATH=/app \
8
+ PORT=7860
9
+
10
+ # Set work directory
11
+ WORKDIR /app
12
+
13
+ # Install system dependencies
14
+ RUN apt-get update \
15
+ && apt-get install -y --no-install-recommends \
16
+ build-essential \
17
+ gcc \
18
+ curl \
19
+ && rm -rf /var/lib/apt/lists/*
20
+
21
+ # Copy requirements first to leverage Docker cache
22
+ COPY requirements.txt .
23
+
24
+ # Install Python dependencies
25
+ RUN pip install --no-cache-dir --upgrade pip \
26
+ && pip install --no-cache-dir -r requirements.txt
27
+
28
+ # Copy the rest of the application
29
+ COPY . .
30
+
31
+ # Create a non-root user and set permissions
32
+ RUN adduser --disabled-password --gecos '' appuser \
33
+ && chown -R appuser:appuser /app
34
+ USER appuser
35
+
36
+ # Expose port (Hugging Face typically uses port 7860 or 8080)
37
+ EXPOSE $PORT
38
+
39
+ # Health check endpoint
40
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
41
+ CMD curl -f http://localhost:$PORT/health || exit 1
42
+
43
+ # Run the application with uvicorn directly for better production performance
44
+ CMD ["sh", "-c", "python app.py"]
README.md CHANGED
@@ -1,10 +1,76 @@
1
- ---
2
- title: Chat Bot
3
- emoji: 🌖
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docusaurus Embedding Pipeline
2
+
3
+ This project extracts text from deployed Docusaurus URLs, generates embeddings using Cohere, and stores them in Qdrant for RAG-based retrieval.
4
+
5
+ ## Features
6
+
7
+ - Crawls Docusaurus sites to extract all accessible URLs
8
+ - Extracts and cleans text content from each page
9
+ - Chunks large documents to optimize embedding quality
10
+ - Generates vector embeddings using Cohere's API
11
+ - Stores embeddings in Qdrant vector database with metadata
12
+ - Supports similarity search for RAG applications
13
+
14
+ ## Prerequisites
15
+
16
+ - Python 3.9+
17
+ - UV package manager (`pip install uv`)
18
+ - Cohere API key
19
+ - Qdrant instance (local or cloud)
20
+
21
+ ## Setup
22
+
23
+ 1. Clone the repository and navigate to the backend directory
24
+ 2. Install UV package manager:
25
+ ```bash
26
+ pip install uv
27
+ ```
28
+
29
+ 3. Install dependencies:
30
+ ```bash
31
+ cd backend
32
+ uv sync # or uv pip install -r requirements.txt
33
+ ```
34
+
35
+ 4. Set up environment variables:
36
+ ```bash
37
+ cp .env.example .env
38
+ # Edit .env with your Cohere API key and Qdrant configuration
39
+ ```
40
+
41
+ ## Configuration
42
+
43
+ The pipeline can be configured via environment variables in the `.env` file:
44
+
45
+ - `COHERE_API_KEY`: Your Cohere API key
46
+ - `QDRANT_URL`: URL to your Qdrant instance
47
+ - `QDRANT_API_KEY`: API key for Qdrant (if required)
48
+ - `TARGET_URL`: The Docusaurus site to process
49
+
50
+ ## Usage
51
+
52
+ Run the complete pipeline:
53
+ ```bash
54
+ uv run main.py
55
+ ```
56
+
57
+ ## Architecture
58
+
59
+ The pipeline consists of these main functions:
60
+
61
+ 1. `get_all_urls()` - Extracts all URLs from the target Docusaurus site
62
+ 2. `extract_text_from_url()` - Cleans and extracts text content from a URL
63
+ 3. `chunk_text()` - Splits large documents into manageable chunks
64
+ 4. `embed()` - Generates vector embeddings using Cohere
65
+ 5. `create_collection()` - Sets up the Qdrant collection
66
+ 6. `save_chunk_to_qdrant()` - Stores embeddings with metadata in Qdrant
67
+
68
+ The main function orchestrates the complete workflow from crawling to storage.
69
+
70
+ ## Output
71
+
72
+ The pipeline stores document chunks as vectors in a Qdrant collection named "rag_embedding" with the following metadata:
73
+ - Content text
74
+ - Source URL
75
+ - Position in original document
76
+ - Creation timestamp
agent.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import Dict, List, Any
5
+ from dotenv import load_dotenv
6
+ import asyncio
7
+ import time
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def retrieve_information(query: str, top_k: int = 5, threshold: float = 0.3) -> Dict:
17
+ """
18
+ Retrieve information from the knowledge base based on a query
19
+ """
20
+ from retrieving import RAGRetriever
21
+ retriever = RAGRetriever()
22
+
23
+ try:
24
+ # Call the existing retrieve method from the RAGRetriever instance
25
+ json_response = retriever.retrieve(query_text=query, top_k=top_k, threshold=threshold)
26
+ results = json.loads(json_response)
27
+
28
+ # Format the results for the assistant
29
+ formatted_results = []
30
+ for result in results.get('results', []):
31
+ formatted_results.append({
32
+ 'content': result['content'],
33
+ 'url': result['url'],
34
+ 'position': result['position'],
35
+ 'similarity_score': result['similarity_score'],
36
+ 'chunk_id': result.get('chunk_id', ''),
37
+ 'created_at': result.get('created_at', '')
38
+ })
39
+
40
+ return {
41
+ 'query': query,
42
+ 'retrieved_chunks': formatted_results,
43
+ 'total_results': len(formatted_results),
44
+ 'metadata': results.get('metadata', {})
45
+ }
46
+ except Exception as e:
47
+ logger.error(f"Error in retrieve_information: {e}")
48
+ return {
49
+ 'query': query,
50
+ 'retrieved_chunks': [],
51
+ 'total_results': 0,
52
+ 'error': str(e),
53
+ 'metadata': {}
54
+ }
55
+
56
+ class RAGAgent:
57
+ def __init__(self):
58
+ # Initialize the RAG system components
59
+ # For now, we'll use the retrieval function directly
60
+ # In a real implementation, you would initialize your existing RAG components
61
+ logger.info("RAG Agent initialized with retrieval and generation components")
62
+
63
+ def query_agent(self, query_text: str, session_id: str = None, query_type: str = "global", selected_text: str = None) -> Dict:
64
+ """
65
+ Process a query through the RAG system and return structured response
66
+ """
67
+ start_time = time.time()
68
+
69
+ logger.info(f"Processing query through RAG system: '{query_text[:50]}...'")
70
+
71
+ try:
72
+ # Retrieve relevant information using our retrieval system
73
+ retrieval_result = retrieve_information(query_text, top_k=5, threshold=0.3)
74
+
75
+ if retrieval_result.get('error'):
76
+ return {
77
+ "answer": "Sorry, I encountered an error retrieving information.",
78
+ "sources": [],
79
+ "matched_chunks": [],
80
+ "citations": [],
81
+ "error": retrieval_result['error'],
82
+ "query_time_ms": (time.time() - start_time) * 1000,
83
+ "session_id": session_id,
84
+ "query_type": query_type
85
+ }
86
+
87
+ # Format the retrieved information for response generation
88
+ # In a real implementation, you would connect this to your response generator
89
+ retrieved_chunks = retrieval_result.get('retrieved_chunks', [])
90
+
91
+ if not retrieved_chunks:
92
+ return {
93
+ "answer": "I couldn't find relevant information in the Physical AI & Humanoid Robotics curriculum to answer your question. Please try asking about specific topics from the curriculum like ROS 2, Digital Twins, AI-Brain, or VLA.",
94
+ "sources": [],
95
+ "matched_chunks": [],
96
+ "citations": [],
97
+ "error": None,
98
+ "query_time_ms": (time.time() - start_time) * 1000,
99
+ "session_id": session_id,
100
+ "query_type": query_type
101
+ }
102
+
103
+ # Generate a response based on the retrieved information
104
+ # For now, we'll create a simple response based on the retrieved chunks
105
+ answer_parts = ["Based on the Physical AI & Humanoid Robotics curriculum:"]
106
+
107
+ # Include content from the most relevant chunks
108
+ for i, chunk in enumerate(retrieved_chunks[:2]): # Use top 2 chunks
109
+ content = chunk.get('content', '')[:300] # Limit content length
110
+ answer_parts.append(f"{content}...")
111
+
112
+ answer = " ".join(answer_parts)
113
+
114
+ # Create citations from the retrieved chunks
115
+ citations = []
116
+ for chunk in retrieved_chunks:
117
+ citation = {
118
+ "document_id": chunk.get('chunk_id', ''),
119
+ "title": chunk.get('url', ''),
120
+ "chapter": "",
121
+ "section": "",
122
+ "page_reference": ""
123
+ }
124
+ citations.append(citation)
125
+
126
+ # Calculate query time
127
+ query_time_ms = (time.time() - start_time) * 1000
128
+
129
+ # Format the response
130
+ response = {
131
+ "answer": answer,
132
+ "sources": [chunk.get('url', '') for chunk in retrieved_chunks if chunk.get('url')],
133
+ "matched_chunks": retrieved_chunks,
134
+ "citations": citations,
135
+ "query_time_ms": query_time_ms,
136
+ "session_id": session_id,
137
+ "query_type": query_type,
138
+ "confidence": self._calculate_confidence(retrieved_chunks),
139
+ "error": None
140
+ }
141
+
142
+ logger.info(f"Query processed in {query_time_ms:.2f}ms")
143
+ return response
144
+
145
+ except Exception as e:
146
+ logger.error(f"Error processing query: {e}")
147
+ return {
148
+ "answer": "Sorry, I encountered an error processing your request.",
149
+ "sources": [],
150
+ "matched_chunks": [],
151
+ "citations": [],
152
+ "error": str(e),
153
+ "query_time_ms": (time.time() - start_time) * 1000,
154
+ "session_id": session_id,
155
+ "query_type": query_type
156
+ }
157
+
158
+ def _calculate_confidence(self, matched_chunks: List[Dict]) -> str:
159
+ """
160
+ Calculate confidence level based on similarity scores and number of matches
161
+ """
162
+ if not matched_chunks:
163
+ return "low"
164
+
165
+ avg_score = sum(chunk.get('similarity_score', 0.0) for chunk in matched_chunks) / len(matched_chunks)
166
+
167
+ if avg_score >= 0.7:
168
+ return "high"
169
+ elif avg_score >= 0.4:
170
+ return "medium"
171
+ else:
172
+ return "low"
173
+
174
+ def query_agent(query_text: str) -> Dict:
175
+ """
176
+ Convenience function to query the RAG agent
177
+ """
178
+ agent = RAGAgent()
179
+ return agent.query_agent(query_text)
180
+
181
+ def run_agent_sync(query_text: str) -> Dict:
182
+ """
183
+ Synchronous function to run the agent for direct usage
184
+ """
185
+ import asyncio
186
+
187
+ async def run_async():
188
+ agent = RAGAgent()
189
+ return await agent._async_query_agent(query_text)
190
+
191
+ # Check if there's already a running event loop
192
+ try:
193
+ loop = asyncio.get_running_loop()
194
+ # If there's already a loop, run in a separate thread
195
+ import concurrent.futures
196
+ with concurrent.futures.ThreadPoolExecutor() as executor:
197
+ future = executor.submit(asyncio.run, run_async())
198
+ return future.result()
199
+ except RuntimeError:
200
+ # No running loop, safe to use asyncio.run
201
+ return asyncio.run(run_async())
202
+
203
+ def main():
204
+ """
205
+ Main function to demonstrate the RAG agent functionality
206
+ """
207
+ logger.info("Initializing RAG Agent...")
208
+
209
+ # Initialize the agent
210
+ agent = RAGAgent()
211
+
212
+ # Example queries to test the system
213
+ test_queries = [
214
+ "What is ROS2?",
215
+ "Explain humanoid design principles",
216
+ "How does VLA work?",
217
+ "What are simulation techniques?",
218
+ "Explain AI control systems"
219
+ ]
220
+
221
+ print("RAG Agent - Testing Queries")
222
+ print("=" * 50)
223
+
224
+ for i, query in enumerate(test_queries, 1):
225
+ print(f"\nQuery {i}: {query}")
226
+ print("-" * 30)
227
+
228
+ # Process query through agent
229
+ response = agent.query_agent(query)
230
+
231
+ # Print formatted results
232
+ print(f"Answer: {response['answer']}")
233
+
234
+ if response.get('sources'):
235
+ print(f"Sources: {len(response['sources'])} documents")
236
+ for source in response['sources'][:3]: # Show first 3 sources
237
+ print(f" - {source}")
238
+
239
+ if response.get('matched_chunks'):
240
+ print(f"Matched chunks: {len(response['matched_chunks'])}")
241
+ for j, chunk in enumerate(response['matched_chunks'][:2], 1): # Show first 2 chunks
242
+ content_preview = chunk['content'][:100] + "..." if len(chunk['content']) > 100 else chunk['content']
243
+ print(f" Chunk {j}: {content_preview}")
244
+ print(f" Source: {chunk['url']}")
245
+ print(f" Score: {chunk['similarity_score']:.3f}")
246
+
247
+ print(f"Query time: {response['query_time_ms']:.2f}ms")
248
+ print(f"Confidence: {response.get('confidence', 'unknown')}")
249
+
250
+ if i < len(test_queries): # Don't sleep after the last query
251
+ time.sleep(1) # Small delay between queries
252
+
253
+ if __name__ == "__main__":
254
+ main()
api.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import List, Optional, Dict
7
+ from dotenv import load_dotenv
8
+ import logging
9
+
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Import the existing RAG agent functionality
18
+ from agent import RAGAgent
19
+
20
+ # Create FastAPI app
21
+ app = FastAPI(
22
+ title="RAG Agent API",
23
+ description="API for RAG Agent with document retrieval and question answering",
24
+ version="1.0.0"
25
+ )
26
+
27
+ # Add CORS middleware for development
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"], # In production, replace with specific origins
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # Pydantic models
37
+ class QueryRequest(BaseModel):
38
+ query: str
39
+
40
+ class ChatRequest(BaseModel):
41
+ query: str
42
+ message: str
43
+ session_id: str
44
+ selected_text: Optional[str] = None
45
+ query_type: str = "global"
46
+ top_k: int = 5
47
+
48
+ class MatchedChunk(BaseModel):
49
+ content: str
50
+ url: str
51
+ position: int
52
+ similarity_score: float
53
+
54
+ class QueryResponse(BaseModel):
55
+ answer: str
56
+ sources: List[str]
57
+ matched_chunks: List[MatchedChunk]
58
+ error: Optional[str] = None
59
+ status: str # "success", "error", "empty"
60
+ query_time_ms: Optional[float] = None
61
+ confidence: Optional[str] = None
62
+
63
+ class ChatResponse(BaseModel):
64
+ response: str
65
+ citations: List[Dict[str, str]]
66
+ session_id: str
67
+ query_type: str
68
+ timestamp: str
69
+
70
+ class HealthResponse(BaseModel):
71
+ status: str
72
+ message: str
73
+
74
+ # Global RAG agent instance
75
+ rag_agent = None
76
+
77
+ @app.on_event("startup")
78
+ async def startup_event():
79
+ """Initialize the RAG agent on startup"""
80
+ global rag_agent
81
+ logger.info("Initializing RAG Agent...")
82
+ try:
83
+ rag_agent = RAGAgent()
84
+ logger.info("RAG Agent initialized successfully")
85
+ except Exception as e:
86
+ logger.error(f"Failed to initialize RAG Agent: {e}")
87
+ raise
88
+
89
+ @app.post("/ask", response_model=QueryResponse)
90
+ async def ask_rag(request: QueryRequest):
91
+ """
92
+ Process a user query through the RAG agent and return the response
93
+ """
94
+ logger.info(f"Processing query: {request.query[:50]}...")
95
+
96
+ try:
97
+ # Validate input
98
+ if not request.query or len(request.query.strip()) == 0:
99
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
100
+
101
+ if len(request.query) > 2000:
102
+ raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters")
103
+
104
+ # Process query through RAG agent
105
+ response = rag_agent.query_agent(request.query)
106
+
107
+ # Format response
108
+ formatted_response = QueryResponse(
109
+ answer=response.get("answer", ""),
110
+ sources=response.get("sources", []),
111
+ matched_chunks=[
112
+ MatchedChunk(
113
+ content=chunk.get("content", ""),
114
+ url=chunk.get("url", ""),
115
+ position=chunk.get("position", 0),
116
+ similarity_score=chunk.get("similarity_score", 0.0)
117
+ )
118
+ for chunk in response.get("matched_chunks", [])
119
+ ],
120
+ error=response.get("error"),
121
+ status="error" if response.get("error") else "success",
122
+ query_time_ms=response.get("query_time_ms"),
123
+ confidence=response.get("confidence")
124
+ )
125
+
126
+ logger.info(f"Query processed successfully in {response.get('query_time_ms', 0):.2f}ms")
127
+ return formatted_response
128
+
129
+ except HTTPException:
130
+ raise
131
+ except Exception as e:
132
+ logger.error(f"Error processing query: {e}")
133
+ return QueryResponse(
134
+ answer="",
135
+ sources=[],
136
+ matched_chunks=[],
137
+ error=str(e),
138
+ status="error"
139
+ )
140
+
141
+ @app.post("/api", response_model=ChatResponse)
142
+ async def chat_endpoint(request: ChatRequest):
143
+ """
144
+ Main chat endpoint that handles conversation with RAG capabilities
145
+ """
146
+ logger.info(f"Processing chat query: {request.query[:50]}...")
147
+
148
+ try:
149
+ # Validate input
150
+ if not request.query or len(request.query.strip()) == 0:
151
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
152
+
153
+ if not request.session_id or len(request.session_id.strip()) == 0:
154
+ raise HTTPException(status_code=400, detail="Session ID cannot be empty")
155
+
156
+ if len(request.query) > 2000:
157
+ raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters")
158
+
159
+ # Process query through RAG agent
160
+ response = rag_agent.query_agent(request.query)
161
+
162
+ # Format response to match expected structure
163
+ from datetime import datetime
164
+ timestamp = datetime.utcnow().isoformat()
165
+
166
+ # Convert matched chunks to citations format
167
+ citations = []
168
+ for chunk in response.get("matched_chunks", []):
169
+ citation = {
170
+ "document_id": "",
171
+ "title": chunk.get("url", ""),
172
+ "chapter": "",
173
+ "section": "",
174
+ "page_reference": ""
175
+ }
176
+ citations.append(citation)
177
+
178
+ formatted_response = ChatResponse(
179
+ response=response.get("answer", ""),
180
+ citations=citations,
181
+ session_id=request.session_id,
182
+ query_type=request.query_type,
183
+ timestamp=timestamp
184
+ )
185
+
186
+ logger.info(f"Chat query processed successfully")
187
+ return formatted_response
188
+
189
+ except HTTPException:
190
+ raise
191
+ except Exception as e:
192
+ logger.error(f"Error processing chat query: {e}")
193
+ from datetime import datetime
194
+ return ChatResponse(
195
+ response="",
196
+ citations=[],
197
+ session_id=request.session_id,
198
+ query_type=request.query_type,
199
+ timestamp=datetime.utcnow().isoformat()
200
+ )
201
+
202
+ @app.get("/health", response_model=HealthResponse)
203
+ async def health_check():
204
+ """
205
+ Health check endpoint
206
+ """
207
+ return HealthResponse(
208
+ status="healthy",
209
+ message="RAG Agent API is running"
210
+ )
211
+
212
+ # For running with uvicorn
213
+ if __name__ == "__main__":
214
+ import uvicorn
215
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # app.py - Entry point for Hugging Face Spaces
2
+ import os
3
+ import uvicorn
4
+ from api import app
5
+
6
+ if __name__ == "__main__":
7
+ port = int(os.environ.get("PORT", 7860))
8
+ uvicorn.run(app, host="0.0.0.0", port=port)
main.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from bs4 import BeautifulSoup
4
+ import xml.etree.ElementTree as ET
5
+ from typing import List, Dict, Any
6
+ import cohere
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.http import models
9
+ from qdrant_client.models import PointStruct
10
+ import logging
11
+ from urllib.parse import urljoin, urlparse
12
+ import time
13
+ import uuid
14
+ from dotenv import load_dotenv
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ class DocusaurusEmbeddingPipeline:
24
+ def __init__(self):
25
+ # Initialize Cohere client
26
+ self.cohere_client = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
27
+
28
+ # Initialize Qdrant client
29
+ qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")
30
+ qdrant_api_key = os.getenv("QDRANT_API_KEY")
31
+
32
+ if qdrant_api_key:
33
+ self.qdrant_client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
34
+ else:
35
+ self.qdrant_client = QdrantClient(url=qdrant_url)
36
+
37
+ # Target URL for the Docusaurus site - configurable via environment variable
38
+ self.target_url = os.getenv("TARGET_URL", "https://your-vercel-url.vercel.app/")
39
+
40
+ def get_all_urls(self, base_url: str) -> List[str]:
41
+ """
42
+ Extract all URLs from a deployed Docusaurus site using sitemap
43
+ """
44
+ urls = []
45
+
46
+ try:
47
+ # Try to get URLs from sitemap first
48
+ sitemap_url = urljoin(base_url, "sitemap.xml")
49
+ response = requests.get(sitemap_url)
50
+
51
+ if response.status_code == 200:
52
+ root = ET.fromstring(response.content)
53
+
54
+ # Handle both sitemap index and regular sitemap
55
+ if root.tag.endswith('sitemapindex'):
56
+ # If it's a sitemap index, get individual sitemaps
57
+ for sitemap in root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'):
58
+ sitemap_response = requests.get(sitemap.text)
59
+ if sitemap_response.status_code == 200:
60
+ sitemap_root = ET.fromstring(sitemap_response.content)
61
+ for url_elem in sitemap_root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'):
62
+ urls.append(url_elem.text)
63
+ else:
64
+ # Regular sitemap
65
+ for url_elem in root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'):
66
+ urls.append(url_elem.text)
67
+ else:
68
+ # Fallback: try to crawl the site by looking for links
69
+ logger.info(f"Sitemap not found at {sitemap_url}, attempting to crawl...")
70
+
71
+ # Get the main page and extract links
72
+ response = requests.get(base_url)
73
+ soup = BeautifulSoup(response.content, 'html.parser')
74
+
75
+ # Find all links within the page
76
+ for link in soup.find_all('a', href=True):
77
+ href = link['href']
78
+ full_url = urljoin(base_url, href)
79
+
80
+ # Only add URLs from the same domain
81
+ if urlparse(full_url).netloc == urlparse(base_url).netloc:
82
+ if full_url not in urls and full_url.startswith(base_url):
83
+ urls.append(full_url)
84
+
85
+ except Exception as e:
86
+ logger.error(f"Error getting URLs from {base_url}: {e}")
87
+
88
+ return urls
89
+
90
+ def extract_text_from_url(self, url: str) -> str:
91
+ """
92
+ Extract and clean text from a single URL
93
+ """
94
+ try:
95
+ response = requests.get(url)
96
+ response.raise_for_status()
97
+
98
+ soup = BeautifulSoup(response.content, 'html.parser')
99
+
100
+ # Remove script and style elements
101
+ for script in soup(["script", "style"]):
102
+ script.decompose()
103
+
104
+ # Look for main content containers typically used in Docusaurus
105
+ # Try multiple selectors to find the main content
106
+ content_selectors = [
107
+ 'article', # Main article content
108
+ '.markdown', # Docusaurus markdown content
109
+ '.theme-doc-markdown', # Docusaurus theme markdown
110
+ '.main-wrapper', # Main content wrapper
111
+ 'main', # Main content area
112
+ '.container', # Container with content
113
+ '[role="main"]' # Main role
114
+ ]
115
+
116
+ content = ""
117
+ for selector in content_selectors:
118
+ elements = soup.select(selector)
119
+ if elements:
120
+ for element in elements:
121
+ # Get text but try to preserve some structure
122
+ text = element.get_text(separator=' ', strip=True)
123
+ if len(text) > len(content):
124
+ content = text
125
+ break
126
+
127
+ # If no specific content found, get all body text
128
+ if not content:
129
+ body = soup.find('body')
130
+ if body:
131
+ content = body.get_text(separator=' ', strip=True)
132
+
133
+ # Clean up the text
134
+ lines = (line.strip() for line in content.splitlines())
135
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
136
+ content = ' '.join(chunk for chunk in chunks if chunk)
137
+
138
+ return content
139
+
140
+ except Exception as e:
141
+ logger.error(f"Error extracting text from {url}: {e}")
142
+ return ""
143
+
144
+ def chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
145
+ """
146
+ Split text into chunks with overlap to preserve context
147
+ """
148
+ if len(text) <= chunk_size:
149
+ return [text]
150
+
151
+ chunks = []
152
+ start = 0
153
+
154
+ while start < len(text):
155
+ end = start + chunk_size
156
+ chunk = text[start:end]
157
+ chunks.append(chunk)
158
+
159
+ # Move start position by chunk_size - overlap
160
+ start = end - overlap
161
+
162
+ # If remaining text is less than chunk_size, add it as final chunk
163
+ if len(text) - start < chunk_size:
164
+ if start < len(text):
165
+ final_chunk = text[start:]
166
+ if final_chunk not in chunks: # Avoid duplicate chunks
167
+ chunks.append(final_chunk)
168
+ break
169
+
170
+ return chunks
171
+
172
+ def embed(self, text: str) -> List[float]:
173
+ """
174
+ Generate embedding for text using Cohere
175
+ """
176
+ try:
177
+ response = self.cohere_client.embed(
178
+ texts=[text],
179
+ model="embed-multilingual-v3.0", # Using multilingual model
180
+ input_type="search_document" # Optimize for search
181
+ )
182
+ return response.embeddings[0] # Return the first (and only) embedding
183
+ except Exception as e:
184
+ logger.error(f"Error generating embedding for text: {e}")
185
+ return []
186
+
187
+ def create_collection(self, collection_name: str = "rag_embedding"):
188
+ """
189
+ Create a Qdrant collection for storing embeddings
190
+ """
191
+ try:
192
+ # Check if collection already exists
193
+ collections = self.qdrant_client.get_collections()
194
+ collection_names = [col.name for col in collections.collections]
195
+
196
+ if collection_name in collection_names:
197
+ logger.info(f"Collection {collection_name} already exists")
198
+ return
199
+
200
+ # Create collection with appropriate vector size (1024 for Cohere embeddings)
201
+ self.qdrant_client.create_collection(
202
+ collection_name=collection_name,
203
+ vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE)
204
+ )
205
+
206
+ logger.info(f"Created collection {collection_name} with 1024-dimension vectors")
207
+
208
+ except Exception as e:
209
+ logger.error(f"Error creating collection {collection_name}: {e}")
210
+ raise
211
+
212
+ def save_chunk_to_qdrant(self, content: str, url: str, embedding: List[float], position: int, collection_name: str = "rag_embedding"):
213
+ """
214
+ Save a text chunk with its embedding to Qdrant
215
+ """
216
+ try:
217
+ # Generate a unique ID for the point
218
+ point_id = str(uuid.uuid4())
219
+
220
+ # Prepare the payload with metadata
221
+ payload = {
222
+ "content": content,
223
+ "url": url,
224
+ "position": position,
225
+ "created_at": time.time()
226
+ }
227
+
228
+ # Create and upload the point to Qdrant
229
+ points = [PointStruct(
230
+ id=point_id,
231
+ vector=embedding,
232
+ payload=payload
233
+ )]
234
+
235
+ self.qdrant_client.upsert(
236
+ collection_name=collection_name,
237
+ points=points
238
+ )
239
+
240
+ logger.info(f"Saved chunk to Qdrant: {url} (position {position})")
241
+ return True
242
+
243
+ except Exception as e:
244
+ logger.error(f"Error saving chunk to Qdrant: {e}")
245
+ return False
246
+
247
+ def main():
248
+ """
249
+ Main function to execute the complete pipeline
250
+ """
251
+ logger.info("Starting Docusaurus Embedding Pipeline...")
252
+
253
+ # Initialize the pipeline
254
+ pipeline = DocusaurusEmbeddingPipeline()
255
+
256
+ try:
257
+ # Step 1: Create the Qdrant collection
258
+ logger.info("Creating Qdrant collection...")
259
+ pipeline.create_collection("rag_embedding")
260
+
261
+ # Step 2: Get all URLs from the target Docusaurus site
262
+ logger.info(f"Extracting URLs from {pipeline.target_url}...")
263
+ urls = pipeline.get_all_urls(pipeline.target_url)
264
+
265
+ if not urls:
266
+ logger.warning(f"No URLs found at {pipeline.target_url}")
267
+ return
268
+
269
+ logger.info(f"Found {len(urls)} URLs to process")
270
+
271
+ # Step 3: Process each URL
272
+ total_chunks = 0
273
+ for i, url in enumerate(urls):
274
+ logger.info(f"Processing URL {i+1}/{len(urls)}: {url}")
275
+
276
+ # Extract text from the URL
277
+ text_content = pipeline.extract_text_from_url(url)
278
+
279
+ if not text_content:
280
+ logger.warning(f"No content extracted from {url}")
281
+ continue
282
+
283
+ logger.info(f"Extracted {len(text_content)} characters from {url}")
284
+
285
+ # Chunk the text
286
+ chunks = pipeline.chunk_text(text_content)
287
+ logger.info(f"Created {len(chunks)} chunks from {url}")
288
+
289
+ # Process each chunk
290
+ for j, chunk in enumerate(chunks):
291
+ if not chunk.strip():
292
+ continue
293
+
294
+ # Generate embedding
295
+ embedding = pipeline.embed(chunk)
296
+
297
+ if not embedding:
298
+ logger.error(f"Failed to generate embedding for chunk {j} of {url}")
299
+ continue
300
+
301
+ # Save to Qdrant
302
+ success = pipeline.save_chunk_to_qdrant(
303
+ content=chunk,
304
+ url=url,
305
+ embedding=embedding,
306
+ position=j
307
+ )
308
+
309
+ if success:
310
+ total_chunks += 1
311
+ logger.info(f"Successfully saved chunk {j} of {url} to Qdrant")
312
+ else:
313
+ logger.error(f"Failed to save chunk {j} of {url} to Qdrant")
314
+
315
+ logger.info(f"Pipeline completed successfully! Total chunks saved: {total_chunks}")
316
+
317
+ except Exception as e:
318
+ logger.error(f"Pipeline failed with error: {e}")
319
+ raise
320
+
321
+ if __name__ == "__main__":
322
+ main()
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools"
4
+
5
+ [project]
6
+ name = "docusaurus-embedding-pipeline"
7
+ version = "0.1.0"
8
+ description = "Pipeline to extract text from Docusaurus sites, generate embeddings, and store in Qdrant"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ dependencies = [
12
+ "requests>=2.31.0",
13
+ "beautifulsoup4>=4.12.2",
14
+ "cohere>=4.9.0",
15
+ "qdrant-client>=1.9.0",
16
+ "python-dotenv>=1.0.0"
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ dev = [
21
+ "pytest>=7.0",
22
+ "black>=23.0",
23
+ "flake8>=6.0"
24
+ ]
25
+
26
+ [tool.setuptools.packages.find]
27
+ where = ["."]
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ python-dotenv==1.0.0
4
+ qdrant-client==1.9.1
5
+ httpx==0.25.2
6
+ psycopg2-binary==2.9.9
7
+ sqlalchemy==2.0.23
8
+ pydantic==2.5.0
9
+ pydantic-settings==2.1.0
10
+ openai==1.3.6
11
+ tiktoken==0.5.2
12
+ markdown==3.5.1
13
+ python-multipart==0.0.6
14
+ python-jose[cryptography]==3.3.0
15
+ passlib[bcrypt]==1.7.4
16
+ python-slugify==8.0.1
17
+ asyncpg==0.29.0
18
+ alembic==1.13.1
19
+ beautifulsoup4==4.12.2
20
+ scikit-learn==1.3.2
21
+ requests>=2.31.0
22
+ cohere>=4.9.0
retrieving.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import List, Dict, Any
4
+ import cohere
5
+ from qdrant_client import QdrantClient
6
+ from qdrant_client.http import models
7
+ import logging
8
+ from dotenv import load_dotenv
9
+ import time
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class RAGRetriever:
19
+ def __init__(self):
20
+ # Initialize Cohere client
21
+ self.cohere_client = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
22
+
23
+ # Initialize Qdrant client
24
+ qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")
25
+ qdrant_api_key = os.getenv("QDRANT_API_KEY")
26
+
27
+ if qdrant_api_key:
28
+ self.qdrant_client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
29
+ else:
30
+ self.qdrant_client = QdrantClient(url=qdrant_url)
31
+
32
+ # Default collection name
33
+ self.collection_name = "rag_embedding"
34
+
35
+ def get_embedding(self, text: str) -> List[float]:
36
+ """
37
+ Generate embedding for query text using Cohere
38
+ """
39
+ try:
40
+ response = self.cohere_client.embed(
41
+ texts=[text],
42
+ model="embed-multilingual-v3.0", # Using same model as storage
43
+ input_type="search_query" # Optimize for search queries
44
+ )
45
+ return response.embeddings[0] # Return the first (and only) embedding
46
+ except Exception as e:
47
+ logger.error(f"Error generating embedding for query: {e}")
48
+ return []
49
+
50
+ def query_qdrant(self, query_embedding: List[float], top_k: int = 5, threshold: float = 0.0) -> List[Dict]:
51
+ """
52
+ Query Qdrant for similar vectors and return results with metadata
53
+ """
54
+ try:
55
+ # Perform similarity search in Qdrant
56
+ search_results = self.qdrant_client.search(
57
+ collection_name=self.collection_name,
58
+ query_vector=query_embedding,
59
+ limit=top_k,
60
+ score_threshold=threshold,
61
+ with_payload=True # Include metadata with results
62
+ )
63
+
64
+ # Format results
65
+ formatted_results = []
66
+ for result in search_results:
67
+ formatted_result = {
68
+ "content": result.payload.get("content", ""),
69
+ "url": result.payload.get("url", ""),
70
+ "position": result.payload.get("position", 0),
71
+ "similarity_score": result.score,
72
+ "chunk_id": result.id,
73
+ "created_at": result.payload.get("created_at", "")
74
+ }
75
+ formatted_results.append(formatted_result)
76
+
77
+ return formatted_results
78
+
79
+ except Exception as e:
80
+ logger.error(f"Error querying Qdrant: {e}")
81
+ return []
82
+
83
+ def verify_content_accuracy(self, retrieved_chunks: List[Dict]) -> bool:
84
+ """
85
+ Verify that retrieved content matches original stored text (basic validation)
86
+ """
87
+ # In a real implementation, this would compare against original sources
88
+ # For now, we'll validate that required fields exist and have content
89
+ for chunk in retrieved_chunks:
90
+ if not chunk.get("content") or not chunk.get("url"):
91
+ logger.warning(f"Missing content or URL in chunk: {chunk.get('chunk_id')}")
92
+ return False
93
+
94
+ # Additional validation could include checking content length, URL format, etc.
95
+ return True
96
+
97
+ def format_json_response(self, results: List[Dict], query: str, query_time_ms: float) -> str:
98
+ """
99
+ Format retrieval results into clean JSON response
100
+ """
101
+ response = {
102
+ "query": query,
103
+ "results": results,
104
+ "metadata": {
105
+ "query_time_ms": query_time_ms,
106
+ "total_results": len(results),
107
+ "timestamp": time.time(),
108
+ "collection_name": self.collection_name
109
+ }
110
+ }
111
+
112
+ return json.dumps(response, indent=2)
113
+
114
+ def retrieve(self, query_text: str, top_k: int = 5, threshold: float = 0.0, include_metadata: bool = True) -> str:
115
+ """
116
+ Main retrieval function that orchestrates the complete workflow
117
+ """
118
+ start_time = time.time()
119
+
120
+ logger.info(f"Processing retrieval request for query: '{query_text[:50]}...'")
121
+
122
+ # Step 1: Convert query text to embedding
123
+ query_embedding = self.get_embedding(query_text)
124
+ if not query_embedding:
125
+ error_response = {
126
+ "query": query_text,
127
+ "results": [],
128
+ "error": "Failed to generate query embedding",
129
+ "metadata": {
130
+ "query_time_ms": (time.time() - start_time) * 1000,
131
+ "timestamp": time.time()
132
+ }
133
+ }
134
+ return json.dumps(error_response, indent=2)
135
+
136
+ # Step 2: Query Qdrant for similar vectors
137
+ raw_results = self.query_qdrant(query_embedding, top_k, threshold)
138
+
139
+ if not raw_results:
140
+ logger.warning("No results returned from Qdrant")
141
+
142
+ # Step 3: Verify content accuracy (optional)
143
+ if include_metadata:
144
+ is_accurate = self.verify_content_accuracy(raw_results)
145
+ if not is_accurate:
146
+ logger.warning("Content accuracy verification failed for some results")
147
+
148
+ # Step 4: Calculate total query time
149
+ query_time_ms = (time.time() - start_time) * 1000
150
+
151
+ # Step 5: Format response as JSON
152
+ json_response = self.format_json_response(raw_results, query_text, query_time_ms)
153
+
154
+ logger.info(f"Retrieval completed in {query_time_ms:.2f}ms, {len(raw_results)} results returned")
155
+
156
+ return json_response
157
+
158
+ def retrieve_all_data():
159
+ """
160
+ Function to retrieve and display all data from Qdrant collection
161
+ """
162
+ logger.info("Initializing RAG Retriever to fetch all data...")
163
+
164
+ # Initialize the retriever
165
+ retriever = RAGRetriever()
166
+
167
+ print("RAG Retrieval System - All Stored Data")
168
+ print("=" * 50)
169
+
170
+ try:
171
+ # Get all points from the collection using scroll
172
+ points = []
173
+ offset = None
174
+ while True:
175
+ # Scroll through the collection to get all points
176
+ batch, next_offset = retriever.qdrant_client.scroll(
177
+ collection_name=retriever.collection_name,
178
+ limit=1000, # Get up to 1000 points at a time
179
+ offset=offset,
180
+ with_payload=True,
181
+ with_vectors=False
182
+ )
183
+
184
+ points.extend(batch)
185
+
186
+ # If next_offset is None, we've reached the end
187
+ if next_offset is None:
188
+ break
189
+
190
+ offset = next_offset
191
+
192
+ print(f"Total stored chunks: {len(points)}")
193
+ print("-" * 50)
194
+
195
+ for i, point in enumerate(points, 1):
196
+ payload = point.payload
197
+ content_preview = ''.join(char for char in payload.get("content", "")[:200] if ord(char) < 256)
198
+
199
+ print(f"Chunk {i}:")
200
+ print(f" ID: {point.id}")
201
+ print(f" URL: {payload.get('url', 'N/A')}")
202
+ print(f" Position: {payload.get('position', 'N/A')}")
203
+ print(f" Content Preview: {content_preview}...")
204
+ print(f" Created At: {payload.get('created_at', 'N/A')}")
205
+ print("-" * 30)
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error retrieving all data: {e}")
209
+ print(f"Error retrieving all data: {e}")
210
+
211
+
212
+ def main():
213
+ """
214
+ Main function to demonstrate the retrieval functionality
215
+ """
216
+ import sys
217
+
218
+ logger.info("Initializing RAG Retriever...")
219
+
220
+ # Check if user wants to retrieve all data or run queries
221
+ if len(sys.argv) > 1 and sys.argv[1] == "all":
222
+ retrieve_all_data()
223
+ return
224
+
225
+ # Initialize the retriever
226
+ retriever = RAGRetriever()
227
+
228
+ # Example queries to test the system
229
+ test_queries = [
230
+ "What is ROS2?",
231
+ "Explain humanoid design principles",
232
+ "How does VLA work?",
233
+ "What are simulation techniques?",
234
+ "Explain AI control systems"
235
+ ]
236
+
237
+ print("RAG Retrieval System - Testing Queries")
238
+ print("=" * 50)
239
+
240
+ for i, query in enumerate(test_queries, 1):
241
+ print(f"\nQuery {i}: {query}")
242
+ print("-" * 30)
243
+
244
+ # Retrieve results
245
+ json_response = retriever.retrieve(query, top_k=3)
246
+ response_dict = json.loads(json_response)
247
+
248
+ # Print formatted results
249
+ results = response_dict.get("results", [])
250
+ if results:
251
+ for j, result in enumerate(results, 1):
252
+ print(f"Result {j} (Score: {result['similarity_score']:.3f}):")
253
+ print(f" URL: {result['url']}")
254
+ content_preview = result['content'][:100].encode('utf-8', errors='ignore').decode('utf-8')
255
+ # Safely print content preview by removing problematic characters
256
+ safe_content = ''.join(char for char in content_preview if ord(char) < 256)
257
+ print(f" Content Preview: {safe_content}...")
258
+ print(f" Position: {result['position']}")
259
+ print()
260
+ else:
261
+ print("No results found for this query.")
262
+
263
+ print(f"Query time: {response_dict['metadata']['query_time_ms']:.2f}ms")
264
+ print(f"Total results: {response_dict['metadata']['total_results']}")
265
+
266
+ if __name__ == "__main__":
267
+ main()
sdk.md ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenAI Agents SDK
2
+
3
+ The OpenAI Agents SDK is a lightweight Python framework for building production-ready agentic AI applications with minimal abstractions. It provides a streamlined upgrade from the experimental Swarm framework, offering essential primitives like agents with instructions and tools, handoffs for task delegation between agents, guardrails for input/output validation, and sessions for automatic conversation history management. The SDK emphasizes ease of use while maintaining enough power to express complex multi-agent relationships, making it suitable for real-world applications without requiring mastery of complex frameworks.
4
+
5
+ Built on core design principles of simplicity and customization, the SDK includes an automatic agent loop handling tool calls and LLM interactions, Python-first orchestration without new abstractions to learn, built-in tracing for visualization and debugging, and automatic schema generation with Pydantic validation for function tools. It supports multiple model providers through OpenAI's Responses API and Chat Completions API, with native integration for LiteLLM and custom providers. Whether building single-agent assistants or complex multi-agent workflows with specialized roles, the SDK provides the necessary features to move quickly from prototype to production.
6
+
7
+ ## Installation and Setup
8
+
9
+ Install the SDK and configure your environment
10
+
11
+ ```bash
12
+ pip install openai-agents
13
+
14
+ export OPENAI_API_KEY=sk-...
15
+ ```
16
+
17
+ ## Creating a Basic Agent
18
+
19
+ Define an agent with name and instructions
20
+
21
+ ```python
22
+ from agents import Agent, Runner
23
+
24
+ agent = Agent(
25
+ name="Math Tutor",
26
+ instructions="You provide help with math problems. Explain your reasoning at each step and include examples"
27
+ )
28
+
29
+ result = Runner.run_sync(agent, "What is 15% of 80?")
30
+ print(result.final_output)
31
+ # 15% of 80 is 12. To calculate: 0.15 × 80 = 12
32
+ ```
33
+
34
+ ## Running Agents Asynchronously
35
+
36
+ Execute agent with async/await pattern
37
+
38
+ ```python
39
+ import asyncio
40
+ from agents import Agent, Runner
41
+
42
+ async def main():
43
+ agent = Agent(
44
+ name="Assistant",
45
+ instructions="Reply very concisely."
46
+ )
47
+
48
+ result = await Runner.run(agent, "What city is the Golden Gate Bridge in?")
49
+ print(result.final_output)
50
+ # San Francisco
51
+
52
+ asyncio.run(main())
53
+ ```
54
+
55
+ ## Agent with Function Tools
56
+
57
+ Decorate Python functions to create tools with automatic schema generation
58
+
59
+ ```python
60
+ from agents import Agent, Runner, function_tool
61
+ import asyncio
62
+
63
+ @function_tool
64
+ async def get_weather(city: str) -> str:
65
+ """Fetch the weather for a given location.
66
+
67
+ Args:
68
+ city: The city to fetch weather for.
69
+ """
70
+ # In production, call actual weather API
71
+ return f"The weather in {city} is sunny and 72°F"
72
+
73
+ @function_tool
74
+ def calculate_sum(a: int, b: int) -> int:
75
+ """Add two numbers together.
76
+
77
+ Args:
78
+ a: First number
79
+ b: Second number
80
+ """
81
+ return a + b
82
+
83
+ agent = Agent(
84
+ name="Assistant",
85
+ instructions="Use the provided tools to help the user",
86
+ tools=[get_weather, calculate_sum]
87
+ )
88
+
89
+ async def main():
90
+ result = await Runner.run(agent, "What's the weather in Seattle?")
91
+ print(result.final_output)
92
+ # The weather in Seattle is sunny and 72°F
93
+
94
+ asyncio.run(main())
95
+ ```
96
+
97
+ ## Agent with Hosted Tools
98
+
99
+ Use OpenAI's built-in tools for web search and file retrieval
100
+
101
+ ```python
102
+ from agents import Agent, Runner, WebSearchTool, FileSearchTool
103
+ import asyncio
104
+
105
+ agent = Agent(
106
+ name="Research Assistant",
107
+ instructions="Use web search and file search to answer questions thoroughly",
108
+ tools=[
109
+ WebSearchTool(),
110
+ FileSearchTool(
111
+ max_num_results=5,
112
+ vector_store_ids=["vs_abc123"]
113
+ )
114
+ ]
115
+ )
116
+
117
+ async def main():
118
+ result = await Runner.run(
119
+ agent,
120
+ "What are the latest developments in quantum computing?"
121
+ )
122
+ print(result.final_output)
123
+
124
+ asyncio.run(main())
125
+ ```
126
+
127
+ ## Multi-Agent Handoffs
128
+
129
+ Create specialized agents that delegate to each other
130
+
131
+ ```python
132
+ from agents import Agent, Runner
133
+ import asyncio
134
+
135
+ billing_agent = Agent(
136
+ name="Billing Agent",
137
+ handoff_description="Specialist for billing questions and payment issues",
138
+ instructions="You handle billing inquiries. Check account status and process refunds."
139
+ )
140
+
141
+ technical_agent = Agent(
142
+ name="Technical Agent",
143
+ handoff_description="Specialist for technical support and troubleshooting",
144
+ instructions="You handle technical issues. Diagnose problems and provide solutions."
145
+ )
146
+
147
+ triage_agent = Agent(
148
+ name="Triage Agent",
149
+ instructions=(
150
+ "Determine which specialist agent should handle the user's request. "
151
+ "Hand off to the appropriate agent based on the question type."
152
+ ),
153
+ handoffs=[billing_agent, technical_agent]
154
+ )
155
+
156
+ async def main():
157
+ result = await Runner.run(
158
+ triage_agent,
159
+ "I was charged twice for my subscription this month"
160
+ )
161
+ print(result.final_output)
162
+ # Output from billing_agent after handoff
163
+
164
+ asyncio.run(main())
165
+ ```
166
+
167
+ ## Custom Handoff with Input Data
168
+
169
+ Configure handoffs with structured input and callbacks
170
+
171
+ ```python
172
+ from agents import Agent, Runner, handoff, RunContextWrapper
173
+ from pydantic import BaseModel
174
+ import asyncio
175
+
176
+ class EscalationData(BaseModel):
177
+ reason: str
178
+ severity: str
179
+
180
+ async def on_escalation(ctx: RunContextWrapper[None], input_data: EscalationData):
181
+ print(f"Escalated: {input_data.reason} (severity: {input_data.severity})")
182
+ # Log to monitoring system, send alert, etc.
183
+
184
+ escalation_agent = Agent(
185
+ name="Manager",
186
+ instructions="Handle escalated customer issues with priority"
187
+ )
188
+
189
+ support_agent = Agent(
190
+ name="Support Agent",
191
+ instructions="Help customers. Escalate to manager if issue is severe.",
192
+ handoffs=[
193
+ handoff(
194
+ agent=escalation_agent,
195
+ on_handoff=on_escalation,
196
+ input_type=EscalationData,
197
+ tool_description_override="Escalate urgent issues to management"
198
+ )
199
+ ]
200
+ )
201
+
202
+ async def main():
203
+ result = await Runner.run(
204
+ support_agent,
205
+ "This is completely unacceptable! I demand to speak to a manager!"
206
+ )
207
+ print(result.final_output)
208
+
209
+ asyncio.run(main())
210
+ ```
211
+
212
+ ## Input Guardrails
213
+
214
+ Validate user input before processing with the main agent
215
+
216
+ ```python
217
+ from agents import Agent, Runner, input_guardrail, GuardrailFunctionOutput
218
+ from agents import InputGuardrailTripwireTriggered, RunContextWrapper, TResponseInputItem
219
+ from pydantic import BaseModel
220
+ import asyncio
221
+
222
+ class HomeworkCheck(BaseModel):
223
+ is_homework: bool
224
+ reasoning: str
225
+
226
+ guardrail_agent = Agent(
227
+ name="Homework Detector",
228
+ instructions="Determine if the user is asking for homework help",
229
+ output_type=HomeworkCheck
230
+ )
231
+
232
+ @input_guardrail
233
+ async def homework_guardrail(
234
+ ctx: RunContextWrapper[None],
235
+ agent: Agent,
236
+ input_data: str | list[TResponseInputItem]
237
+ ) -> GuardrailFunctionOutput:
238
+ result = await Runner.run(guardrail_agent, input_data, context=ctx.context)
239
+
240
+ return GuardrailFunctionOutput(
241
+ output_info=result.final_output,
242
+ tripwire_triggered=result.final_output.is_homework
243
+ )
244
+
245
+ tutoring_agent = Agent(
246
+ name="Tutoring Service",
247
+ instructions="You help students understand concepts, not do their homework",
248
+ input_guardrails=[homework_guardrail]
249
+ )
250
+
251
+ async def main():
252
+ try:
253
+ result = await Runner.run(
254
+ tutoring_agent,
255
+ "Can you solve this equation for me: 2x + 5 = 15?"
256
+ )
257
+ print(result.final_output)
258
+ except InputGuardrailTripwireTriggered as e:
259
+ print("Request blocked: This appears to be homework help")
260
+
261
+ asyncio.run(main())
262
+ ```
263
+
264
+ ## Output Guardrails
265
+
266
+ Validate agent responses before returning to user
267
+
268
+ ```python
269
+ from agents import Agent, Runner, output_guardrail, GuardrailFunctionOutput
270
+ from agents import OutputGuardrailTripwireTriggered, RunContextWrapper
271
+ from pydantic import BaseModel
272
+ import asyncio
273
+
274
+ class ToxicityCheck(BaseModel):
275
+ is_toxic: bool
276
+ confidence: float
277
+
278
+ class AgentResponse(BaseModel):
279
+ message: str
280
+
281
+ toxicity_checker = Agent(
282
+ name="Toxicity Detector",
283
+ instructions="Analyze if the message contains toxic or harmful content",
284
+ output_type=ToxicityCheck
285
+ )
286
+
287
+ @output_guardrail
288
+ async def toxicity_guardrail(
289
+ ctx: RunContextWrapper[None],
290
+ agent: Agent,
291
+ output: AgentResponse
292
+ ) -> GuardrailFunctionOutput:
293
+ result = await Runner.run(toxicity_checker, output.message, context=ctx.context)
294
+
295
+ return GuardrailFunctionOutput(
296
+ output_info=result.final_output,
297
+ tripwire_triggered=result.final_output.is_toxic and result.final_output.confidence > 0.8
298
+ )
299
+
300
+ chatbot = Agent(
301
+ name="Chatbot",
302
+ instructions="You are a friendly assistant",
303
+ output_guardrails=[toxicity_guardrail],
304
+ output_type=AgentResponse
305
+ )
306
+
307
+ async def main():
308
+ try:
309
+ result = await Runner.run(chatbot, "Tell me about your day")
310
+ print(result.final_output.message)
311
+ except OutputGuardrailTripwireTriggered:
312
+ print("Response blocked by content filter")
313
+
314
+ asyncio.run(main())
315
+ ```
316
+
317
+ ## Sessions for Conversation Memory
318
+
319
+ Automatically maintain conversation history across multiple turns
320
+
321
+ ```python
322
+ from agents import Agent, Runner, SQLiteSession
323
+ import asyncio
324
+
325
+ async def main():
326
+ agent = Agent(
327
+ name="Assistant",
328
+ instructions="Reply concisely and remember previous context"
329
+ )
330
+
331
+ # Create persistent session with SQLite backend
332
+ session = SQLiteSession("user_123", "conversations.db")
333
+
334
+ # First turn
335
+ result = await Runner.run(
336
+ agent,
337
+ "What city is the Golden Gate Bridge in?",
338
+ session=session
339
+ )
340
+ print(result.final_output)
341
+ # San Francisco
342
+
343
+ # Second turn - agent remembers previous context
344
+ result = await Runner.run(
345
+ agent,
346
+ "What state is it in?",
347
+ session=session
348
+ )
349
+ print(result.final_output)
350
+ # California
351
+
352
+ # Third turn - continuing the conversation
353
+ result = await Runner.run(
354
+ agent,
355
+ "What's the population?",
356
+ session=session
357
+ )
358
+ print(result.final_output)
359
+ # Approximately 39 million
360
+
361
+ asyncio.run(main())
362
+ ```
363
+
364
+ ## Session Management Operations
365
+
366
+ Manipulate conversation history programmatically
367
+
368
+ ```python
369
+ from agents import Agent, Runner, SQLiteSession
370
+ import asyncio
371
+
372
+ async def main():
373
+ session = SQLiteSession("conversation_456", "chats.db")
374
+
375
+ # Get all conversation items
376
+ items = await session.get_items()
377
+ print(f"Total messages: {len(items)}")
378
+
379
+ # Add items manually
380
+ await session.add_items([
381
+ {"role": "user", "content": "Hello"},
382
+ {"role": "assistant", "content": "Hi! How can I help?"}
383
+ ])
384
+
385
+ # Remove last item (useful for corrections)
386
+ agent = Agent(name="Assistant")
387
+
388
+ result = await Runner.run(agent, "What's 2 + 2?", session=session)
389
+ print(result.final_output)
390
+
391
+ # User wants to correct their question
392
+ await session.pop_item() # Remove assistant response
393
+ await session.pop_item() # Remove user question
394
+
395
+ result = await Runner.run(agent, "What's 2 + 3?", session=session)
396
+ print(result.final_output)
397
+
398
+ # Clear entire session
399
+ await session.clear_session()
400
+
401
+ asyncio.run(main())
402
+ ```
403
+
404
+ ## OpenAI Conversations Session
405
+
406
+ Use OpenAI-hosted conversation storage
407
+
408
+ ```python
409
+ from agents import Agent, Runner, OpenAIConversationsSession
410
+ import asyncio
411
+
412
+ async def main():
413
+ agent = Agent(name="Assistant")
414
+
415
+ # Create new conversation or resume existing one
416
+ session = OpenAIConversationsSession()
417
+ # Or with existing conversation ID:
418
+ # session = OpenAIConversationsSession(conversation_id="conv_abc123")
419
+
420
+ result = await Runner.run(
421
+ agent,
422
+ "Remember that my favorite color is blue",
423
+ session=session
424
+ )
425
+
426
+ # Later conversation with same session
427
+ result = await Runner.run(
428
+ agent,
429
+ "What's my favorite color?",
430
+ session=session
431
+ )
432
+ print(result.final_output)
433
+ # Your favorite color is blue
434
+
435
+ asyncio.run(main())
436
+ ```
437
+
438
+ ## Structured Outputs
439
+
440
+ Force agents to return specific data types with validation
441
+
442
+ ```python
443
+ from agents import Agent, Runner
444
+ from pydantic import BaseModel
445
+ import asyncio
446
+
447
+ class CalendarEvent(BaseModel):
448
+ title: str
449
+ date: str
450
+ participants: list[str]
451
+ location: str | None = None
452
+
453
+ agent = Agent(
454
+ name="Calendar Parser",
455
+ instructions="Extract calendar event information from text",
456
+ output_type=CalendarEvent
457
+ )
458
+
459
+ async def main():
460
+ text = "Schedule a team meeting on March 15th with John, Sarah, and Mike"
461
+ result = await Runner.run(agent, text)
462
+
463
+ event = result.final_output_as(CalendarEvent)
464
+ print(f"Event: {event.title}")
465
+ print(f"Date: {event.date}")
466
+ print(f"Attendees: {', '.join(event.participants)}")
467
+ # Event: Team Meeting
468
+ # Date: March 15th
469
+ # Attendees: John, Sarah, Mike
470
+
471
+ asyncio.run(main())
472
+ ```
473
+
474
+ ## Agent Context and Dependency Injection
475
+
476
+ Pass custom context objects to agents and tools
477
+
478
+ ```python
479
+ from dataclasses import dataclass
480
+ from agents import Agent, Runner, RunContextWrapper, function_tool
481
+ import asyncio
482
+
483
+ @dataclass
484
+ class UserContext:
485
+ user_id: str
486
+ is_premium: bool
487
+ api_token: str
488
+
489
+ @function_tool
490
+ async def get_user_data(ctx: RunContextWrapper[UserContext]) -> str:
491
+ """Fetch user-specific data using context."""
492
+ user_id = ctx.context.user_id
493
+ is_premium = ctx.context.is_premium
494
+
495
+ if is_premium:
496
+ return f"Premium user {user_id} has access to all features"
497
+ return f"User {user_id} has basic access"
498
+
499
+ agent = Agent[UserContext](
500
+ name="Account Manager",
501
+ instructions="Provide user information based on their account status",
502
+ tools=[get_user_data]
503
+ )
504
+
505
+ async def main():
506
+ context = UserContext(
507
+ user_id="user_789",
508
+ is_premium=True,
509
+ api_token="secret_token"
510
+ )
511
+
512
+ result = await Runner.run(agent, "What's my account status?", context=context)
513
+ print(result.final_output)
514
+
515
+ asyncio.run(main())
516
+ ```
517
+
518
+ ## Dynamic Instructions
519
+
520
+ Generate agent instructions at runtime based on context
521
+
522
+ ```python
523
+ from agents import Agent, Runner, RunContextWrapper
524
+ from dataclasses import dataclass
525
+ import asyncio
526
+
527
+ @dataclass
528
+ class AppContext:
529
+ username: str
530
+ language: str
531
+ timezone: str
532
+
533
+ def dynamic_instructions(
534
+ context: RunContextWrapper[AppContext],
535
+ agent: Agent[AppContext]
536
+ ) -> str:
537
+ user = context.context
538
+ return f"""You are a helpful assistant for {user.username}.
539
+ - Respond in {user.language}
540
+ - Use {user.timezone} timezone for all time references
541
+ - Be friendly and personalized"""
542
+
543
+ agent = Agent[AppContext](
544
+ name="Personal Assistant",
545
+ instructions=dynamic_instructions
546
+ )
547
+
548
+ async def main():
549
+ context = AppContext(
550
+ username="Alice",
551
+ language="Spanish",
552
+ timezone="PST"
553
+ )
554
+
555
+ result = await Runner.run(agent, "What time is it?", context=context)
556
+ print(result.final_output)
557
+
558
+ asyncio.run(main())
559
+ ```
560
+
561
+ ## Streaming Agent Responses
562
+
563
+ Stream token-by-token responses from the agent
564
+
565
+ ```python
566
+ from agents import Agent, Runner
567
+ from openai.types.responses import ResponseTextDeltaEvent
568
+ import asyncio
569
+
570
+ async def main():
571
+ agent = Agent(
572
+ name="Storyteller",
573
+ instructions="Tell engaging short stories"
574
+ )
575
+
576
+ result = Runner.run_streamed(agent, "Tell me a story about a robot")
577
+
578
+ print("Streaming response: ", end="", flush=True)
579
+ async for event in result.stream_events():
580
+ if event.type == "raw_response_event":
581
+ if isinstance(event.data, ResponseTextDeltaEvent):
582
+ print(event.data.delta, end="", flush=True)
583
+
584
+ print("\n\nFinal output:", result.final_output)
585
+
586
+ asyncio.run(main())
587
+ ```
588
+
589
+ ## Streaming with Item-Level Events
590
+
591
+ Stream higher-level events like tool calls and messages
592
+
593
+ ```python
594
+ from agents import Agent, Runner, ItemHelpers, function_tool
595
+ import asyncio
596
+ import random
597
+
598
+ @function_tool
599
+ def roll_dice(sides: int = 6) -> int:
600
+ """Roll a dice with specified number of sides."""
601
+ return random.randint(1, sides)
602
+
603
+ async def main():
604
+ agent = Agent(
605
+ name="Game Master",
606
+ instructions="Use the dice rolling tool when asked",
607
+ tools=[roll_dice]
608
+ )
609
+
610
+ result = Runner.run_streamed(agent, "Roll two dice for me")
611
+
612
+ async for event in result.stream_events():
613
+ if event.type == "raw_response_event":
614
+ continue # Skip token-level events
615
+ elif event.type == "agent_updated_stream_event":
616
+ print(f"Agent: {event.new_agent.name}")
617
+ elif event.type == "run_item_stream_event":
618
+ if event.item.type == "tool_call_item":
619
+ print("🔧 Tool called")
620
+ elif event.item.type == "tool_call_output_item":
621
+ print(f"📤 Tool result: {event.item.output}")
622
+ elif event.item.type == "message_output_item":
623
+ text = ItemHelpers.text_message_output(event.item)
624
+ print(f"💬 Agent: {text}")
625
+
626
+ asyncio.run(main())
627
+ ```
628
+
629
+ ## Agents as Tools Pattern
630
+
631
+ Use specialized agents as tools in a central orchestrator
632
+
633
+ ```python
634
+ from agents import Agent, Runner
635
+ import asyncio
636
+
637
+ translation_agent = Agent(
638
+ name="Translator",
639
+ instructions="Translate the user's message to the specified language"
640
+ )
641
+
642
+ summarization_agent = Agent(
643
+ name="Summarizer",
644
+ instructions="Create a concise summary of the provided text"
645
+ )
646
+
647
+ orchestrator = Agent(
648
+ name="Orchestrator",
649
+ instructions="Use the available tools to process user requests efficiently",
650
+ tools=[
651
+ translation_agent.as_tool(
652
+ tool_name="translate_text",
653
+ tool_description="Translate text to another language"
654
+ ),
655
+ summarization_agent.as_tool(
656
+ tool_name="summarize_text",
657
+ tool_description="Generate a summary of long text"
658
+ )
659
+ ]
660
+ )
661
+
662
+ async def main():
663
+ result = await Runner.run(
664
+ orchestrator,
665
+ "Translate 'Hello, how are you?' to French and Spanish"
666
+ )
667
+ print(result.final_output)
668
+
669
+ asyncio.run(main())
670
+ ```
671
+
672
+ ## Custom Model Configuration
673
+
674
+ Configure model settings and use different models per agent
675
+
676
+ ```python
677
+ from agents import Agent, Runner, ModelSettings
678
+ from openai.types.shared import Reasoning
679
+ import asyncio
680
+
681
+ reasoning_agent = Agent(
682
+ name="Deep Thinker",
683
+ instructions="Analyze complex problems thoroughly",
684
+ model="gpt-5",
685
+ model_settings=ModelSettings(
686
+ reasoning=Reasoning(effort="high"),
687
+ temperature=0.7,
688
+ verbosity="high"
689
+ )
690
+ )
691
+
692
+ fast_agent = Agent(
693
+ name="Quick Responder",
694
+ instructions="Provide rapid responses",
695
+ model="gpt-5-nano",
696
+ model_settings=ModelSettings(
697
+ reasoning=Reasoning(effort="minimal"),
698
+ temperature=0.3,
699
+ verbosity="low"
700
+ )
701
+ )
702
+
703
+ triage_agent = Agent(
704
+ name="Router",
705
+ instructions="Route complex problems to deep thinker, simple ones to quick responder",
706
+ handoffs=[reasoning_agent, fast_agent]
707
+ )
708
+
709
+ async def main():
710
+ result = await Runner.run(
711
+ triage_agent,
712
+ "Explain quantum entanglement in simple terms"
713
+ )
714
+ print(result.final_output)
715
+
716
+ asyncio.run(main())
717
+ ```
718
+
719
+ ## MCP Hosted Tool Integration
720
+
721
+ Use Model Context Protocol servers as hosted tools
722
+
723
+ ```python
724
+ from agents import Agent, Runner, HostedMCPTool
725
+ import asyncio
726
+
727
+ async def main():
728
+ agent = Agent(
729
+ name="Code Assistant",
730
+ instructions="Help with repository questions using git tools",
731
+ tools=[
732
+ HostedMCPTool(
733
+ tool_config={
734
+ "type": "mcp",
735
+ "server_label": "gitmcp",
736
+ "server_url": "https://gitmcp.io/openai/codex",
737
+ "require_approval": "never"
738
+ }
739
+ )
740
+ ]
741
+ )
742
+
743
+ result = await Runner.run(
744
+ agent,
745
+ "What programming languages are used in this repository?"
746
+ )
747
+ print(result.final_output)
748
+
749
+ asyncio.run(main())
750
+ ```
751
+
752
+ ## MCP Server with Streamable HTTP
753
+
754
+ Connect to local or remote MCP servers via HTTP
755
+
756
+ ```python
757
+ from agents import Agent, Runner, ModelSettings
758
+ from agents.mcp import MCPServerStreamableHttp
759
+ import asyncio
760
+ import os
761
+
762
+ async def main():
763
+ token = os.environ["MCP_SERVER_TOKEN"]
764
+
765
+ async with MCPServerStreamableHttp(
766
+ name="Calculator Server",
767
+ params={
768
+ "url": "http://localhost:8000/mcp",
769
+ "headers": {"Authorization": f"Bearer {token}"},
770
+ "timeout": 10
771
+ },
772
+ cache_tools_list=True,
773
+ max_retry_attempts=3
774
+ ) as server:
775
+ agent = Agent(
776
+ name="Math Assistant",
777
+ instructions="Use MCP tools to perform calculations",
778
+ mcp_servers=[server],
779
+ model_settings=ModelSettings(tool_choice="required")
780
+ )
781
+
782
+ result = await Runner.run(agent, "Calculate 47 + 89")
783
+ print(result.final_output)
784
+
785
+ asyncio.run(main())
786
+ ```
787
+
788
+ ## MCP stdio Server
789
+
790
+ Launch local MCP server processes
791
+
792
+ ```python
793
+ from agents import Agent, Runner
794
+ from agents.mcp import MCPServerStdio
795
+ from pathlib import Path
796
+ import asyncio
797
+
798
+ async def main():
799
+ samples_dir = Path(__file__).parent / "sample_files"
800
+
801
+ async with MCPServerStdio(
802
+ name="Filesystem Server",
803
+ params={
804
+ "command": "npx",
805
+ "args": ["-y", "@modelcontextprotocol/server-filesystem", str(samples_dir)]
806
+ }
807
+ ) as server:
808
+ agent = Agent(
809
+ name="File Assistant",
810
+ instructions="Help users work with files in the sample directory",
811
+ mcp_servers=[server]
812
+ )
813
+
814
+ result = await Runner.run(agent, "List all files in the directory")
815
+ print(result.final_output)
816
+
817
+ asyncio.run(main())
818
+ ```
819
+
820
+ ## Tracing and Monitoring
821
+
822
+ Built-in tracing for debugging and monitoring agent workflows
823
+
824
+ ```python
825
+ from agents import Agent, Runner, trace
826
+ import asyncio
827
+
828
+ async def main():
829
+ agent = Agent(
830
+ name="Research Agent",
831
+ instructions="Research topics thoroughly"
832
+ )
833
+
834
+ # Trace multiple runs under single workflow
835
+ with trace(
836
+ workflow_name="Research Workflow",
837
+ group_id="session_123",
838
+ metadata={"user": "alice", "environment": "production"}
839
+ ):
840
+ result1 = await Runner.run(agent, "What is machine learning?")
841
+ print(f"Response 1: {result1.final_output}")
842
+
843
+ result2 = await Runner.run(agent, "Explain neural networks")
844
+ print(f"Response 2: {result2.final_output}")
845
+
846
+ # View traces at: https://platform.openai.com/traces
847
+
848
+ asyncio.run(main())
849
+ ```
850
+
851
+ ## Error Handling
852
+
853
+ Handle exceptions from agent runs, guardrails, and tool failures
854
+
855
+ ```python
856
+ from agents import Agent, Runner, function_tool
857
+ from agents.exceptions import (
858
+ MaxTurnsExceeded,
859
+ InputGuardrailTripwireTriggered,
860
+ ModelBehaviorError
861
+ )
862
+ import asyncio
863
+
864
+ @function_tool
865
+ def risky_operation() -> str:
866
+ """An operation that might fail."""
867
+ raise ValueError("Operation failed!")
868
+
869
+ agent = Agent(
870
+ name="Assistant",
871
+ instructions="Help users with tasks",
872
+ tools=[risky_operation]
873
+ )
874
+
875
+ async def main():
876
+ try:
877
+ result = await Runner.run(
878
+ agent,
879
+ "Run the risky operation",
880
+ max_turns=5
881
+ )
882
+ print(result.final_output)
883
+
884
+ except MaxTurnsExceeded:
885
+ print("Error: Agent exceeded maximum turns")
886
+ except InputGuardrailTripwireTriggered as e:
887
+ print(f"Error: Input blocked by guardrail: {e}")
888
+ except ModelBehaviorError as e:
889
+ print(f"Error: Model produced invalid output: {e}")
890
+ except Exception as e:
891
+ print(f"Unexpected error: {e}")
892
+
893
+ asyncio.run(main())
894
+ ```
895
+
896
+ ## Using Alternative Model Providers
897
+
898
+ Integrate non-OpenAI models via LiteLLM
899
+
900
+ ```bash
901
+ pip install "openai-agents[litellm]"
902
+ ```
903
+
904
+ ```python
905
+ from agents import Agent, Runner
906
+ import asyncio
907
+
908
+ async def main():
909
+ # Use Claude via LiteLLM
910
+ claude_agent = Agent(
911
+ name="Claude Assistant",
912
+ instructions="You are a helpful assistant",
913
+ model="litellm/anthropic/claude-3-5-sonnet-20240620"
914
+ )
915
+
916
+ # Use Gemini via LiteLLM
917
+ gemini_agent = Agent(
918
+ name="Gemini Assistant",
919
+ instructions="You are a helpful assistant",
920
+ model="litellm/gemini/gemini-2.5-flash-preview-04-17"
921
+ )
922
+
923
+ result = await Runner.run(claude_agent, "Explain photosynthesis briefly")
924
+ print(result.final_output)
925
+
926
+ asyncio.run(main())
927
+ ```
928
+
929
+ ---
930
+
931
+ ## Summary
932
+
933
+ The OpenAI Agents SDK provides a comprehensive yet simple framework for building agentic AI applications in Python. Core use cases include single-agent assistants with tool access, multi-agent systems with specialized roles using handoffs, conversational applications with automatic session memory, and workflows with input/output validation via guardrails. The SDK excels at building customer service bots with agent routing, research assistants with web search and file retrieval, code generation tools with MCP integration, and any application requiring LLM orchestration with minimal boilerplate.
934
+
935
+ Integration patterns follow Python-first principles using native async/await, context managers for resource handling, decorators for function tools, and Pydantic models for structured outputs. The framework supports horizontal scaling through session persistence with SQLite or SQLAlchemy backends, vertical scaling with model mixing (fast models for triage, powerful models for complex tasks), and comprehensive observability through built-in tracing to OpenAI's dashboard or custom processors. Whether building prototypes or production systems, the SDK's balance of simplicity and power makes it an ideal choice for Python developers working with AI agents.
uv.lock ADDED
The diff for this file is too large to render. See raw diff