Spaces:
Sleeping
Sleeping
Claude commited on
Commit ·
7afbd6c
1
Parent(s): 4dec3fa
feat: Deploy Physical AI & Humanoid Robotics RAG backend
Browse files- Add RAG backend with Cohere and Qdrant integration
- Include API endpoints for chat functionality
- Configure Dockerfile for Hugging Face Spaces deployment
- Add proper health check endpoint
- Set up requirements and dependencies
- .env.example +26 -0
- .gitignore +26 -0
- Dockerfile +44 -0
- README.md +76 -10
- agent.py +254 -0
- api.py +215 -0
- app.py +8 -0
- main.py +322 -0
- pyproject.toml +27 -0
- requirements.txt +22 -0
- retrieving.py +267 -0
- sdk.md +935 -0
- uv.lock +0 -0
.env.example
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenRouter API Configuration
|
| 2 |
+
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
| 3 |
+
|
| 4 |
+
# Qdrant Vector Database Configuration
|
| 5 |
+
QDRANT_URL=your_qdrant_url_here
|
| 6 |
+
QDRANT_API_KEY=your_qdrant_api_key_here
|
| 7 |
+
QDRANT_CLUSTER_ID=your_qdrant_cluster_id_here
|
| 8 |
+
|
| 9 |
+
# Neon PostgreSQL Database Configuration
|
| 10 |
+
NEON_DATABASE_URL=your_neon_database_url_here
|
| 11 |
+
|
| 12 |
+
# Cohere API Key (if needed)
|
| 13 |
+
COHERE_API_KEY=your_cohere_api_key_here
|
| 14 |
+
|
| 15 |
+
# Backend API Key
|
| 16 |
+
BACKEND_API_KEY=your_backend_api_key_here
|
| 17 |
+
|
| 18 |
+
# Target URL for Docusaurus site
|
| 19 |
+
TARGET_URL=your_vercel_url_here
|
| 20 |
+
|
| 21 |
+
# Application Configuration
|
| 22 |
+
DEBUG=False
|
| 23 |
+
LOG_LEVEL=INFO
|
| 24 |
+
MAX_CONTENT_LENGTH=5000
|
| 25 |
+
RATE_LIMIT_REQUESTS=100
|
| 26 |
+
RATE_LIMIT_WINDOW=60
|
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
env/
|
| 7 |
+
venv/
|
| 8 |
+
.venv/
|
| 9 |
+
pip-log.txt
|
| 10 |
+
pip-delete-this-directory.txt
|
| 11 |
+
.tox/
|
| 12 |
+
.coverage
|
| 13 |
+
.coverage.*
|
| 14 |
+
.cache
|
| 15 |
+
nosetests.xml
|
| 16 |
+
coverage.xml
|
| 17 |
+
*.cover
|
| 18 |
+
*.log
|
| 19 |
+
.git/
|
| 20 |
+
.DS_Store
|
| 21 |
+
.DS_Store?
|
| 22 |
+
._*
|
| 23 |
+
.Spotlight-V100
|
| 24 |
+
.Trashes
|
| 25 |
+
ehthumbs.db
|
| 26 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.11 slim image as base
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PYTHONUNBUFFERED=1 \
|
| 7 |
+
PYTHONPATH=/app \
|
| 8 |
+
PORT=7860
|
| 9 |
+
|
| 10 |
+
# Set work directory
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
# Install system dependencies
|
| 14 |
+
RUN apt-get update \
|
| 15 |
+
&& apt-get install -y --no-install-recommends \
|
| 16 |
+
build-essential \
|
| 17 |
+
gcc \
|
| 18 |
+
curl \
|
| 19 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 20 |
+
|
| 21 |
+
# Copy requirements first to leverage Docker cache
|
| 22 |
+
COPY requirements.txt .
|
| 23 |
+
|
| 24 |
+
# Install Python dependencies
|
| 25 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 26 |
+
&& pip install --no-cache-dir -r requirements.txt
|
| 27 |
+
|
| 28 |
+
# Copy the rest of the application
|
| 29 |
+
COPY . .
|
| 30 |
+
|
| 31 |
+
# Create a non-root user and set permissions
|
| 32 |
+
RUN adduser --disabled-password --gecos '' appuser \
|
| 33 |
+
&& chown -R appuser:appuser /app
|
| 34 |
+
USER appuser
|
| 35 |
+
|
| 36 |
+
# Expose port (Hugging Face typically uses port 7860 or 8080)
|
| 37 |
+
EXPOSE $PORT
|
| 38 |
+
|
| 39 |
+
# Health check endpoint
|
| 40 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 41 |
+
CMD curl -f http://localhost:$PORT/health || exit 1
|
| 42 |
+
|
| 43 |
+
# Run the application with uvicorn directly for better production performance
|
| 44 |
+
CMD ["sh", "-c", "python app.py"]
|
README.md
CHANGED
|
@@ -1,10 +1,76 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docusaurus Embedding Pipeline
|
| 2 |
+
|
| 3 |
+
This project extracts text from deployed Docusaurus URLs, generates embeddings using Cohere, and stores them in Qdrant for RAG-based retrieval.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- Crawls Docusaurus sites to extract all accessible URLs
|
| 8 |
+
- Extracts and cleans text content from each page
|
| 9 |
+
- Chunks large documents to optimize embedding quality
|
| 10 |
+
- Generates vector embeddings using Cohere's API
|
| 11 |
+
- Stores embeddings in Qdrant vector database with metadata
|
| 12 |
+
- Supports similarity search for RAG applications
|
| 13 |
+
|
| 14 |
+
## Prerequisites
|
| 15 |
+
|
| 16 |
+
- Python 3.9+
|
| 17 |
+
- UV package manager (`pip install uv`)
|
| 18 |
+
- Cohere API key
|
| 19 |
+
- Qdrant instance (local or cloud)
|
| 20 |
+
|
| 21 |
+
## Setup
|
| 22 |
+
|
| 23 |
+
1. Clone the repository and navigate to the backend directory
|
| 24 |
+
2. Install UV package manager:
|
| 25 |
+
```bash
|
| 26 |
+
pip install uv
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
3. Install dependencies:
|
| 30 |
+
```bash
|
| 31 |
+
cd backend
|
| 32 |
+
uv sync # or uv pip install -r requirements.txt
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
4. Set up environment variables:
|
| 36 |
+
```bash
|
| 37 |
+
cp .env.example .env
|
| 38 |
+
# Edit .env with your Cohere API key and Qdrant configuration
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Configuration
|
| 42 |
+
|
| 43 |
+
The pipeline can be configured via environment variables in the `.env` file:
|
| 44 |
+
|
| 45 |
+
- `COHERE_API_KEY`: Your Cohere API key
|
| 46 |
+
- `QDRANT_URL`: URL to your Qdrant instance
|
| 47 |
+
- `QDRANT_API_KEY`: API key for Qdrant (if required)
|
| 48 |
+
- `TARGET_URL`: The Docusaurus site to process
|
| 49 |
+
|
| 50 |
+
## Usage
|
| 51 |
+
|
| 52 |
+
Run the complete pipeline:
|
| 53 |
+
```bash
|
| 54 |
+
uv run main.py
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Architecture
|
| 58 |
+
|
| 59 |
+
The pipeline consists of these main functions:
|
| 60 |
+
|
| 61 |
+
1. `get_all_urls()` - Extracts all URLs from the target Docusaurus site
|
| 62 |
+
2. `extract_text_from_url()` - Cleans and extracts text content from a URL
|
| 63 |
+
3. `chunk_text()` - Splits large documents into manageable chunks
|
| 64 |
+
4. `embed()` - Generates vector embeddings using Cohere
|
| 65 |
+
5. `create_collection()` - Sets up the Qdrant collection
|
| 66 |
+
6. `save_chunk_to_qdrant()` - Stores embeddings with metadata in Qdrant
|
| 67 |
+
|
| 68 |
+
The main function orchestrates the complete workflow from crawling to storage.
|
| 69 |
+
|
| 70 |
+
## Output
|
| 71 |
+
|
| 72 |
+
The pipeline stores document chunks as vectors in a Qdrant collection named "rag_embedding" with the following metadata:
|
| 73 |
+
- Content text
|
| 74 |
+
- Source URL
|
| 75 |
+
- Position in original document
|
| 76 |
+
- Creation timestamp
|
agent.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, List, Any
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
# Load environment variables
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
def retrieve_information(query: str, top_k: int = 5, threshold: float = 0.3) -> Dict:
|
| 17 |
+
"""
|
| 18 |
+
Retrieve information from the knowledge base based on a query
|
| 19 |
+
"""
|
| 20 |
+
from retrieving import RAGRetriever
|
| 21 |
+
retriever = RAGRetriever()
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
# Call the existing retrieve method from the RAGRetriever instance
|
| 25 |
+
json_response = retriever.retrieve(query_text=query, top_k=top_k, threshold=threshold)
|
| 26 |
+
results = json.loads(json_response)
|
| 27 |
+
|
| 28 |
+
# Format the results for the assistant
|
| 29 |
+
formatted_results = []
|
| 30 |
+
for result in results.get('results', []):
|
| 31 |
+
formatted_results.append({
|
| 32 |
+
'content': result['content'],
|
| 33 |
+
'url': result['url'],
|
| 34 |
+
'position': result['position'],
|
| 35 |
+
'similarity_score': result['similarity_score'],
|
| 36 |
+
'chunk_id': result.get('chunk_id', ''),
|
| 37 |
+
'created_at': result.get('created_at', '')
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
return {
|
| 41 |
+
'query': query,
|
| 42 |
+
'retrieved_chunks': formatted_results,
|
| 43 |
+
'total_results': len(formatted_results),
|
| 44 |
+
'metadata': results.get('metadata', {})
|
| 45 |
+
}
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.error(f"Error in retrieve_information: {e}")
|
| 48 |
+
return {
|
| 49 |
+
'query': query,
|
| 50 |
+
'retrieved_chunks': [],
|
| 51 |
+
'total_results': 0,
|
| 52 |
+
'error': str(e),
|
| 53 |
+
'metadata': {}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
class RAGAgent:
|
| 57 |
+
def __init__(self):
|
| 58 |
+
# Initialize the RAG system components
|
| 59 |
+
# For now, we'll use the retrieval function directly
|
| 60 |
+
# In a real implementation, you would initialize your existing RAG components
|
| 61 |
+
logger.info("RAG Agent initialized with retrieval and generation components")
|
| 62 |
+
|
| 63 |
+
def query_agent(self, query_text: str, session_id: str = None, query_type: str = "global", selected_text: str = None) -> Dict:
|
| 64 |
+
"""
|
| 65 |
+
Process a query through the RAG system and return structured response
|
| 66 |
+
"""
|
| 67 |
+
start_time = time.time()
|
| 68 |
+
|
| 69 |
+
logger.info(f"Processing query through RAG system: '{query_text[:50]}...'")
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
# Retrieve relevant information using our retrieval system
|
| 73 |
+
retrieval_result = retrieve_information(query_text, top_k=5, threshold=0.3)
|
| 74 |
+
|
| 75 |
+
if retrieval_result.get('error'):
|
| 76 |
+
return {
|
| 77 |
+
"answer": "Sorry, I encountered an error retrieving information.",
|
| 78 |
+
"sources": [],
|
| 79 |
+
"matched_chunks": [],
|
| 80 |
+
"citations": [],
|
| 81 |
+
"error": retrieval_result['error'],
|
| 82 |
+
"query_time_ms": (time.time() - start_time) * 1000,
|
| 83 |
+
"session_id": session_id,
|
| 84 |
+
"query_type": query_type
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# Format the retrieved information for response generation
|
| 88 |
+
# In a real implementation, you would connect this to your response generator
|
| 89 |
+
retrieved_chunks = retrieval_result.get('retrieved_chunks', [])
|
| 90 |
+
|
| 91 |
+
if not retrieved_chunks:
|
| 92 |
+
return {
|
| 93 |
+
"answer": "I couldn't find relevant information in the Physical AI & Humanoid Robotics curriculum to answer your question. Please try asking about specific topics from the curriculum like ROS 2, Digital Twins, AI-Brain, or VLA.",
|
| 94 |
+
"sources": [],
|
| 95 |
+
"matched_chunks": [],
|
| 96 |
+
"citations": [],
|
| 97 |
+
"error": None,
|
| 98 |
+
"query_time_ms": (time.time() - start_time) * 1000,
|
| 99 |
+
"session_id": session_id,
|
| 100 |
+
"query_type": query_type
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Generate a response based on the retrieved information
|
| 104 |
+
# For now, we'll create a simple response based on the retrieved chunks
|
| 105 |
+
answer_parts = ["Based on the Physical AI & Humanoid Robotics curriculum:"]
|
| 106 |
+
|
| 107 |
+
# Include content from the most relevant chunks
|
| 108 |
+
for i, chunk in enumerate(retrieved_chunks[:2]): # Use top 2 chunks
|
| 109 |
+
content = chunk.get('content', '')[:300] # Limit content length
|
| 110 |
+
answer_parts.append(f"{content}...")
|
| 111 |
+
|
| 112 |
+
answer = " ".join(answer_parts)
|
| 113 |
+
|
| 114 |
+
# Create citations from the retrieved chunks
|
| 115 |
+
citations = []
|
| 116 |
+
for chunk in retrieved_chunks:
|
| 117 |
+
citation = {
|
| 118 |
+
"document_id": chunk.get('chunk_id', ''),
|
| 119 |
+
"title": chunk.get('url', ''),
|
| 120 |
+
"chapter": "",
|
| 121 |
+
"section": "",
|
| 122 |
+
"page_reference": ""
|
| 123 |
+
}
|
| 124 |
+
citations.append(citation)
|
| 125 |
+
|
| 126 |
+
# Calculate query time
|
| 127 |
+
query_time_ms = (time.time() - start_time) * 1000
|
| 128 |
+
|
| 129 |
+
# Format the response
|
| 130 |
+
response = {
|
| 131 |
+
"answer": answer,
|
| 132 |
+
"sources": [chunk.get('url', '') for chunk in retrieved_chunks if chunk.get('url')],
|
| 133 |
+
"matched_chunks": retrieved_chunks,
|
| 134 |
+
"citations": citations,
|
| 135 |
+
"query_time_ms": query_time_ms,
|
| 136 |
+
"session_id": session_id,
|
| 137 |
+
"query_type": query_type,
|
| 138 |
+
"confidence": self._calculate_confidence(retrieved_chunks),
|
| 139 |
+
"error": None
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
logger.info(f"Query processed in {query_time_ms:.2f}ms")
|
| 143 |
+
return response
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Error processing query: {e}")
|
| 147 |
+
return {
|
| 148 |
+
"answer": "Sorry, I encountered an error processing your request.",
|
| 149 |
+
"sources": [],
|
| 150 |
+
"matched_chunks": [],
|
| 151 |
+
"citations": [],
|
| 152 |
+
"error": str(e),
|
| 153 |
+
"query_time_ms": (time.time() - start_time) * 1000,
|
| 154 |
+
"session_id": session_id,
|
| 155 |
+
"query_type": query_type
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
def _calculate_confidence(self, matched_chunks: List[Dict]) -> str:
|
| 159 |
+
"""
|
| 160 |
+
Calculate confidence level based on similarity scores and number of matches
|
| 161 |
+
"""
|
| 162 |
+
if not matched_chunks:
|
| 163 |
+
return "low"
|
| 164 |
+
|
| 165 |
+
avg_score = sum(chunk.get('similarity_score', 0.0) for chunk in matched_chunks) / len(matched_chunks)
|
| 166 |
+
|
| 167 |
+
if avg_score >= 0.7:
|
| 168 |
+
return "high"
|
| 169 |
+
elif avg_score >= 0.4:
|
| 170 |
+
return "medium"
|
| 171 |
+
else:
|
| 172 |
+
return "low"
|
| 173 |
+
|
| 174 |
+
def query_agent(query_text: str) -> Dict:
|
| 175 |
+
"""
|
| 176 |
+
Convenience function to query the RAG agent
|
| 177 |
+
"""
|
| 178 |
+
agent = RAGAgent()
|
| 179 |
+
return agent.query_agent(query_text)
|
| 180 |
+
|
| 181 |
+
def run_agent_sync(query_text: str) -> Dict:
|
| 182 |
+
"""
|
| 183 |
+
Synchronous function to run the agent for direct usage
|
| 184 |
+
"""
|
| 185 |
+
import asyncio
|
| 186 |
+
|
| 187 |
+
async def run_async():
|
| 188 |
+
agent = RAGAgent()
|
| 189 |
+
return await agent._async_query_agent(query_text)
|
| 190 |
+
|
| 191 |
+
# Check if there's already a running event loop
|
| 192 |
+
try:
|
| 193 |
+
loop = asyncio.get_running_loop()
|
| 194 |
+
# If there's already a loop, run in a separate thread
|
| 195 |
+
import concurrent.futures
|
| 196 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 197 |
+
future = executor.submit(asyncio.run, run_async())
|
| 198 |
+
return future.result()
|
| 199 |
+
except RuntimeError:
|
| 200 |
+
# No running loop, safe to use asyncio.run
|
| 201 |
+
return asyncio.run(run_async())
|
| 202 |
+
|
| 203 |
+
def main():
|
| 204 |
+
"""
|
| 205 |
+
Main function to demonstrate the RAG agent functionality
|
| 206 |
+
"""
|
| 207 |
+
logger.info("Initializing RAG Agent...")
|
| 208 |
+
|
| 209 |
+
# Initialize the agent
|
| 210 |
+
agent = RAGAgent()
|
| 211 |
+
|
| 212 |
+
# Example queries to test the system
|
| 213 |
+
test_queries = [
|
| 214 |
+
"What is ROS2?",
|
| 215 |
+
"Explain humanoid design principles",
|
| 216 |
+
"How does VLA work?",
|
| 217 |
+
"What are simulation techniques?",
|
| 218 |
+
"Explain AI control systems"
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
print("RAG Agent - Testing Queries")
|
| 222 |
+
print("=" * 50)
|
| 223 |
+
|
| 224 |
+
for i, query in enumerate(test_queries, 1):
|
| 225 |
+
print(f"\nQuery {i}: {query}")
|
| 226 |
+
print("-" * 30)
|
| 227 |
+
|
| 228 |
+
# Process query through agent
|
| 229 |
+
response = agent.query_agent(query)
|
| 230 |
+
|
| 231 |
+
# Print formatted results
|
| 232 |
+
print(f"Answer: {response['answer']}")
|
| 233 |
+
|
| 234 |
+
if response.get('sources'):
|
| 235 |
+
print(f"Sources: {len(response['sources'])} documents")
|
| 236 |
+
for source in response['sources'][:3]: # Show first 3 sources
|
| 237 |
+
print(f" - {source}")
|
| 238 |
+
|
| 239 |
+
if response.get('matched_chunks'):
|
| 240 |
+
print(f"Matched chunks: {len(response['matched_chunks'])}")
|
| 241 |
+
for j, chunk in enumerate(response['matched_chunks'][:2], 1): # Show first 2 chunks
|
| 242 |
+
content_preview = chunk['content'][:100] + "..." if len(chunk['content']) > 100 else chunk['content']
|
| 243 |
+
print(f" Chunk {j}: {content_preview}")
|
| 244 |
+
print(f" Source: {chunk['url']}")
|
| 245 |
+
print(f" Score: {chunk['similarity_score']:.3f}")
|
| 246 |
+
|
| 247 |
+
print(f"Query time: {response['query_time_ms']:.2f}ms")
|
| 248 |
+
print(f"Confidence: {response.get('confidence', 'unknown')}")
|
| 249 |
+
|
| 250 |
+
if i < len(test_queries): # Don't sleep after the last query
|
| 251 |
+
time.sleep(1) # Small delay between queries
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main()
|
api.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
from fastapi import FastAPI, HTTPException
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import List, Optional, Dict
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Load environment variables
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Import the existing RAG agent functionality
|
| 18 |
+
from agent import RAGAgent
|
| 19 |
+
|
| 20 |
+
# Create FastAPI app
|
| 21 |
+
app = FastAPI(
|
| 22 |
+
title="RAG Agent API",
|
| 23 |
+
description="API for RAG Agent with document retrieval and question answering",
|
| 24 |
+
version="1.0.0"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Add CORS middleware for development
|
| 28 |
+
app.add_middleware(
|
| 29 |
+
CORSMiddleware,
|
| 30 |
+
allow_origins=["*"], # In production, replace with specific origins
|
| 31 |
+
allow_credentials=True,
|
| 32 |
+
allow_methods=["*"],
|
| 33 |
+
allow_headers=["*"],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Pydantic models
|
| 37 |
+
class QueryRequest(BaseModel):
|
| 38 |
+
query: str
|
| 39 |
+
|
| 40 |
+
class ChatRequest(BaseModel):
|
| 41 |
+
query: str
|
| 42 |
+
message: str
|
| 43 |
+
session_id: str
|
| 44 |
+
selected_text: Optional[str] = None
|
| 45 |
+
query_type: str = "global"
|
| 46 |
+
top_k: int = 5
|
| 47 |
+
|
| 48 |
+
class MatchedChunk(BaseModel):
|
| 49 |
+
content: str
|
| 50 |
+
url: str
|
| 51 |
+
position: int
|
| 52 |
+
similarity_score: float
|
| 53 |
+
|
| 54 |
+
class QueryResponse(BaseModel):
|
| 55 |
+
answer: str
|
| 56 |
+
sources: List[str]
|
| 57 |
+
matched_chunks: List[MatchedChunk]
|
| 58 |
+
error: Optional[str] = None
|
| 59 |
+
status: str # "success", "error", "empty"
|
| 60 |
+
query_time_ms: Optional[float] = None
|
| 61 |
+
confidence: Optional[str] = None
|
| 62 |
+
|
| 63 |
+
class ChatResponse(BaseModel):
|
| 64 |
+
response: str
|
| 65 |
+
citations: List[Dict[str, str]]
|
| 66 |
+
session_id: str
|
| 67 |
+
query_type: str
|
| 68 |
+
timestamp: str
|
| 69 |
+
|
| 70 |
+
class HealthResponse(BaseModel):
|
| 71 |
+
status: str
|
| 72 |
+
message: str
|
| 73 |
+
|
| 74 |
+
# Global RAG agent instance
|
| 75 |
+
rag_agent = None
|
| 76 |
+
|
| 77 |
+
@app.on_event("startup")
|
| 78 |
+
async def startup_event():
|
| 79 |
+
"""Initialize the RAG agent on startup"""
|
| 80 |
+
global rag_agent
|
| 81 |
+
logger.info("Initializing RAG Agent...")
|
| 82 |
+
try:
|
| 83 |
+
rag_agent = RAGAgent()
|
| 84 |
+
logger.info("RAG Agent initialized successfully")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Failed to initialize RAG Agent: {e}")
|
| 87 |
+
raise
|
| 88 |
+
|
| 89 |
+
@app.post("/ask", response_model=QueryResponse)
|
| 90 |
+
async def ask_rag(request: QueryRequest):
|
| 91 |
+
"""
|
| 92 |
+
Process a user query through the RAG agent and return the response
|
| 93 |
+
"""
|
| 94 |
+
logger.info(f"Processing query: {request.query[:50]}...")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
# Validate input
|
| 98 |
+
if not request.query or len(request.query.strip()) == 0:
|
| 99 |
+
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
| 100 |
+
|
| 101 |
+
if len(request.query) > 2000:
|
| 102 |
+
raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters")
|
| 103 |
+
|
| 104 |
+
# Process query through RAG agent
|
| 105 |
+
response = rag_agent.query_agent(request.query)
|
| 106 |
+
|
| 107 |
+
# Format response
|
| 108 |
+
formatted_response = QueryResponse(
|
| 109 |
+
answer=response.get("answer", ""),
|
| 110 |
+
sources=response.get("sources", []),
|
| 111 |
+
matched_chunks=[
|
| 112 |
+
MatchedChunk(
|
| 113 |
+
content=chunk.get("content", ""),
|
| 114 |
+
url=chunk.get("url", ""),
|
| 115 |
+
position=chunk.get("position", 0),
|
| 116 |
+
similarity_score=chunk.get("similarity_score", 0.0)
|
| 117 |
+
)
|
| 118 |
+
for chunk in response.get("matched_chunks", [])
|
| 119 |
+
],
|
| 120 |
+
error=response.get("error"),
|
| 121 |
+
status="error" if response.get("error") else "success",
|
| 122 |
+
query_time_ms=response.get("query_time_ms"),
|
| 123 |
+
confidence=response.get("confidence")
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
logger.info(f"Query processed successfully in {response.get('query_time_ms', 0):.2f}ms")
|
| 127 |
+
return formatted_response
|
| 128 |
+
|
| 129 |
+
except HTTPException:
|
| 130 |
+
raise
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Error processing query: {e}")
|
| 133 |
+
return QueryResponse(
|
| 134 |
+
answer="",
|
| 135 |
+
sources=[],
|
| 136 |
+
matched_chunks=[],
|
| 137 |
+
error=str(e),
|
| 138 |
+
status="error"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
@app.post("/api", response_model=ChatResponse)
|
| 142 |
+
async def chat_endpoint(request: ChatRequest):
|
| 143 |
+
"""
|
| 144 |
+
Main chat endpoint that handles conversation with RAG capabilities
|
| 145 |
+
"""
|
| 146 |
+
logger.info(f"Processing chat query: {request.query[:50]}...")
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
# Validate input
|
| 150 |
+
if not request.query or len(request.query.strip()) == 0:
|
| 151 |
+
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
| 152 |
+
|
| 153 |
+
if not request.session_id or len(request.session_id.strip()) == 0:
|
| 154 |
+
raise HTTPException(status_code=400, detail="Session ID cannot be empty")
|
| 155 |
+
|
| 156 |
+
if len(request.query) > 2000:
|
| 157 |
+
raise HTTPException(status_code=400, detail="Query too long, maximum 2000 characters")
|
| 158 |
+
|
| 159 |
+
# Process query through RAG agent
|
| 160 |
+
response = rag_agent.query_agent(request.query)
|
| 161 |
+
|
| 162 |
+
# Format response to match expected structure
|
| 163 |
+
from datetime import datetime
|
| 164 |
+
timestamp = datetime.utcnow().isoformat()
|
| 165 |
+
|
| 166 |
+
# Convert matched chunks to citations format
|
| 167 |
+
citations = []
|
| 168 |
+
for chunk in response.get("matched_chunks", []):
|
| 169 |
+
citation = {
|
| 170 |
+
"document_id": "",
|
| 171 |
+
"title": chunk.get("url", ""),
|
| 172 |
+
"chapter": "",
|
| 173 |
+
"section": "",
|
| 174 |
+
"page_reference": ""
|
| 175 |
+
}
|
| 176 |
+
citations.append(citation)
|
| 177 |
+
|
| 178 |
+
formatted_response = ChatResponse(
|
| 179 |
+
response=response.get("answer", ""),
|
| 180 |
+
citations=citations,
|
| 181 |
+
session_id=request.session_id,
|
| 182 |
+
query_type=request.query_type,
|
| 183 |
+
timestamp=timestamp
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
logger.info(f"Chat query processed successfully")
|
| 187 |
+
return formatted_response
|
| 188 |
+
|
| 189 |
+
except HTTPException:
|
| 190 |
+
raise
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Error processing chat query: {e}")
|
| 193 |
+
from datetime import datetime
|
| 194 |
+
return ChatResponse(
|
| 195 |
+
response="",
|
| 196 |
+
citations=[],
|
| 197 |
+
session_id=request.session_id,
|
| 198 |
+
query_type=request.query_type,
|
| 199 |
+
timestamp=datetime.utcnow().isoformat()
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
@app.get("/health", response_model=HealthResponse)
|
| 203 |
+
async def health_check():
|
| 204 |
+
"""
|
| 205 |
+
Health check endpoint
|
| 206 |
+
"""
|
| 207 |
+
return HealthResponse(
|
| 208 |
+
status="healthy",
|
| 209 |
+
message="RAG Agent API is running"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# For running with uvicorn
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
import uvicorn
|
| 215 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
app.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py - Entry point for Hugging Face Spaces
|
| 2 |
+
import os
|
| 3 |
+
import uvicorn
|
| 4 |
+
from api import app
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
port = int(os.environ.get("PORT", 7860))
|
| 8 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
main.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
from bs4 import BeautifulSoup
|
| 4 |
+
import xml.etree.ElementTree as ET
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
import cohere
|
| 7 |
+
from qdrant_client import QdrantClient
|
| 8 |
+
from qdrant_client.http import models
|
| 9 |
+
from qdrant_client.models import PointStruct
|
| 10 |
+
import logging
|
| 11 |
+
from urllib.parse import urljoin, urlparse
|
| 12 |
+
import time
|
| 13 |
+
import uuid
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
# Load environment variables
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
class DocusaurusEmbeddingPipeline:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
# Initialize Cohere client
|
| 26 |
+
self.cohere_client = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
|
| 27 |
+
|
| 28 |
+
# Initialize Qdrant client
|
| 29 |
+
qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")
|
| 30 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 31 |
+
|
| 32 |
+
if qdrant_api_key:
|
| 33 |
+
self.qdrant_client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
|
| 34 |
+
else:
|
| 35 |
+
self.qdrant_client = QdrantClient(url=qdrant_url)
|
| 36 |
+
|
| 37 |
+
# Target URL for the Docusaurus site - configurable via environment variable
|
| 38 |
+
self.target_url = os.getenv("TARGET_URL", "https://your-vercel-url.vercel.app/")
|
| 39 |
+
|
| 40 |
+
def get_all_urls(self, base_url: str) -> List[str]:
|
| 41 |
+
"""
|
| 42 |
+
Extract all URLs from a deployed Docusaurus site using sitemap
|
| 43 |
+
"""
|
| 44 |
+
urls = []
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# Try to get URLs from sitemap first
|
| 48 |
+
sitemap_url = urljoin(base_url, "sitemap.xml")
|
| 49 |
+
response = requests.get(sitemap_url)
|
| 50 |
+
|
| 51 |
+
if response.status_code == 200:
|
| 52 |
+
root = ET.fromstring(response.content)
|
| 53 |
+
|
| 54 |
+
# Handle both sitemap index and regular sitemap
|
| 55 |
+
if root.tag.endswith('sitemapindex'):
|
| 56 |
+
# If it's a sitemap index, get individual sitemaps
|
| 57 |
+
for sitemap in root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'):
|
| 58 |
+
sitemap_response = requests.get(sitemap.text)
|
| 59 |
+
if sitemap_response.status_code == 200:
|
| 60 |
+
sitemap_root = ET.fromstring(sitemap_response.content)
|
| 61 |
+
for url_elem in sitemap_root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'):
|
| 62 |
+
urls.append(url_elem.text)
|
| 63 |
+
else:
|
| 64 |
+
# Regular sitemap
|
| 65 |
+
for url_elem in root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'):
|
| 66 |
+
urls.append(url_elem.text)
|
| 67 |
+
else:
|
| 68 |
+
# Fallback: try to crawl the site by looking for links
|
| 69 |
+
logger.info(f"Sitemap not found at {sitemap_url}, attempting to crawl...")
|
| 70 |
+
|
| 71 |
+
# Get the main page and extract links
|
| 72 |
+
response = requests.get(base_url)
|
| 73 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
| 74 |
+
|
| 75 |
+
# Find all links within the page
|
| 76 |
+
for link in soup.find_all('a', href=True):
|
| 77 |
+
href = link['href']
|
| 78 |
+
full_url = urljoin(base_url, href)
|
| 79 |
+
|
| 80 |
+
# Only add URLs from the same domain
|
| 81 |
+
if urlparse(full_url).netloc == urlparse(base_url).netloc:
|
| 82 |
+
if full_url not in urls and full_url.startswith(base_url):
|
| 83 |
+
urls.append(full_url)
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Error getting URLs from {base_url}: {e}")
|
| 87 |
+
|
| 88 |
+
return urls
|
| 89 |
+
|
| 90 |
+
def extract_text_from_url(self, url: str) -> str:
|
| 91 |
+
"""
|
| 92 |
+
Extract and clean text from a single URL
|
| 93 |
+
"""
|
| 94 |
+
try:
|
| 95 |
+
response = requests.get(url)
|
| 96 |
+
response.raise_for_status()
|
| 97 |
+
|
| 98 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
| 99 |
+
|
| 100 |
+
# Remove script and style elements
|
| 101 |
+
for script in soup(["script", "style"]):
|
| 102 |
+
script.decompose()
|
| 103 |
+
|
| 104 |
+
# Look for main content containers typically used in Docusaurus
|
| 105 |
+
# Try multiple selectors to find the main content
|
| 106 |
+
content_selectors = [
|
| 107 |
+
'article', # Main article content
|
| 108 |
+
'.markdown', # Docusaurus markdown content
|
| 109 |
+
'.theme-doc-markdown', # Docusaurus theme markdown
|
| 110 |
+
'.main-wrapper', # Main content wrapper
|
| 111 |
+
'main', # Main content area
|
| 112 |
+
'.container', # Container with content
|
| 113 |
+
'[role="main"]' # Main role
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
content = ""
|
| 117 |
+
for selector in content_selectors:
|
| 118 |
+
elements = soup.select(selector)
|
| 119 |
+
if elements:
|
| 120 |
+
for element in elements:
|
| 121 |
+
# Get text but try to preserve some structure
|
| 122 |
+
text = element.get_text(separator=' ', strip=True)
|
| 123 |
+
if len(text) > len(content):
|
| 124 |
+
content = text
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
# If no specific content found, get all body text
|
| 128 |
+
if not content:
|
| 129 |
+
body = soup.find('body')
|
| 130 |
+
if body:
|
| 131 |
+
content = body.get_text(separator=' ', strip=True)
|
| 132 |
+
|
| 133 |
+
# Clean up the text
|
| 134 |
+
lines = (line.strip() for line in content.splitlines())
|
| 135 |
+
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
| 136 |
+
content = ' '.join(chunk for chunk in chunks if chunk)
|
| 137 |
+
|
| 138 |
+
return content
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Error extracting text from {url}: {e}")
|
| 142 |
+
return ""
|
| 143 |
+
|
| 144 |
+
def chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
|
| 145 |
+
"""
|
| 146 |
+
Split text into chunks with overlap to preserve context
|
| 147 |
+
"""
|
| 148 |
+
if len(text) <= chunk_size:
|
| 149 |
+
return [text]
|
| 150 |
+
|
| 151 |
+
chunks = []
|
| 152 |
+
start = 0
|
| 153 |
+
|
| 154 |
+
while start < len(text):
|
| 155 |
+
end = start + chunk_size
|
| 156 |
+
chunk = text[start:end]
|
| 157 |
+
chunks.append(chunk)
|
| 158 |
+
|
| 159 |
+
# Move start position by chunk_size - overlap
|
| 160 |
+
start = end - overlap
|
| 161 |
+
|
| 162 |
+
# If remaining text is less than chunk_size, add it as final chunk
|
| 163 |
+
if len(text) - start < chunk_size:
|
| 164 |
+
if start < len(text):
|
| 165 |
+
final_chunk = text[start:]
|
| 166 |
+
if final_chunk not in chunks: # Avoid duplicate chunks
|
| 167 |
+
chunks.append(final_chunk)
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
return chunks
|
| 171 |
+
|
| 172 |
+
def embed(self, text: str) -> List[float]:
|
| 173 |
+
"""
|
| 174 |
+
Generate embedding for text using Cohere
|
| 175 |
+
"""
|
| 176 |
+
try:
|
| 177 |
+
response = self.cohere_client.embed(
|
| 178 |
+
texts=[text],
|
| 179 |
+
model="embed-multilingual-v3.0", # Using multilingual model
|
| 180 |
+
input_type="search_document" # Optimize for search
|
| 181 |
+
)
|
| 182 |
+
return response.embeddings[0] # Return the first (and only) embedding
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"Error generating embedding for text: {e}")
|
| 185 |
+
return []
|
| 186 |
+
|
| 187 |
+
def create_collection(self, collection_name: str = "rag_embedding"):
|
| 188 |
+
"""
|
| 189 |
+
Create a Qdrant collection for storing embeddings
|
| 190 |
+
"""
|
| 191 |
+
try:
|
| 192 |
+
# Check if collection already exists
|
| 193 |
+
collections = self.qdrant_client.get_collections()
|
| 194 |
+
collection_names = [col.name for col in collections.collections]
|
| 195 |
+
|
| 196 |
+
if collection_name in collection_names:
|
| 197 |
+
logger.info(f"Collection {collection_name} already exists")
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
# Create collection with appropriate vector size (1024 for Cohere embeddings)
|
| 201 |
+
self.qdrant_client.create_collection(
|
| 202 |
+
collection_name=collection_name,
|
| 203 |
+
vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE)
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
logger.info(f"Created collection {collection_name} with 1024-dimension vectors")
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"Error creating collection {collection_name}: {e}")
|
| 210 |
+
raise
|
| 211 |
+
|
| 212 |
+
def save_chunk_to_qdrant(self, content: str, url: str, embedding: List[float], position: int, collection_name: str = "rag_embedding"):
|
| 213 |
+
"""
|
| 214 |
+
Save a text chunk with its embedding to Qdrant
|
| 215 |
+
"""
|
| 216 |
+
try:
|
| 217 |
+
# Generate a unique ID for the point
|
| 218 |
+
point_id = str(uuid.uuid4())
|
| 219 |
+
|
| 220 |
+
# Prepare the payload with metadata
|
| 221 |
+
payload = {
|
| 222 |
+
"content": content,
|
| 223 |
+
"url": url,
|
| 224 |
+
"position": position,
|
| 225 |
+
"created_at": time.time()
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
# Create and upload the point to Qdrant
|
| 229 |
+
points = [PointStruct(
|
| 230 |
+
id=point_id,
|
| 231 |
+
vector=embedding,
|
| 232 |
+
payload=payload
|
| 233 |
+
)]
|
| 234 |
+
|
| 235 |
+
self.qdrant_client.upsert(
|
| 236 |
+
collection_name=collection_name,
|
| 237 |
+
points=points
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
logger.info(f"Saved chunk to Qdrant: {url} (position {position})")
|
| 241 |
+
return True
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
logger.error(f"Error saving chunk to Qdrant: {e}")
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
def main():
|
| 248 |
+
"""
|
| 249 |
+
Main function to execute the complete pipeline
|
| 250 |
+
"""
|
| 251 |
+
logger.info("Starting Docusaurus Embedding Pipeline...")
|
| 252 |
+
|
| 253 |
+
# Initialize the pipeline
|
| 254 |
+
pipeline = DocusaurusEmbeddingPipeline()
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
# Step 1: Create the Qdrant collection
|
| 258 |
+
logger.info("Creating Qdrant collection...")
|
| 259 |
+
pipeline.create_collection("rag_embedding")
|
| 260 |
+
|
| 261 |
+
# Step 2: Get all URLs from the target Docusaurus site
|
| 262 |
+
logger.info(f"Extracting URLs from {pipeline.target_url}...")
|
| 263 |
+
urls = pipeline.get_all_urls(pipeline.target_url)
|
| 264 |
+
|
| 265 |
+
if not urls:
|
| 266 |
+
logger.warning(f"No URLs found at {pipeline.target_url}")
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
logger.info(f"Found {len(urls)} URLs to process")
|
| 270 |
+
|
| 271 |
+
# Step 3: Process each URL
|
| 272 |
+
total_chunks = 0
|
| 273 |
+
for i, url in enumerate(urls):
|
| 274 |
+
logger.info(f"Processing URL {i+1}/{len(urls)}: {url}")
|
| 275 |
+
|
| 276 |
+
# Extract text from the URL
|
| 277 |
+
text_content = pipeline.extract_text_from_url(url)
|
| 278 |
+
|
| 279 |
+
if not text_content:
|
| 280 |
+
logger.warning(f"No content extracted from {url}")
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
logger.info(f"Extracted {len(text_content)} characters from {url}")
|
| 284 |
+
|
| 285 |
+
# Chunk the text
|
| 286 |
+
chunks = pipeline.chunk_text(text_content)
|
| 287 |
+
logger.info(f"Created {len(chunks)} chunks from {url}")
|
| 288 |
+
|
| 289 |
+
# Process each chunk
|
| 290 |
+
for j, chunk in enumerate(chunks):
|
| 291 |
+
if not chunk.strip():
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
# Generate embedding
|
| 295 |
+
embedding = pipeline.embed(chunk)
|
| 296 |
+
|
| 297 |
+
if not embedding:
|
| 298 |
+
logger.error(f"Failed to generate embedding for chunk {j} of {url}")
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
# Save to Qdrant
|
| 302 |
+
success = pipeline.save_chunk_to_qdrant(
|
| 303 |
+
content=chunk,
|
| 304 |
+
url=url,
|
| 305 |
+
embedding=embedding,
|
| 306 |
+
position=j
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if success:
|
| 310 |
+
total_chunks += 1
|
| 311 |
+
logger.info(f"Successfully saved chunk {j} of {url} to Qdrant")
|
| 312 |
+
else:
|
| 313 |
+
logger.error(f"Failed to save chunk {j} of {url} to Qdrant")
|
| 314 |
+
|
| 315 |
+
logger.info(f"Pipeline completed successfully! Total chunks saved: {total_chunks}")
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.error(f"Pipeline failed with error: {e}")
|
| 319 |
+
raise
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "docusaurus-embedding-pipeline"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Pipeline to extract text from Docusaurus sites, generate embeddings, and store in Qdrant"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"requests>=2.31.0",
|
| 13 |
+
"beautifulsoup4>=4.12.2",
|
| 14 |
+
"cohere>=4.9.0",
|
| 15 |
+
"qdrant-client>=1.9.0",
|
| 16 |
+
"python-dotenv>=1.0.0"
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.optional-dependencies]
|
| 20 |
+
dev = [
|
| 21 |
+
"pytest>=7.0",
|
| 22 |
+
"black>=23.0",
|
| 23 |
+
"flake8>=6.0"
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[tool.setuptools.packages.find]
|
| 27 |
+
where = ["."]
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn[standard]==0.24.0
|
| 3 |
+
python-dotenv==1.0.0
|
| 4 |
+
qdrant-client==1.9.1
|
| 5 |
+
httpx==0.25.2
|
| 6 |
+
psycopg2-binary==2.9.9
|
| 7 |
+
sqlalchemy==2.0.23
|
| 8 |
+
pydantic==2.5.0
|
| 9 |
+
pydantic-settings==2.1.0
|
| 10 |
+
openai==1.3.6
|
| 11 |
+
tiktoken==0.5.2
|
| 12 |
+
markdown==3.5.1
|
| 13 |
+
python-multipart==0.0.6
|
| 14 |
+
python-jose[cryptography]==3.3.0
|
| 15 |
+
passlib[bcrypt]==1.7.4
|
| 16 |
+
python-slugify==8.0.1
|
| 17 |
+
asyncpg==0.29.0
|
| 18 |
+
alembic==1.13.1
|
| 19 |
+
beautifulsoup4==4.12.2
|
| 20 |
+
scikit-learn==1.3.2
|
| 21 |
+
requests>=2.31.0
|
| 22 |
+
cohere>=4.9.0
|
retrieving.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
import cohere
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
from qdrant_client.http import models
|
| 7 |
+
import logging
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
# Load environment variables
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class RAGRetriever:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
# Initialize Cohere client
|
| 21 |
+
self.cohere_client = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
|
| 22 |
+
|
| 23 |
+
# Initialize Qdrant client
|
| 24 |
+
qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")
|
| 25 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 26 |
+
|
| 27 |
+
if qdrant_api_key:
|
| 28 |
+
self.qdrant_client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
|
| 29 |
+
else:
|
| 30 |
+
self.qdrant_client = QdrantClient(url=qdrant_url)
|
| 31 |
+
|
| 32 |
+
# Default collection name
|
| 33 |
+
self.collection_name = "rag_embedding"
|
| 34 |
+
|
| 35 |
+
def get_embedding(self, text: str) -> List[float]:
|
| 36 |
+
"""
|
| 37 |
+
Generate embedding for query text using Cohere
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
response = self.cohere_client.embed(
|
| 41 |
+
texts=[text],
|
| 42 |
+
model="embed-multilingual-v3.0", # Using same model as storage
|
| 43 |
+
input_type="search_query" # Optimize for search queries
|
| 44 |
+
)
|
| 45 |
+
return response.embeddings[0] # Return the first (and only) embedding
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.error(f"Error generating embedding for query: {e}")
|
| 48 |
+
return []
|
| 49 |
+
|
| 50 |
+
def query_qdrant(self, query_embedding: List[float], top_k: int = 5, threshold: float = 0.0) -> List[Dict]:
|
| 51 |
+
"""
|
| 52 |
+
Query Qdrant for similar vectors and return results with metadata
|
| 53 |
+
"""
|
| 54 |
+
try:
|
| 55 |
+
# Perform similarity search in Qdrant
|
| 56 |
+
search_results = self.qdrant_client.search(
|
| 57 |
+
collection_name=self.collection_name,
|
| 58 |
+
query_vector=query_embedding,
|
| 59 |
+
limit=top_k,
|
| 60 |
+
score_threshold=threshold,
|
| 61 |
+
with_payload=True # Include metadata with results
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Format results
|
| 65 |
+
formatted_results = []
|
| 66 |
+
for result in search_results:
|
| 67 |
+
formatted_result = {
|
| 68 |
+
"content": result.payload.get("content", ""),
|
| 69 |
+
"url": result.payload.get("url", ""),
|
| 70 |
+
"position": result.payload.get("position", 0),
|
| 71 |
+
"similarity_score": result.score,
|
| 72 |
+
"chunk_id": result.id,
|
| 73 |
+
"created_at": result.payload.get("created_at", "")
|
| 74 |
+
}
|
| 75 |
+
formatted_results.append(formatted_result)
|
| 76 |
+
|
| 77 |
+
return formatted_results
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(f"Error querying Qdrant: {e}")
|
| 81 |
+
return []
|
| 82 |
+
|
| 83 |
+
def verify_content_accuracy(self, retrieved_chunks: List[Dict]) -> bool:
|
| 84 |
+
"""
|
| 85 |
+
Verify that retrieved content matches original stored text (basic validation)
|
| 86 |
+
"""
|
| 87 |
+
# In a real implementation, this would compare against original sources
|
| 88 |
+
# For now, we'll validate that required fields exist and have content
|
| 89 |
+
for chunk in retrieved_chunks:
|
| 90 |
+
if not chunk.get("content") or not chunk.get("url"):
|
| 91 |
+
logger.warning(f"Missing content or URL in chunk: {chunk.get('chunk_id')}")
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
# Additional validation could include checking content length, URL format, etc.
|
| 95 |
+
return True
|
| 96 |
+
|
| 97 |
+
def format_json_response(self, results: List[Dict], query: str, query_time_ms: float) -> str:
|
| 98 |
+
"""
|
| 99 |
+
Format retrieval results into clean JSON response
|
| 100 |
+
"""
|
| 101 |
+
response = {
|
| 102 |
+
"query": query,
|
| 103 |
+
"results": results,
|
| 104 |
+
"metadata": {
|
| 105 |
+
"query_time_ms": query_time_ms,
|
| 106 |
+
"total_results": len(results),
|
| 107 |
+
"timestamp": time.time(),
|
| 108 |
+
"collection_name": self.collection_name
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return json.dumps(response, indent=2)
|
| 113 |
+
|
| 114 |
+
def retrieve(self, query_text: str, top_k: int = 5, threshold: float = 0.0, include_metadata: bool = True) -> str:
|
| 115 |
+
"""
|
| 116 |
+
Main retrieval function that orchestrates the complete workflow
|
| 117 |
+
"""
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
|
| 120 |
+
logger.info(f"Processing retrieval request for query: '{query_text[:50]}...'")
|
| 121 |
+
|
| 122 |
+
# Step 1: Convert query text to embedding
|
| 123 |
+
query_embedding = self.get_embedding(query_text)
|
| 124 |
+
if not query_embedding:
|
| 125 |
+
error_response = {
|
| 126 |
+
"query": query_text,
|
| 127 |
+
"results": [],
|
| 128 |
+
"error": "Failed to generate query embedding",
|
| 129 |
+
"metadata": {
|
| 130 |
+
"query_time_ms": (time.time() - start_time) * 1000,
|
| 131 |
+
"timestamp": time.time()
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
return json.dumps(error_response, indent=2)
|
| 135 |
+
|
| 136 |
+
# Step 2: Query Qdrant for similar vectors
|
| 137 |
+
raw_results = self.query_qdrant(query_embedding, top_k, threshold)
|
| 138 |
+
|
| 139 |
+
if not raw_results:
|
| 140 |
+
logger.warning("No results returned from Qdrant")
|
| 141 |
+
|
| 142 |
+
# Step 3: Verify content accuracy (optional)
|
| 143 |
+
if include_metadata:
|
| 144 |
+
is_accurate = self.verify_content_accuracy(raw_results)
|
| 145 |
+
if not is_accurate:
|
| 146 |
+
logger.warning("Content accuracy verification failed for some results")
|
| 147 |
+
|
| 148 |
+
# Step 4: Calculate total query time
|
| 149 |
+
query_time_ms = (time.time() - start_time) * 1000
|
| 150 |
+
|
| 151 |
+
# Step 5: Format response as JSON
|
| 152 |
+
json_response = self.format_json_response(raw_results, query_text, query_time_ms)
|
| 153 |
+
|
| 154 |
+
logger.info(f"Retrieval completed in {query_time_ms:.2f}ms, {len(raw_results)} results returned")
|
| 155 |
+
|
| 156 |
+
return json_response
|
| 157 |
+
|
| 158 |
+
def retrieve_all_data():
|
| 159 |
+
"""
|
| 160 |
+
Function to retrieve and display all data from Qdrant collection
|
| 161 |
+
"""
|
| 162 |
+
logger.info("Initializing RAG Retriever to fetch all data...")
|
| 163 |
+
|
| 164 |
+
# Initialize the retriever
|
| 165 |
+
retriever = RAGRetriever()
|
| 166 |
+
|
| 167 |
+
print("RAG Retrieval System - All Stored Data")
|
| 168 |
+
print("=" * 50)
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
# Get all points from the collection using scroll
|
| 172 |
+
points = []
|
| 173 |
+
offset = None
|
| 174 |
+
while True:
|
| 175 |
+
# Scroll through the collection to get all points
|
| 176 |
+
batch, next_offset = retriever.qdrant_client.scroll(
|
| 177 |
+
collection_name=retriever.collection_name,
|
| 178 |
+
limit=1000, # Get up to 1000 points at a time
|
| 179 |
+
offset=offset,
|
| 180 |
+
with_payload=True,
|
| 181 |
+
with_vectors=False
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
points.extend(batch)
|
| 185 |
+
|
| 186 |
+
# If next_offset is None, we've reached the end
|
| 187 |
+
if next_offset is None:
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
offset = next_offset
|
| 191 |
+
|
| 192 |
+
print(f"Total stored chunks: {len(points)}")
|
| 193 |
+
print("-" * 50)
|
| 194 |
+
|
| 195 |
+
for i, point in enumerate(points, 1):
|
| 196 |
+
payload = point.payload
|
| 197 |
+
content_preview = ''.join(char for char in payload.get("content", "")[:200] if ord(char) < 256)
|
| 198 |
+
|
| 199 |
+
print(f"Chunk {i}:")
|
| 200 |
+
print(f" ID: {point.id}")
|
| 201 |
+
print(f" URL: {payload.get('url', 'N/A')}")
|
| 202 |
+
print(f" Position: {payload.get('position', 'N/A')}")
|
| 203 |
+
print(f" Content Preview: {content_preview}...")
|
| 204 |
+
print(f" Created At: {payload.get('created_at', 'N/A')}")
|
| 205 |
+
print("-" * 30)
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Error retrieving all data: {e}")
|
| 209 |
+
print(f"Error retrieving all data: {e}")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def main():
|
| 213 |
+
"""
|
| 214 |
+
Main function to demonstrate the retrieval functionality
|
| 215 |
+
"""
|
| 216 |
+
import sys
|
| 217 |
+
|
| 218 |
+
logger.info("Initializing RAG Retriever...")
|
| 219 |
+
|
| 220 |
+
# Check if user wants to retrieve all data or run queries
|
| 221 |
+
if len(sys.argv) > 1 and sys.argv[1] == "all":
|
| 222 |
+
retrieve_all_data()
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
# Initialize the retriever
|
| 226 |
+
retriever = RAGRetriever()
|
| 227 |
+
|
| 228 |
+
# Example queries to test the system
|
| 229 |
+
test_queries = [
|
| 230 |
+
"What is ROS2?",
|
| 231 |
+
"Explain humanoid design principles",
|
| 232 |
+
"How does VLA work?",
|
| 233 |
+
"What are simulation techniques?",
|
| 234 |
+
"Explain AI control systems"
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
print("RAG Retrieval System - Testing Queries")
|
| 238 |
+
print("=" * 50)
|
| 239 |
+
|
| 240 |
+
for i, query in enumerate(test_queries, 1):
|
| 241 |
+
print(f"\nQuery {i}: {query}")
|
| 242 |
+
print("-" * 30)
|
| 243 |
+
|
| 244 |
+
# Retrieve results
|
| 245 |
+
json_response = retriever.retrieve(query, top_k=3)
|
| 246 |
+
response_dict = json.loads(json_response)
|
| 247 |
+
|
| 248 |
+
# Print formatted results
|
| 249 |
+
results = response_dict.get("results", [])
|
| 250 |
+
if results:
|
| 251 |
+
for j, result in enumerate(results, 1):
|
| 252 |
+
print(f"Result {j} (Score: {result['similarity_score']:.3f}):")
|
| 253 |
+
print(f" URL: {result['url']}")
|
| 254 |
+
content_preview = result['content'][:100].encode('utf-8', errors='ignore').decode('utf-8')
|
| 255 |
+
# Safely print content preview by removing problematic characters
|
| 256 |
+
safe_content = ''.join(char for char in content_preview if ord(char) < 256)
|
| 257 |
+
print(f" Content Preview: {safe_content}...")
|
| 258 |
+
print(f" Position: {result['position']}")
|
| 259 |
+
print()
|
| 260 |
+
else:
|
| 261 |
+
print("No results found for this query.")
|
| 262 |
+
|
| 263 |
+
print(f"Query time: {response_dict['metadata']['query_time_ms']:.2f}ms")
|
| 264 |
+
print(f"Total results: {response_dict['metadata']['total_results']}")
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
main()
|
sdk.md
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenAI Agents SDK
|
| 2 |
+
|
| 3 |
+
The OpenAI Agents SDK is a lightweight Python framework for building production-ready agentic AI applications with minimal abstractions. It provides a streamlined upgrade from the experimental Swarm framework, offering essential primitives like agents with instructions and tools, handoffs for task delegation between agents, guardrails for input/output validation, and sessions for automatic conversation history management. The SDK emphasizes ease of use while maintaining enough power to express complex multi-agent relationships, making it suitable for real-world applications without requiring mastery of complex frameworks.
|
| 4 |
+
|
| 5 |
+
Built on core design principles of simplicity and customization, the SDK includes an automatic agent loop handling tool calls and LLM interactions, Python-first orchestration without new abstractions to learn, built-in tracing for visualization and debugging, and automatic schema generation with Pydantic validation for function tools. It supports multiple model providers through OpenAI's Responses API and Chat Completions API, with native integration for LiteLLM and custom providers. Whether building single-agent assistants or complex multi-agent workflows with specialized roles, the SDK provides the necessary features to move quickly from prototype to production.
|
| 6 |
+
|
| 7 |
+
## Installation and Setup
|
| 8 |
+
|
| 9 |
+
Install the SDK and configure your environment
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
pip install openai-agents
|
| 13 |
+
|
| 14 |
+
export OPENAI_API_KEY=sk-...
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Creating a Basic Agent
|
| 18 |
+
|
| 19 |
+
Define an agent with name and instructions
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
from agents import Agent, Runner
|
| 23 |
+
|
| 24 |
+
agent = Agent(
|
| 25 |
+
name="Math Tutor",
|
| 26 |
+
instructions="You provide help with math problems. Explain your reasoning at each step and include examples"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
result = Runner.run_sync(agent, "What is 15% of 80?")
|
| 30 |
+
print(result.final_output)
|
| 31 |
+
# 15% of 80 is 12. To calculate: 0.15 × 80 = 12
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Running Agents Asynchronously
|
| 35 |
+
|
| 36 |
+
Execute agent with async/await pattern
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
import asyncio
|
| 40 |
+
from agents import Agent, Runner
|
| 41 |
+
|
| 42 |
+
async def main():
|
| 43 |
+
agent = Agent(
|
| 44 |
+
name="Assistant",
|
| 45 |
+
instructions="Reply very concisely."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
result = await Runner.run(agent, "What city is the Golden Gate Bridge in?")
|
| 49 |
+
print(result.final_output)
|
| 50 |
+
# San Francisco
|
| 51 |
+
|
| 52 |
+
asyncio.run(main())
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Agent with Function Tools
|
| 56 |
+
|
| 57 |
+
Decorate Python functions to create tools with automatic schema generation
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
from agents import Agent, Runner, function_tool
|
| 61 |
+
import asyncio
|
| 62 |
+
|
| 63 |
+
@function_tool
|
| 64 |
+
async def get_weather(city: str) -> str:
|
| 65 |
+
"""Fetch the weather for a given location.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
city: The city to fetch weather for.
|
| 69 |
+
"""
|
| 70 |
+
# In production, call actual weather API
|
| 71 |
+
return f"The weather in {city} is sunny and 72°F"
|
| 72 |
+
|
| 73 |
+
@function_tool
|
| 74 |
+
def calculate_sum(a: int, b: int) -> int:
|
| 75 |
+
"""Add two numbers together.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
a: First number
|
| 79 |
+
b: Second number
|
| 80 |
+
"""
|
| 81 |
+
return a + b
|
| 82 |
+
|
| 83 |
+
agent = Agent(
|
| 84 |
+
name="Assistant",
|
| 85 |
+
instructions="Use the provided tools to help the user",
|
| 86 |
+
tools=[get_weather, calculate_sum]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
async def main():
|
| 90 |
+
result = await Runner.run(agent, "What's the weather in Seattle?")
|
| 91 |
+
print(result.final_output)
|
| 92 |
+
# The weather in Seattle is sunny and 72°F
|
| 93 |
+
|
| 94 |
+
asyncio.run(main())
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
## Agent with Hosted Tools
|
| 98 |
+
|
| 99 |
+
Use OpenAI's built-in tools for web search and file retrieval
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
from agents import Agent, Runner, WebSearchTool, FileSearchTool
|
| 103 |
+
import asyncio
|
| 104 |
+
|
| 105 |
+
agent = Agent(
|
| 106 |
+
name="Research Assistant",
|
| 107 |
+
instructions="Use web search and file search to answer questions thoroughly",
|
| 108 |
+
tools=[
|
| 109 |
+
WebSearchTool(),
|
| 110 |
+
FileSearchTool(
|
| 111 |
+
max_num_results=5,
|
| 112 |
+
vector_store_ids=["vs_abc123"]
|
| 113 |
+
)
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
async def main():
|
| 118 |
+
result = await Runner.run(
|
| 119 |
+
agent,
|
| 120 |
+
"What are the latest developments in quantum computing?"
|
| 121 |
+
)
|
| 122 |
+
print(result.final_output)
|
| 123 |
+
|
| 124 |
+
asyncio.run(main())
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
## Multi-Agent Handoffs
|
| 128 |
+
|
| 129 |
+
Create specialized agents that delegate to each other
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
from agents import Agent, Runner
|
| 133 |
+
import asyncio
|
| 134 |
+
|
| 135 |
+
billing_agent = Agent(
|
| 136 |
+
name="Billing Agent",
|
| 137 |
+
handoff_description="Specialist for billing questions and payment issues",
|
| 138 |
+
instructions="You handle billing inquiries. Check account status and process refunds."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
technical_agent = Agent(
|
| 142 |
+
name="Technical Agent",
|
| 143 |
+
handoff_description="Specialist for technical support and troubleshooting",
|
| 144 |
+
instructions="You handle technical issues. Diagnose problems and provide solutions."
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
triage_agent = Agent(
|
| 148 |
+
name="Triage Agent",
|
| 149 |
+
instructions=(
|
| 150 |
+
"Determine which specialist agent should handle the user's request. "
|
| 151 |
+
"Hand off to the appropriate agent based on the question type."
|
| 152 |
+
),
|
| 153 |
+
handoffs=[billing_agent, technical_agent]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
async def main():
|
| 157 |
+
result = await Runner.run(
|
| 158 |
+
triage_agent,
|
| 159 |
+
"I was charged twice for my subscription this month"
|
| 160 |
+
)
|
| 161 |
+
print(result.final_output)
|
| 162 |
+
# Output from billing_agent after handoff
|
| 163 |
+
|
| 164 |
+
asyncio.run(main())
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## Custom Handoff with Input Data
|
| 168 |
+
|
| 169 |
+
Configure handoffs with structured input and callbacks
|
| 170 |
+
|
| 171 |
+
```python
|
| 172 |
+
from agents import Agent, Runner, handoff, RunContextWrapper
|
| 173 |
+
from pydantic import BaseModel
|
| 174 |
+
import asyncio
|
| 175 |
+
|
| 176 |
+
class EscalationData(BaseModel):
|
| 177 |
+
reason: str
|
| 178 |
+
severity: str
|
| 179 |
+
|
| 180 |
+
async def on_escalation(ctx: RunContextWrapper[None], input_data: EscalationData):
|
| 181 |
+
print(f"Escalated: {input_data.reason} (severity: {input_data.severity})")
|
| 182 |
+
# Log to monitoring system, send alert, etc.
|
| 183 |
+
|
| 184 |
+
escalation_agent = Agent(
|
| 185 |
+
name="Manager",
|
| 186 |
+
instructions="Handle escalated customer issues with priority"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
support_agent = Agent(
|
| 190 |
+
name="Support Agent",
|
| 191 |
+
instructions="Help customers. Escalate to manager if issue is severe.",
|
| 192 |
+
handoffs=[
|
| 193 |
+
handoff(
|
| 194 |
+
agent=escalation_agent,
|
| 195 |
+
on_handoff=on_escalation,
|
| 196 |
+
input_type=EscalationData,
|
| 197 |
+
tool_description_override="Escalate urgent issues to management"
|
| 198 |
+
)
|
| 199 |
+
]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
async def main():
|
| 203 |
+
result = await Runner.run(
|
| 204 |
+
support_agent,
|
| 205 |
+
"This is completely unacceptable! I demand to speak to a manager!"
|
| 206 |
+
)
|
| 207 |
+
print(result.final_output)
|
| 208 |
+
|
| 209 |
+
asyncio.run(main())
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
## Input Guardrails
|
| 213 |
+
|
| 214 |
+
Validate user input before processing with the main agent
|
| 215 |
+
|
| 216 |
+
```python
|
| 217 |
+
from agents import Agent, Runner, input_guardrail, GuardrailFunctionOutput
|
| 218 |
+
from agents import InputGuardrailTripwireTriggered, RunContextWrapper, TResponseInputItem
|
| 219 |
+
from pydantic import BaseModel
|
| 220 |
+
import asyncio
|
| 221 |
+
|
| 222 |
+
class HomeworkCheck(BaseModel):
|
| 223 |
+
is_homework: bool
|
| 224 |
+
reasoning: str
|
| 225 |
+
|
| 226 |
+
guardrail_agent = Agent(
|
| 227 |
+
name="Homework Detector",
|
| 228 |
+
instructions="Determine if the user is asking for homework help",
|
| 229 |
+
output_type=HomeworkCheck
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
@input_guardrail
|
| 233 |
+
async def homework_guardrail(
|
| 234 |
+
ctx: RunContextWrapper[None],
|
| 235 |
+
agent: Agent,
|
| 236 |
+
input_data: str | list[TResponseInputItem]
|
| 237 |
+
) -> GuardrailFunctionOutput:
|
| 238 |
+
result = await Runner.run(guardrail_agent, input_data, context=ctx.context)
|
| 239 |
+
|
| 240 |
+
return GuardrailFunctionOutput(
|
| 241 |
+
output_info=result.final_output,
|
| 242 |
+
tripwire_triggered=result.final_output.is_homework
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
tutoring_agent = Agent(
|
| 246 |
+
name="Tutoring Service",
|
| 247 |
+
instructions="You help students understand concepts, not do their homework",
|
| 248 |
+
input_guardrails=[homework_guardrail]
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
async def main():
|
| 252 |
+
try:
|
| 253 |
+
result = await Runner.run(
|
| 254 |
+
tutoring_agent,
|
| 255 |
+
"Can you solve this equation for me: 2x + 5 = 15?"
|
| 256 |
+
)
|
| 257 |
+
print(result.final_output)
|
| 258 |
+
except InputGuardrailTripwireTriggered as e:
|
| 259 |
+
print("Request blocked: This appears to be homework help")
|
| 260 |
+
|
| 261 |
+
asyncio.run(main())
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
## Output Guardrails
|
| 265 |
+
|
| 266 |
+
Validate agent responses before returning to user
|
| 267 |
+
|
| 268 |
+
```python
|
| 269 |
+
from agents import Agent, Runner, output_guardrail, GuardrailFunctionOutput
|
| 270 |
+
from agents import OutputGuardrailTripwireTriggered, RunContextWrapper
|
| 271 |
+
from pydantic import BaseModel
|
| 272 |
+
import asyncio
|
| 273 |
+
|
| 274 |
+
class ToxicityCheck(BaseModel):
|
| 275 |
+
is_toxic: bool
|
| 276 |
+
confidence: float
|
| 277 |
+
|
| 278 |
+
class AgentResponse(BaseModel):
|
| 279 |
+
message: str
|
| 280 |
+
|
| 281 |
+
toxicity_checker = Agent(
|
| 282 |
+
name="Toxicity Detector",
|
| 283 |
+
instructions="Analyze if the message contains toxic or harmful content",
|
| 284 |
+
output_type=ToxicityCheck
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
@output_guardrail
|
| 288 |
+
async def toxicity_guardrail(
|
| 289 |
+
ctx: RunContextWrapper[None],
|
| 290 |
+
agent: Agent,
|
| 291 |
+
output: AgentResponse
|
| 292 |
+
) -> GuardrailFunctionOutput:
|
| 293 |
+
result = await Runner.run(toxicity_checker, output.message, context=ctx.context)
|
| 294 |
+
|
| 295 |
+
return GuardrailFunctionOutput(
|
| 296 |
+
output_info=result.final_output,
|
| 297 |
+
tripwire_triggered=result.final_output.is_toxic and result.final_output.confidence > 0.8
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
chatbot = Agent(
|
| 301 |
+
name="Chatbot",
|
| 302 |
+
instructions="You are a friendly assistant",
|
| 303 |
+
output_guardrails=[toxicity_guardrail],
|
| 304 |
+
output_type=AgentResponse
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
async def main():
|
| 308 |
+
try:
|
| 309 |
+
result = await Runner.run(chatbot, "Tell me about your day")
|
| 310 |
+
print(result.final_output.message)
|
| 311 |
+
except OutputGuardrailTripwireTriggered:
|
| 312 |
+
print("Response blocked by content filter")
|
| 313 |
+
|
| 314 |
+
asyncio.run(main())
|
| 315 |
+
```
|
| 316 |
+
|
| 317 |
+
## Sessions for Conversation Memory
|
| 318 |
+
|
| 319 |
+
Automatically maintain conversation history across multiple turns
|
| 320 |
+
|
| 321 |
+
```python
|
| 322 |
+
from agents import Agent, Runner, SQLiteSession
|
| 323 |
+
import asyncio
|
| 324 |
+
|
| 325 |
+
async def main():
|
| 326 |
+
agent = Agent(
|
| 327 |
+
name="Assistant",
|
| 328 |
+
instructions="Reply concisely and remember previous context"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Create persistent session with SQLite backend
|
| 332 |
+
session = SQLiteSession("user_123", "conversations.db")
|
| 333 |
+
|
| 334 |
+
# First turn
|
| 335 |
+
result = await Runner.run(
|
| 336 |
+
agent,
|
| 337 |
+
"What city is the Golden Gate Bridge in?",
|
| 338 |
+
session=session
|
| 339 |
+
)
|
| 340 |
+
print(result.final_output)
|
| 341 |
+
# San Francisco
|
| 342 |
+
|
| 343 |
+
# Second turn - agent remembers previous context
|
| 344 |
+
result = await Runner.run(
|
| 345 |
+
agent,
|
| 346 |
+
"What state is it in?",
|
| 347 |
+
session=session
|
| 348 |
+
)
|
| 349 |
+
print(result.final_output)
|
| 350 |
+
# California
|
| 351 |
+
|
| 352 |
+
# Third turn - continuing the conversation
|
| 353 |
+
result = await Runner.run(
|
| 354 |
+
agent,
|
| 355 |
+
"What's the population?",
|
| 356 |
+
session=session
|
| 357 |
+
)
|
| 358 |
+
print(result.final_output)
|
| 359 |
+
# Approximately 39 million
|
| 360 |
+
|
| 361 |
+
asyncio.run(main())
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
## Session Management Operations
|
| 365 |
+
|
| 366 |
+
Manipulate conversation history programmatically
|
| 367 |
+
|
| 368 |
+
```python
|
| 369 |
+
from agents import Agent, Runner, SQLiteSession
|
| 370 |
+
import asyncio
|
| 371 |
+
|
| 372 |
+
async def main():
|
| 373 |
+
session = SQLiteSession("conversation_456", "chats.db")
|
| 374 |
+
|
| 375 |
+
# Get all conversation items
|
| 376 |
+
items = await session.get_items()
|
| 377 |
+
print(f"Total messages: {len(items)}")
|
| 378 |
+
|
| 379 |
+
# Add items manually
|
| 380 |
+
await session.add_items([
|
| 381 |
+
{"role": "user", "content": "Hello"},
|
| 382 |
+
{"role": "assistant", "content": "Hi! How can I help?"}
|
| 383 |
+
])
|
| 384 |
+
|
| 385 |
+
# Remove last item (useful for corrections)
|
| 386 |
+
agent = Agent(name="Assistant")
|
| 387 |
+
|
| 388 |
+
result = await Runner.run(agent, "What's 2 + 2?", session=session)
|
| 389 |
+
print(result.final_output)
|
| 390 |
+
|
| 391 |
+
# User wants to correct their question
|
| 392 |
+
await session.pop_item() # Remove assistant response
|
| 393 |
+
await session.pop_item() # Remove user question
|
| 394 |
+
|
| 395 |
+
result = await Runner.run(agent, "What's 2 + 3?", session=session)
|
| 396 |
+
print(result.final_output)
|
| 397 |
+
|
| 398 |
+
# Clear entire session
|
| 399 |
+
await session.clear_session()
|
| 400 |
+
|
| 401 |
+
asyncio.run(main())
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
## OpenAI Conversations Session
|
| 405 |
+
|
| 406 |
+
Use OpenAI-hosted conversation storage
|
| 407 |
+
|
| 408 |
+
```python
|
| 409 |
+
from agents import Agent, Runner, OpenAIConversationsSession
|
| 410 |
+
import asyncio
|
| 411 |
+
|
| 412 |
+
async def main():
|
| 413 |
+
agent = Agent(name="Assistant")
|
| 414 |
+
|
| 415 |
+
# Create new conversation or resume existing one
|
| 416 |
+
session = OpenAIConversationsSession()
|
| 417 |
+
# Or with existing conversation ID:
|
| 418 |
+
# session = OpenAIConversationsSession(conversation_id="conv_abc123")
|
| 419 |
+
|
| 420 |
+
result = await Runner.run(
|
| 421 |
+
agent,
|
| 422 |
+
"Remember that my favorite color is blue",
|
| 423 |
+
session=session
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Later conversation with same session
|
| 427 |
+
result = await Runner.run(
|
| 428 |
+
agent,
|
| 429 |
+
"What's my favorite color?",
|
| 430 |
+
session=session
|
| 431 |
+
)
|
| 432 |
+
print(result.final_output)
|
| 433 |
+
# Your favorite color is blue
|
| 434 |
+
|
| 435 |
+
asyncio.run(main())
|
| 436 |
+
```
|
| 437 |
+
|
| 438 |
+
## Structured Outputs
|
| 439 |
+
|
| 440 |
+
Force agents to return specific data types with validation
|
| 441 |
+
|
| 442 |
+
```python
|
| 443 |
+
from agents import Agent, Runner
|
| 444 |
+
from pydantic import BaseModel
|
| 445 |
+
import asyncio
|
| 446 |
+
|
| 447 |
+
class CalendarEvent(BaseModel):
|
| 448 |
+
title: str
|
| 449 |
+
date: str
|
| 450 |
+
participants: list[str]
|
| 451 |
+
location: str | None = None
|
| 452 |
+
|
| 453 |
+
agent = Agent(
|
| 454 |
+
name="Calendar Parser",
|
| 455 |
+
instructions="Extract calendar event information from text",
|
| 456 |
+
output_type=CalendarEvent
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
async def main():
|
| 460 |
+
text = "Schedule a team meeting on March 15th with John, Sarah, and Mike"
|
| 461 |
+
result = await Runner.run(agent, text)
|
| 462 |
+
|
| 463 |
+
event = result.final_output_as(CalendarEvent)
|
| 464 |
+
print(f"Event: {event.title}")
|
| 465 |
+
print(f"Date: {event.date}")
|
| 466 |
+
print(f"Attendees: {', '.join(event.participants)}")
|
| 467 |
+
# Event: Team Meeting
|
| 468 |
+
# Date: March 15th
|
| 469 |
+
# Attendees: John, Sarah, Mike
|
| 470 |
+
|
| 471 |
+
asyncio.run(main())
|
| 472 |
+
```
|
| 473 |
+
|
| 474 |
+
## Agent Context and Dependency Injection
|
| 475 |
+
|
| 476 |
+
Pass custom context objects to agents and tools
|
| 477 |
+
|
| 478 |
+
```python
|
| 479 |
+
from dataclasses import dataclass
|
| 480 |
+
from agents import Agent, Runner, RunContextWrapper, function_tool
|
| 481 |
+
import asyncio
|
| 482 |
+
|
| 483 |
+
@dataclass
|
| 484 |
+
class UserContext:
|
| 485 |
+
user_id: str
|
| 486 |
+
is_premium: bool
|
| 487 |
+
api_token: str
|
| 488 |
+
|
| 489 |
+
@function_tool
|
| 490 |
+
async def get_user_data(ctx: RunContextWrapper[UserContext]) -> str:
|
| 491 |
+
"""Fetch user-specific data using context."""
|
| 492 |
+
user_id = ctx.context.user_id
|
| 493 |
+
is_premium = ctx.context.is_premium
|
| 494 |
+
|
| 495 |
+
if is_premium:
|
| 496 |
+
return f"Premium user {user_id} has access to all features"
|
| 497 |
+
return f"User {user_id} has basic access"
|
| 498 |
+
|
| 499 |
+
agent = Agent[UserContext](
|
| 500 |
+
name="Account Manager",
|
| 501 |
+
instructions="Provide user information based on their account status",
|
| 502 |
+
tools=[get_user_data]
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
async def main():
|
| 506 |
+
context = UserContext(
|
| 507 |
+
user_id="user_789",
|
| 508 |
+
is_premium=True,
|
| 509 |
+
api_token="secret_token"
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
result = await Runner.run(agent, "What's my account status?", context=context)
|
| 513 |
+
print(result.final_output)
|
| 514 |
+
|
| 515 |
+
asyncio.run(main())
|
| 516 |
+
```
|
| 517 |
+
|
| 518 |
+
## Dynamic Instructions
|
| 519 |
+
|
| 520 |
+
Generate agent instructions at runtime based on context
|
| 521 |
+
|
| 522 |
+
```python
|
| 523 |
+
from agents import Agent, Runner, RunContextWrapper
|
| 524 |
+
from dataclasses import dataclass
|
| 525 |
+
import asyncio
|
| 526 |
+
|
| 527 |
+
@dataclass
|
| 528 |
+
class AppContext:
|
| 529 |
+
username: str
|
| 530 |
+
language: str
|
| 531 |
+
timezone: str
|
| 532 |
+
|
| 533 |
+
def dynamic_instructions(
|
| 534 |
+
context: RunContextWrapper[AppContext],
|
| 535 |
+
agent: Agent[AppContext]
|
| 536 |
+
) -> str:
|
| 537 |
+
user = context.context
|
| 538 |
+
return f"""You are a helpful assistant for {user.username}.
|
| 539 |
+
- Respond in {user.language}
|
| 540 |
+
- Use {user.timezone} timezone for all time references
|
| 541 |
+
- Be friendly and personalized"""
|
| 542 |
+
|
| 543 |
+
agent = Agent[AppContext](
|
| 544 |
+
name="Personal Assistant",
|
| 545 |
+
instructions=dynamic_instructions
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
async def main():
|
| 549 |
+
context = AppContext(
|
| 550 |
+
username="Alice",
|
| 551 |
+
language="Spanish",
|
| 552 |
+
timezone="PST"
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
result = await Runner.run(agent, "What time is it?", context=context)
|
| 556 |
+
print(result.final_output)
|
| 557 |
+
|
| 558 |
+
asyncio.run(main())
|
| 559 |
+
```
|
| 560 |
+
|
| 561 |
+
## Streaming Agent Responses
|
| 562 |
+
|
| 563 |
+
Stream token-by-token responses from the agent
|
| 564 |
+
|
| 565 |
+
```python
|
| 566 |
+
from agents import Agent, Runner
|
| 567 |
+
from openai.types.responses import ResponseTextDeltaEvent
|
| 568 |
+
import asyncio
|
| 569 |
+
|
| 570 |
+
async def main():
|
| 571 |
+
agent = Agent(
|
| 572 |
+
name="Storyteller",
|
| 573 |
+
instructions="Tell engaging short stories"
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
result = Runner.run_streamed(agent, "Tell me a story about a robot")
|
| 577 |
+
|
| 578 |
+
print("Streaming response: ", end="", flush=True)
|
| 579 |
+
async for event in result.stream_events():
|
| 580 |
+
if event.type == "raw_response_event":
|
| 581 |
+
if isinstance(event.data, ResponseTextDeltaEvent):
|
| 582 |
+
print(event.data.delta, end="", flush=True)
|
| 583 |
+
|
| 584 |
+
print("\n\nFinal output:", result.final_output)
|
| 585 |
+
|
| 586 |
+
asyncio.run(main())
|
| 587 |
+
```
|
| 588 |
+
|
| 589 |
+
## Streaming with Item-Level Events
|
| 590 |
+
|
| 591 |
+
Stream higher-level events like tool calls and messages
|
| 592 |
+
|
| 593 |
+
```python
|
| 594 |
+
from agents import Agent, Runner, ItemHelpers, function_tool
|
| 595 |
+
import asyncio
|
| 596 |
+
import random
|
| 597 |
+
|
| 598 |
+
@function_tool
|
| 599 |
+
def roll_dice(sides: int = 6) -> int:
|
| 600 |
+
"""Roll a dice with specified number of sides."""
|
| 601 |
+
return random.randint(1, sides)
|
| 602 |
+
|
| 603 |
+
async def main():
|
| 604 |
+
agent = Agent(
|
| 605 |
+
name="Game Master",
|
| 606 |
+
instructions="Use the dice rolling tool when asked",
|
| 607 |
+
tools=[roll_dice]
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
result = Runner.run_streamed(agent, "Roll two dice for me")
|
| 611 |
+
|
| 612 |
+
async for event in result.stream_events():
|
| 613 |
+
if event.type == "raw_response_event":
|
| 614 |
+
continue # Skip token-level events
|
| 615 |
+
elif event.type == "agent_updated_stream_event":
|
| 616 |
+
print(f"Agent: {event.new_agent.name}")
|
| 617 |
+
elif event.type == "run_item_stream_event":
|
| 618 |
+
if event.item.type == "tool_call_item":
|
| 619 |
+
print("🔧 Tool called")
|
| 620 |
+
elif event.item.type == "tool_call_output_item":
|
| 621 |
+
print(f"📤 Tool result: {event.item.output}")
|
| 622 |
+
elif event.item.type == "message_output_item":
|
| 623 |
+
text = ItemHelpers.text_message_output(event.item)
|
| 624 |
+
print(f"💬 Agent: {text}")
|
| 625 |
+
|
| 626 |
+
asyncio.run(main())
|
| 627 |
+
```
|
| 628 |
+
|
| 629 |
+
## Agents as Tools Pattern
|
| 630 |
+
|
| 631 |
+
Use specialized agents as tools in a central orchestrator
|
| 632 |
+
|
| 633 |
+
```python
|
| 634 |
+
from agents import Agent, Runner
|
| 635 |
+
import asyncio
|
| 636 |
+
|
| 637 |
+
translation_agent = Agent(
|
| 638 |
+
name="Translator",
|
| 639 |
+
instructions="Translate the user's message to the specified language"
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
summarization_agent = Agent(
|
| 643 |
+
name="Summarizer",
|
| 644 |
+
instructions="Create a concise summary of the provided text"
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
orchestrator = Agent(
|
| 648 |
+
name="Orchestrator",
|
| 649 |
+
instructions="Use the available tools to process user requests efficiently",
|
| 650 |
+
tools=[
|
| 651 |
+
translation_agent.as_tool(
|
| 652 |
+
tool_name="translate_text",
|
| 653 |
+
tool_description="Translate text to another language"
|
| 654 |
+
),
|
| 655 |
+
summarization_agent.as_tool(
|
| 656 |
+
tool_name="summarize_text",
|
| 657 |
+
tool_description="Generate a summary of long text"
|
| 658 |
+
)
|
| 659 |
+
]
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
async def main():
|
| 663 |
+
result = await Runner.run(
|
| 664 |
+
orchestrator,
|
| 665 |
+
"Translate 'Hello, how are you?' to French and Spanish"
|
| 666 |
+
)
|
| 667 |
+
print(result.final_output)
|
| 668 |
+
|
| 669 |
+
asyncio.run(main())
|
| 670 |
+
```
|
| 671 |
+
|
| 672 |
+
## Custom Model Configuration
|
| 673 |
+
|
| 674 |
+
Configure model settings and use different models per agent
|
| 675 |
+
|
| 676 |
+
```python
|
| 677 |
+
from agents import Agent, Runner, ModelSettings
|
| 678 |
+
from openai.types.shared import Reasoning
|
| 679 |
+
import asyncio
|
| 680 |
+
|
| 681 |
+
reasoning_agent = Agent(
|
| 682 |
+
name="Deep Thinker",
|
| 683 |
+
instructions="Analyze complex problems thoroughly",
|
| 684 |
+
model="gpt-5",
|
| 685 |
+
model_settings=ModelSettings(
|
| 686 |
+
reasoning=Reasoning(effort="high"),
|
| 687 |
+
temperature=0.7,
|
| 688 |
+
verbosity="high"
|
| 689 |
+
)
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
fast_agent = Agent(
|
| 693 |
+
name="Quick Responder",
|
| 694 |
+
instructions="Provide rapid responses",
|
| 695 |
+
model="gpt-5-nano",
|
| 696 |
+
model_settings=ModelSettings(
|
| 697 |
+
reasoning=Reasoning(effort="minimal"),
|
| 698 |
+
temperature=0.3,
|
| 699 |
+
verbosity="low"
|
| 700 |
+
)
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
triage_agent = Agent(
|
| 704 |
+
name="Router",
|
| 705 |
+
instructions="Route complex problems to deep thinker, simple ones to quick responder",
|
| 706 |
+
handoffs=[reasoning_agent, fast_agent]
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
async def main():
|
| 710 |
+
result = await Runner.run(
|
| 711 |
+
triage_agent,
|
| 712 |
+
"Explain quantum entanglement in simple terms"
|
| 713 |
+
)
|
| 714 |
+
print(result.final_output)
|
| 715 |
+
|
| 716 |
+
asyncio.run(main())
|
| 717 |
+
```
|
| 718 |
+
|
| 719 |
+
## MCP Hosted Tool Integration
|
| 720 |
+
|
| 721 |
+
Use Model Context Protocol servers as hosted tools
|
| 722 |
+
|
| 723 |
+
```python
|
| 724 |
+
from agents import Agent, Runner, HostedMCPTool
|
| 725 |
+
import asyncio
|
| 726 |
+
|
| 727 |
+
async def main():
|
| 728 |
+
agent = Agent(
|
| 729 |
+
name="Code Assistant",
|
| 730 |
+
instructions="Help with repository questions using git tools",
|
| 731 |
+
tools=[
|
| 732 |
+
HostedMCPTool(
|
| 733 |
+
tool_config={
|
| 734 |
+
"type": "mcp",
|
| 735 |
+
"server_label": "gitmcp",
|
| 736 |
+
"server_url": "https://gitmcp.io/openai/codex",
|
| 737 |
+
"require_approval": "never"
|
| 738 |
+
}
|
| 739 |
+
)
|
| 740 |
+
]
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
result = await Runner.run(
|
| 744 |
+
agent,
|
| 745 |
+
"What programming languages are used in this repository?"
|
| 746 |
+
)
|
| 747 |
+
print(result.final_output)
|
| 748 |
+
|
| 749 |
+
asyncio.run(main())
|
| 750 |
+
```
|
| 751 |
+
|
| 752 |
+
## MCP Server with Streamable HTTP
|
| 753 |
+
|
| 754 |
+
Connect to local or remote MCP servers via HTTP
|
| 755 |
+
|
| 756 |
+
```python
|
| 757 |
+
from agents import Agent, Runner, ModelSettings
|
| 758 |
+
from agents.mcp import MCPServerStreamableHttp
|
| 759 |
+
import asyncio
|
| 760 |
+
import os
|
| 761 |
+
|
| 762 |
+
async def main():
|
| 763 |
+
token = os.environ["MCP_SERVER_TOKEN"]
|
| 764 |
+
|
| 765 |
+
async with MCPServerStreamableHttp(
|
| 766 |
+
name="Calculator Server",
|
| 767 |
+
params={
|
| 768 |
+
"url": "http://localhost:8000/mcp",
|
| 769 |
+
"headers": {"Authorization": f"Bearer {token}"},
|
| 770 |
+
"timeout": 10
|
| 771 |
+
},
|
| 772 |
+
cache_tools_list=True,
|
| 773 |
+
max_retry_attempts=3
|
| 774 |
+
) as server:
|
| 775 |
+
agent = Agent(
|
| 776 |
+
name="Math Assistant",
|
| 777 |
+
instructions="Use MCP tools to perform calculations",
|
| 778 |
+
mcp_servers=[server],
|
| 779 |
+
model_settings=ModelSettings(tool_choice="required")
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
result = await Runner.run(agent, "Calculate 47 + 89")
|
| 783 |
+
print(result.final_output)
|
| 784 |
+
|
| 785 |
+
asyncio.run(main())
|
| 786 |
+
```
|
| 787 |
+
|
| 788 |
+
## MCP stdio Server
|
| 789 |
+
|
| 790 |
+
Launch local MCP server processes
|
| 791 |
+
|
| 792 |
+
```python
|
| 793 |
+
from agents import Agent, Runner
|
| 794 |
+
from agents.mcp import MCPServerStdio
|
| 795 |
+
from pathlib import Path
|
| 796 |
+
import asyncio
|
| 797 |
+
|
| 798 |
+
async def main():
|
| 799 |
+
samples_dir = Path(__file__).parent / "sample_files"
|
| 800 |
+
|
| 801 |
+
async with MCPServerStdio(
|
| 802 |
+
name="Filesystem Server",
|
| 803 |
+
params={
|
| 804 |
+
"command": "npx",
|
| 805 |
+
"args": ["-y", "@modelcontextprotocol/server-filesystem", str(samples_dir)]
|
| 806 |
+
}
|
| 807 |
+
) as server:
|
| 808 |
+
agent = Agent(
|
| 809 |
+
name="File Assistant",
|
| 810 |
+
instructions="Help users work with files in the sample directory",
|
| 811 |
+
mcp_servers=[server]
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
result = await Runner.run(agent, "List all files in the directory")
|
| 815 |
+
print(result.final_output)
|
| 816 |
+
|
| 817 |
+
asyncio.run(main())
|
| 818 |
+
```
|
| 819 |
+
|
| 820 |
+
## Tracing and Monitoring
|
| 821 |
+
|
| 822 |
+
Built-in tracing for debugging and monitoring agent workflows
|
| 823 |
+
|
| 824 |
+
```python
|
| 825 |
+
from agents import Agent, Runner, trace
|
| 826 |
+
import asyncio
|
| 827 |
+
|
| 828 |
+
async def main():
|
| 829 |
+
agent = Agent(
|
| 830 |
+
name="Research Agent",
|
| 831 |
+
instructions="Research topics thoroughly"
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Trace multiple runs under single workflow
|
| 835 |
+
with trace(
|
| 836 |
+
workflow_name="Research Workflow",
|
| 837 |
+
group_id="session_123",
|
| 838 |
+
metadata={"user": "alice", "environment": "production"}
|
| 839 |
+
):
|
| 840 |
+
result1 = await Runner.run(agent, "What is machine learning?")
|
| 841 |
+
print(f"Response 1: {result1.final_output}")
|
| 842 |
+
|
| 843 |
+
result2 = await Runner.run(agent, "Explain neural networks")
|
| 844 |
+
print(f"Response 2: {result2.final_output}")
|
| 845 |
+
|
| 846 |
+
# View traces at: https://platform.openai.com/traces
|
| 847 |
+
|
| 848 |
+
asyncio.run(main())
|
| 849 |
+
```
|
| 850 |
+
|
| 851 |
+
## Error Handling
|
| 852 |
+
|
| 853 |
+
Handle exceptions from agent runs, guardrails, and tool failures
|
| 854 |
+
|
| 855 |
+
```python
|
| 856 |
+
from agents import Agent, Runner, function_tool
|
| 857 |
+
from agents.exceptions import (
|
| 858 |
+
MaxTurnsExceeded,
|
| 859 |
+
InputGuardrailTripwireTriggered,
|
| 860 |
+
ModelBehaviorError
|
| 861 |
+
)
|
| 862 |
+
import asyncio
|
| 863 |
+
|
| 864 |
+
@function_tool
|
| 865 |
+
def risky_operation() -> str:
|
| 866 |
+
"""An operation that might fail."""
|
| 867 |
+
raise ValueError("Operation failed!")
|
| 868 |
+
|
| 869 |
+
agent = Agent(
|
| 870 |
+
name="Assistant",
|
| 871 |
+
instructions="Help users with tasks",
|
| 872 |
+
tools=[risky_operation]
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
async def main():
|
| 876 |
+
try:
|
| 877 |
+
result = await Runner.run(
|
| 878 |
+
agent,
|
| 879 |
+
"Run the risky operation",
|
| 880 |
+
max_turns=5
|
| 881 |
+
)
|
| 882 |
+
print(result.final_output)
|
| 883 |
+
|
| 884 |
+
except MaxTurnsExceeded:
|
| 885 |
+
print("Error: Agent exceeded maximum turns")
|
| 886 |
+
except InputGuardrailTripwireTriggered as e:
|
| 887 |
+
print(f"Error: Input blocked by guardrail: {e}")
|
| 888 |
+
except ModelBehaviorError as e:
|
| 889 |
+
print(f"Error: Model produced invalid output: {e}")
|
| 890 |
+
except Exception as e:
|
| 891 |
+
print(f"Unexpected error: {e}")
|
| 892 |
+
|
| 893 |
+
asyncio.run(main())
|
| 894 |
+
```
|
| 895 |
+
|
| 896 |
+
## Using Alternative Model Providers
|
| 897 |
+
|
| 898 |
+
Integrate non-OpenAI models via LiteLLM
|
| 899 |
+
|
| 900 |
+
```bash
|
| 901 |
+
pip install "openai-agents[litellm]"
|
| 902 |
+
```
|
| 903 |
+
|
| 904 |
+
```python
|
| 905 |
+
from agents import Agent, Runner
|
| 906 |
+
import asyncio
|
| 907 |
+
|
| 908 |
+
async def main():
|
| 909 |
+
# Use Claude via LiteLLM
|
| 910 |
+
claude_agent = Agent(
|
| 911 |
+
name="Claude Assistant",
|
| 912 |
+
instructions="You are a helpful assistant",
|
| 913 |
+
model="litellm/anthropic/claude-3-5-sonnet-20240620"
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
# Use Gemini via LiteLLM
|
| 917 |
+
gemini_agent = Agent(
|
| 918 |
+
name="Gemini Assistant",
|
| 919 |
+
instructions="You are a helpful assistant",
|
| 920 |
+
model="litellm/gemini/gemini-2.5-flash-preview-04-17"
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
result = await Runner.run(claude_agent, "Explain photosynthesis briefly")
|
| 924 |
+
print(result.final_output)
|
| 925 |
+
|
| 926 |
+
asyncio.run(main())
|
| 927 |
+
```
|
| 928 |
+
|
| 929 |
+
---
|
| 930 |
+
|
| 931 |
+
## Summary
|
| 932 |
+
|
| 933 |
+
The OpenAI Agents SDK provides a comprehensive yet simple framework for building agentic AI applications in Python. Core use cases include single-agent assistants with tool access, multi-agent systems with specialized roles using handoffs, conversational applications with automatic session memory, and workflows with input/output validation via guardrails. The SDK excels at building customer service bots with agent routing, research assistants with web search and file retrieval, code generation tools with MCP integration, and any application requiring LLM orchestration with minimal boilerplate.
|
| 934 |
+
|
| 935 |
+
Integration patterns follow Python-first principles using native async/await, context managers for resource handling, decorators for function tools, and Pydantic models for structured outputs. The framework supports horizontal scaling through session persistence with SQLite or SQLAlchemy backends, vertical scaling with model mixing (fast models for triage, powerful models for complex tasks), and comprehensive observability through built-in tracing to OpenAI's dashboard or custom processors. Whether building prototypes or production systems, the SDK's balance of simplicity and power makes it an ideal choice for Python developers working with AI agents.
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|