Spaces:
Sleeping
Sleeping
Commit ·
db7c1e8
0
Parent(s):
Add full AI Native Textbook project source code
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .env.example +13 -0
- Dockerfile +25 -0
- README.md +212 -0
- api/auth.py +108 -0
- api/chat.py +72 -0
- api/personalization.py +71 -0
- api/rag_search.py +10 -0
- api/translation.py +66 -0
- database/schema.sql +50 -0
- debug_comprehensive.py +79 -0
- debug_qdrant.py +81 -0
- docker-compose.yml +47 -0
- final_verification.py +83 -0
- main.py +53 -0
- middleware/auth_middleware.py +46 -0
- models/chat_session.py +12 -0
- models/user.py +13 -0
- models/user_profile.py +14 -0
- pyproject.toml +70 -0
- requirements.txt +19 -0
- services/content_adaptation.py +105 -0
- services/personalization_service.py +95 -0
- services/rag_service.py +118 -0
- services/translation_service.py +144 -0
- services/vector_db.py +127 -0
- setup_sample_content.py +86 -0
- src/__init__.py +0 -0
- src/auth/__init__.py +0 -0
- src/auth/auth.py +132 -0
- src/auth/middleware.py +74 -0
- src/auth/schemas.py +53 -0
- src/config/__init__.py +0 -0
- src/config/database.py +61 -0
- src/config/settings.py +69 -0
- src/db/__init__.py +0 -0
- src/db/base.py +28 -0
- src/db/crud.py +432 -0
- src/db/models/__init__.py +0 -0
- src/db/models/chat_history.py +28 -0
- src/db/models/document.py +31 -0
- src/db/models/user.py +28 -0
- src/embeddings/__init__.py +0 -0
- src/embeddings/gemini_client.py +335 -0
- src/embeddings/processor.py +303 -0
- src/main.py +51 -0
- src/models/__init__.py +0 -0
- src/models/documents.py +32 -0
- src/models/search.py +43 -0
- src/qdrant/__init__.py +0 -0
- src/qdrant/client.py +140 -0
.env.example
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Configuration
|
| 2 |
+
GEMINI_API_KEY=your_gemini_api_key_here
|
| 3 |
+
QDRANT_URL=your_qdrant_url_here
|
| 4 |
+
QDRANT_API_KEY=your_qdrant_api_key_here
|
| 5 |
+
NEON_DB_URL=your_neon_db_connection_string_here
|
| 6 |
+
|
| 7 |
+
# JWT Configuration
|
| 8 |
+
SECRET_KEY=your_secret_key_here
|
| 9 |
+
JWT_EXPIRES_IN=3600
|
| 10 |
+
|
| 11 |
+
# Application Configuration
|
| 12 |
+
DEBUG=false
|
| 13 |
+
LOG_LEVEL=info
|
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
gcc \
|
| 8 |
+
g++ \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements first to leverage Docker cache
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
|
| 14 |
+
# Install Python dependencies
|
| 15 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 16 |
+
pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy the rest of the application
|
| 19 |
+
COPY . .
|
| 20 |
+
|
| 21 |
+
# Expose the port the app runs on
|
| 22 |
+
EXPOSE 8000
|
| 23 |
+
|
| 24 |
+
# Run the application
|
| 25 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
README.md
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AI Backend with RAG + Authentication
|
| 2 |
+
|
| 3 |
+
A scalable backend featuring authentication, RAG capabilities, and integration with external services. The system uses Better Auth for authentication, Qdrant for vector storage, Neon Postgres for relational data, and Google's Gemini models for embeddings and chat functionality.
|
| 4 |
+
|
| 5 |
+
## Architecture Overview
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
┌─────────────────┐ ┌──────────────────┐ ┌──────────────────┐
|
| 9 |
+
│ Frontend │────│ FastAPI │────│ Better Auth │
|
| 10 |
+
│ (Future) │ │ Backend │ │ Service │
|
| 11 |
+
└─────────────────┘ └──────────────────┘ └──────────────────┘
|
| 12 |
+
│
|
| 13 |
+
┌────────────────────┼────────────────────┐
|
| 14 |
+
│ │ │
|
| 15 |
+
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
| 16 |
+
│ Qdrant │ │ Neon │ │ Gemini │
|
| 17 |
+
│ Vector DB │ │ Postgres │ │ API │
|
| 18 |
+
└─────────────┘ └─────────────┘ └─────────────┘
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Features
|
| 22 |
+
|
| 23 |
+
- **Authentication**: JWT-based authentication with Better Auth
|
| 24 |
+
- **RAG Pipeline**: Retrieval-Augmented Generation with Qdrant vector database
|
| 25 |
+
- **AI Integration**: Google Gemini for embeddings and chat responses
|
| 26 |
+
- **Database**: Neon Postgres for user data and chat history
|
| 27 |
+
- **Security**: Password hashing, JWT validation, user isolation
|
| 28 |
+
- **Scalability**: Async architecture with connection pooling
|
| 29 |
+
|
| 30 |
+
## Prerequisites
|
| 31 |
+
|
| 32 |
+
- Python 3.9+
|
| 33 |
+
- Qdrant vector database instance
|
| 34 |
+
- Neon Postgres database
|
| 35 |
+
- Google Gemini API key
|
| 36 |
+
- Node.js (for development tools, optional)
|
| 37 |
+
|
| 38 |
+
## Setup
|
| 39 |
+
|
| 40 |
+
### 1. Clone the repository
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
git clone <repository-url>
|
| 44 |
+
cd backend
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### 2. Create a virtual environment
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
python -m venv venv
|
| 51 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### 3. Install dependencies
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
pip install -r requirements.txt
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### 4. Configure environment variables
|
| 61 |
+
|
| 62 |
+
Copy the example environment file:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
cp .env.example .env
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Then edit `.env` with your actual configuration:
|
| 69 |
+
|
| 70 |
+
```env
|
| 71 |
+
# API Configuration
|
| 72 |
+
GEMINI_API_KEY=your_gemini_api_key_here
|
| 73 |
+
QDRANT_URL=your_qdrant_url_here
|
| 74 |
+
QDRANT_API_KEY=your_qdrant_api_key_here
|
| 75 |
+
NEON_DB_URL=your_neon_db_connection_string_here
|
| 76 |
+
|
| 77 |
+
# JWT Configuration
|
| 78 |
+
SECRET_KEY=your_secret_key_here # Use a strong, random secret key
|
| 79 |
+
JWT_EXPIRES_IN=3600
|
| 80 |
+
|
| 81 |
+
# Application Configuration
|
| 82 |
+
DEBUG=false
|
| 83 |
+
LOG_LEVEL=info
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### 5. Run the application
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
cd src
|
| 90 |
+
python main.py
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Or using uvicorn directly:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
cd src
|
| 97 |
+
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
The application will be available at `http://localhost:8000`
|
| 101 |
+
|
| 102 |
+
## API Endpoints
|
| 103 |
+
|
| 104 |
+
### Authentication
|
| 105 |
+
- `POST /auth/signup` - User registration
|
| 106 |
+
- `POST /auth/login` - User login
|
| 107 |
+
- `GET /auth/me` - Get current user info
|
| 108 |
+
|
| 109 |
+
### RAG & Embeddings
|
| 110 |
+
- `POST /embed` - Generate embeddings for text
|
| 111 |
+
- `POST /save-document` - Save and embed a document
|
| 112 |
+
- `POST /search` - Semantic search in documents
|
| 113 |
+
- `POST /chat` - Chat with RAG context
|
| 114 |
+
|
| 115 |
+
### History
|
| 116 |
+
- `GET /history` - Get chat history
|
| 117 |
+
- `GET /history/{conversation_id}` - Get specific conversation
|
| 118 |
+
|
| 119 |
+
### Health
|
| 120 |
+
- `GET /health` - Health check endpoint
|
| 121 |
+
|
| 122 |
+
## Project Structure
|
| 123 |
+
|
| 124 |
+
```
|
| 125 |
+
backend/
|
| 126 |
+
├── src/
|
| 127 |
+
│ ├── __init__.py
|
| 128 |
+
│ ├── main.py # Application entry point
|
| 129 |
+
│ ├── config/ # Configuration management
|
| 130 |
+
│ │ ├── __init__.py
|
| 131 |
+
│ │ ├── settings.py # App settings and env vars
|
| 132 |
+
│ │ └── database.py # Database configuration
|
| 133 |
+
│ ├── auth/ # Authentication module
|
| 134 |
+
│ ├── db/ # Database module
|
| 135 |
+
│ │ ├── __init__.py
|
| 136 |
+
│ │ ├── base.py # Base model class
|
| 137 |
+
│ │ ├── models/ # SQLAlchemy models
|
| 138 |
+
│ │ ├── database.py # Database connection
|
| 139 |
+
│ │ └── crud.py # CRUD operations
|
| 140 |
+
│ ├── qdrant/ # Vector database module
|
| 141 |
+
│ ├── embeddings/ # Embedding module
|
| 142 |
+
│ ├── rag/ # RAG pipeline module
|
| 143 |
+
│ ├── routes/ # API routes
|
| 144 |
+
│ ├── models/ # Pydantic models
|
| 145 |
+
│ ├── utils/ # Utility functions
|
| 146 |
+
│ └── scripts/ # Utility scripts
|
| 147 |
+
├── tests/ # Test suite
|
| 148 |
+
├── requirements.txt # Python dependencies
|
| 149 |
+
├── .env.example # Environment variables template
|
| 150 |
+
└── README.md # Documentation
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Development
|
| 154 |
+
|
| 155 |
+
### Running tests
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
cd backend
|
| 159 |
+
python -m pytest tests/ -v
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
### Running with auto-reload during development
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
cd src
|
| 166 |
+
uvicorn main:app --reload
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
## Environment Variables
|
| 170 |
+
|
| 171 |
+
| Variable | Description | Required |
|
| 172 |
+
|----------|-------------|----------|
|
| 173 |
+
| GEMINI_API_KEY | Google Gemini API key | Yes |
|
| 174 |
+
| QDRANT_URL | Qdrant vector database URL | Yes |
|
| 175 |
+
| QDRANT_API_KEY | Qdrant API key (if secured) | No |
|
| 176 |
+
| NEON_DB_URL | Neon Postgres connection string | Yes |
|
| 177 |
+
| SECRET_KEY | JWT secret key | Yes |
|
| 178 |
+
| JWT_EXPIRES_IN | JWT expiration time in seconds | No (default: 3600) |
|
| 179 |
+
| DEBUG | Enable debug mode | No (default: false) |
|
| 180 |
+
| LOG_LEVEL | Logging level | No (default: info) |
|
| 181 |
+
|
| 182 |
+
## Security Considerations
|
| 183 |
+
|
| 184 |
+
- Always use HTTPS in production
|
| 185 |
+
- Store secrets securely (not in version control)
|
| 186 |
+
- Validate and sanitize all user inputs
|
| 187 |
+
- Use parameterized queries to prevent SQL injection
|
| 188 |
+
- Implement rate limiting to prevent abuse
|
| 189 |
+
- Use strong, randomly generated secret keys
|
| 190 |
+
|
| 191 |
+
## Performance
|
| 192 |
+
|
| 193 |
+
- Async architecture for high concurrency
|
| 194 |
+
- Connection pooling for database operations
|
| 195 |
+
- Caching mechanisms for frequently accessed data
|
| 196 |
+
- Optimized vector search with Qdrant
|
| 197 |
+
- Efficient embedding processing pipeline
|
| 198 |
+
|
| 199 |
+
## Contributing
|
| 200 |
+
|
| 201 |
+
1. Fork the repository
|
| 202 |
+
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
|
| 203 |
+
3. Make your changes
|
| 204 |
+
4. Add tests if applicable
|
| 205 |
+
5. Run tests (`python -m pytest`)
|
| 206 |
+
6. Commit your changes (`git commit -m 'Add amazing feature'`)
|
| 207 |
+
7. Push to the branch (`git push origin feature/amazing-feature`)
|
| 208 |
+
8. Open a Pull Request
|
| 209 |
+
|
| 210 |
+
## License
|
| 211 |
+
|
| 212 |
+
[Add your license here]
|
api/auth.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from models.user import User
|
| 5 |
+
from models.user_profile import UserProfile
|
| 6 |
+
import os
|
| 7 |
+
import bcrypt
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
class SignupRequest(BaseModel):
|
| 12 |
+
email: str
|
| 13 |
+
password: str
|
| 14 |
+
software_background: Optional[str] = None
|
| 15 |
+
hardware_background: Optional[str] = None
|
| 16 |
+
experience_level: Optional[str] = None
|
| 17 |
+
|
| 18 |
+
class LoginRequest(BaseModel):
|
| 19 |
+
email: str
|
| 20 |
+
password: str
|
| 21 |
+
|
| 22 |
+
class AuthResponse(BaseModel):
|
| 23 |
+
user_id: str
|
| 24 |
+
email: str
|
| 25 |
+
access_token: str
|
| 26 |
+
refresh_token: str
|
| 27 |
+
|
| 28 |
+
@router.post("/auth/signup", response_model=AuthResponse)
|
| 29 |
+
async def signup(request: SignupRequest):
|
| 30 |
+
"""Handle user registration with background information"""
|
| 31 |
+
try:
|
| 32 |
+
# In a real implementation, you would hash the password and store user in DB
|
| 33 |
+
# For now, we'll simulate the process
|
| 34 |
+
|
| 35 |
+
# Hash the password
|
| 36 |
+
hashed_password = bcrypt.hashpw(request.password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
|
| 37 |
+
|
| 38 |
+
# Create user object
|
| 39 |
+
user = User(
|
| 40 |
+
email=request.email,
|
| 41 |
+
password=hashed_password, # In real app, don't return the hash
|
| 42 |
+
software_background=request.software_background,
|
| 43 |
+
hardware_background=request.hardware_background,
|
| 44 |
+
experience_level=request.experience_level
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Create user profile
|
| 48 |
+
user_profile = UserProfile(
|
| 49 |
+
user_id="temp_user_id", # In real app, this would be the actual user ID
|
| 50 |
+
software_background=request.software_background,
|
| 51 |
+
hardware_background=request.hardware_background,
|
| 52 |
+
experience_level=request.experience_level
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# In a real implementation, you would store these in the database
|
| 56 |
+
# and generate proper JWT tokens
|
| 57 |
+
|
| 58 |
+
# For now, return a mock response
|
| 59 |
+
return AuthResponse(
|
| 60 |
+
user_id="temp_user_id",
|
| 61 |
+
email=request.email,
|
| 62 |
+
access_token="mock_access_token",
|
| 63 |
+
refresh_token="mock_refresh_token"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
raise HTTPException(status_code=500, detail=f"Error during signup: {str(e)}")
|
| 68 |
+
|
| 69 |
+
@router.post("/auth/login", response_model=AuthResponse)
|
| 70 |
+
async def login(request: LoginRequest):
|
| 71 |
+
"""Handle user login"""
|
| 72 |
+
try:
|
| 73 |
+
# In a real implementation, you would verify credentials against DB
|
| 74 |
+
# For now, we'll simulate the process
|
| 75 |
+
|
| 76 |
+
# For demo purposes, we'll just return a mock response
|
| 77 |
+
# In a real app, you'd verify the password and generate tokens
|
| 78 |
+
return AuthResponse(
|
| 79 |
+
user_id="temp_user_id",
|
| 80 |
+
email=request.email,
|
| 81 |
+
access_token="mock_access_token",
|
| 82 |
+
refresh_token="mock_refresh_token"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
raise HTTPException(status_code=500, detail=f"Error during login: {str(e)}")
|
| 87 |
+
|
| 88 |
+
@router.get("/auth/profile")
|
| 89 |
+
async def get_profile():
|
| 90 |
+
"""Get user profile information"""
|
| 91 |
+
try:
|
| 92 |
+
# In a real implementation, you would retrieve from DB based on auth token
|
| 93 |
+
profile = UserProfile(
|
| 94 |
+
user_id="temp_user_id",
|
| 95 |
+
software_background="Software Engineer",
|
| 96 |
+
hardware_background="Beginner",
|
| 97 |
+
experience_level="Intermediate"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return profile
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving profile: {str(e)}")
|
| 104 |
+
|
| 105 |
+
@router.get("/auth/health")
|
| 106 |
+
async def auth_health():
|
| 107 |
+
"""Health check for auth service"""
|
| 108 |
+
return {"status": "auth service is running"}
|
api/chat.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from fastapi import APIRouter
|
| 4 |
+
import logging
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from services.rag_service import RAGService
|
| 10 |
+
|
| 11 |
+
router = APIRouter()
|
| 12 |
+
|
| 13 |
+
# Configure OpenRouter and RAG service
|
| 14 |
+
openrouter_api_key = os.getenv("OPENAI_API_KEY")
|
| 15 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 16 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 17 |
+
collection_name = os.getenv("QDRANT_COLLECTION", "project_documents")
|
| 18 |
+
|
| 19 |
+
if openrouter_api_key and openrouter_api_key != "your_openrouter_api_key_here":
|
| 20 |
+
# Initialize Qdrant client for cloud
|
| 21 |
+
if qdrant_url and qdrant_api_key and "qdrant.io" in qdrant_url:
|
| 22 |
+
qdrant_client = QdrantClient(
|
| 23 |
+
url=qdrant_url.replace(":6333", ""), # Remove port from URL for cloud
|
| 24 |
+
api_key=qdrant_api_key,
|
| 25 |
+
prefer_grpc=False
|
| 26 |
+
)
|
| 27 |
+
else:
|
| 28 |
+
# Use local Qdrant if cloud not configured
|
| 29 |
+
qdrant_client = QdrantClient(
|
| 30 |
+
host=os.getenv("QDRANT_HOST", "localhost"),
|
| 31 |
+
port=int(os.getenv("QDRANT_PORT", 6333))
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Initialize RAG service with OpenRouter
|
| 35 |
+
rag_service = RAGService(openrouter_api_key, qdrant_client, collection_name)
|
| 36 |
+
else:
|
| 37 |
+
rag_service = None
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
@router.post("/chat")
|
| 42 |
+
async def chat(payload: dict):
|
| 43 |
+
user_msg = payload["message"]
|
| 44 |
+
selected_text = payload.get("selected_text", "")
|
| 45 |
+
|
| 46 |
+
# If selected text is provided, try to use RAG service to answer based only on that text
|
| 47 |
+
if selected_text and rag_service:
|
| 48 |
+
try:
|
| 49 |
+
# Use the RAG service to answer based on selected text only (with OpenRouter)
|
| 50 |
+
answer = rag_service.query_rag(selected_text, user_msg)
|
| 51 |
+
return {"answer": answer}
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"RAG service failed: {str(e)}")
|
| 54 |
+
# Fall through to fallback response below
|
| 55 |
+
elif selected_text and not rag_service:
|
| 56 |
+
logger.warning("RAG service not available, using fallback")
|
| 57 |
+
|
| 58 |
+
# Fallback response when API is unavailable or not configured
|
| 59 |
+
fallback_responses = {
|
| 60 |
+
"hello": "Hello! I'm your AI textbook assistant. Feel free to ask questions about the content you're studying!",
|
| 61 |
+
"hi": "Hi there! I'm here to help you understand the AI and robotics concepts in your textbook. What would you like to know?",
|
| 62 |
+
"help": "I can help explain concepts from your AI and robotics textbook! Please select some text and ask questions about it.",
|
| 63 |
+
"default": f"I'm currently unable to process your request about '{user_msg}'. This might be because the AI service is temporarily unavailable or needs to be configured with a valid API key. The system is working properly but requires a valid OPENROUTER_API_KEY to provide AI-generated responses."
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
response_text = fallback_responses.get(user_msg.lower().strip(), fallback_responses["default"])
|
| 67 |
+
|
| 68 |
+
result = {"answer": response_text}
|
| 69 |
+
if not rag_service:
|
| 70 |
+
result["setup_needed"] = "Please configure a valid OPENROUTER_API_KEY in the .env file to enable AI responses"
|
| 71 |
+
|
| 72 |
+
return result
|
api/personalization.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from services.personalization_service import PersonalizationService
|
| 7 |
+
from services.content_adaptation import ContentAdaptationService
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
router = APIRouter()
|
| 12 |
+
|
| 13 |
+
class PersonalizationRequest(BaseModel):
|
| 14 |
+
content: str
|
| 15 |
+
user_profile: Dict[str, Any]
|
| 16 |
+
chapter_id: str
|
| 17 |
+
|
| 18 |
+
class PersonalizationResponse(BaseModel):
|
| 19 |
+
personalized_content: str
|
| 20 |
+
adaptation_details: Dict[str, Any]
|
| 21 |
+
|
| 22 |
+
@router.post("/personalization/adapt", response_model=PersonalizationResponse)
|
| 23 |
+
async def adapt_content(request: PersonalizationRequest):
|
| 24 |
+
"""Adapt content based on user profile and background"""
|
| 25 |
+
try:
|
| 26 |
+
# Initialize content adaptation service
|
| 27 |
+
content_adaptation_service = ContentAdaptationService(
|
| 28 |
+
gemini_api_key=os.getenv("GEMINI_API_KEY", "your-gemini-key-here")
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Initialize personalization service with content adaptation service
|
| 32 |
+
personalization_service = PersonalizationService(content_adaptation_service)
|
| 33 |
+
|
| 34 |
+
# Adapt the content based on user profile
|
| 35 |
+
adapted_content = personalization_service.get_personalized_content(
|
| 36 |
+
content=request.content,
|
| 37 |
+
user_profile=request.user_profile,
|
| 38 |
+
chapter_id=request.chapter_id
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Prepare adaptation details
|
| 42 |
+
adaptation_details = {
|
| 43 |
+
"status": "success",
|
| 44 |
+
"user_software_background": request.user_profile.get('software_background', 'General'),
|
| 45 |
+
"user_hardware_background": request.user_profile.get('hardware_background', 'General'),
|
| 46 |
+
"user_experience_level": request.user_profile.get('experience_level', 'Intermediate'),
|
| 47 |
+
"chapter_id": request.chapter_id,
|
| 48 |
+
"adaptation_method": "AI-driven personalization"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
return PersonalizationResponse(
|
| 52 |
+
personalized_content=adapted_content,
|
| 53 |
+
adaptation_details=adaptation_details
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Error adapting content: {str(e)}")
|
| 58 |
+
# Return original content if personalization fails, but still provide a response
|
| 59 |
+
return PersonalizationResponse(
|
| 60 |
+
personalized_content=request.content,
|
| 61 |
+
adaptation_details={
|
| 62 |
+
"status": "fallback",
|
| 63 |
+
"message": "Content personalization is temporarily unavailable. Showing original content.",
|
| 64 |
+
"original_chapter_id": request.chapter_id
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
@router.get("/personalization/health")
|
| 69 |
+
async def personalization_health():
|
| 70 |
+
"""Health check for personalization service"""
|
| 71 |
+
return {"status": "personalization service is running"}
|
api/rag_search.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
|
| 3 |
+
router = APIRouter()
|
| 4 |
+
|
| 5 |
+
@router.post("/rag-search")
|
| 6 |
+
async def rag_search(payload: dict):
|
| 7 |
+
query = payload["query"]
|
| 8 |
+
# For now, return an empty result as the RAG functionality requires proper vector DB setup
|
| 9 |
+
# In a full implementation, this would search the vector database
|
| 10 |
+
return {"results": []}
|
api/translation.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from services.translation_service import TranslationService
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
router = APIRouter()
|
| 8 |
+
|
| 9 |
+
class TranslationRequest(BaseModel):
|
| 10 |
+
text: str
|
| 11 |
+
source_lang: str = "en"
|
| 12 |
+
target_lang: str = "ur"
|
| 13 |
+
|
| 14 |
+
class TranslationResponse(BaseModel):
|
| 15 |
+
original_text: str
|
| 16 |
+
translated_text: str
|
| 17 |
+
source_lang: str
|
| 18 |
+
target_lang: str
|
| 19 |
+
|
| 20 |
+
@router.post("/translation/translate", response_model=TranslationResponse)
|
| 21 |
+
async def translate_text(request: TranslationRequest):
|
| 22 |
+
"""Translate text between languages (currently English to Urdu)"""
|
| 23 |
+
try:
|
| 24 |
+
# Initialize translation service
|
| 25 |
+
translation_service = TranslationService(
|
| 26 |
+
gemini_api_key=os.getenv("GEMINI_API_KEY", "your-gemini-key-here")
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
if request.source_lang == "en" and request.target_lang == "ur":
|
| 30 |
+
# Translate English to Urdu
|
| 31 |
+
translated_text = translation_service.translate_to_urdu(request.text)
|
| 32 |
+
elif request.source_lang == "ur" and request.target_lang == "en":
|
| 33 |
+
# Translate Urdu to English
|
| 34 |
+
translated_text = translation_service.translate_to_english(request.text)
|
| 35 |
+
else:
|
| 36 |
+
raise HTTPException(
|
| 37 |
+
status_code=400,
|
| 38 |
+
detail=f"Unsupported language pair: {request.source_lang} to {request.target_lang}. Currently supported: en-ur"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return TranslationResponse(
|
| 42 |
+
original_text=request.text,
|
| 43 |
+
translated_text=translated_text,
|
| 44 |
+
source_lang=request.source_lang,
|
| 45 |
+
target_lang=request.target_lang
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
except Exception as e:
|
| 49 |
+
raise HTTPException(status_code=500, detail=f"Error during translation: {str(e)}")
|
| 50 |
+
|
| 51 |
+
@router.get("/translation/health")
|
| 52 |
+
async def translation_health():
|
| 53 |
+
"""Health check for translation service"""
|
| 54 |
+
return {"status": "translation service is running"}
|
| 55 |
+
|
| 56 |
+
@router.post("/translation/clear-cache")
|
| 57 |
+
async def clear_translation_cache():
|
| 58 |
+
"""Clear the translation cache"""
|
| 59 |
+
try:
|
| 60 |
+
translation_service = TranslationService(
|
| 61 |
+
gemini_api_key=os.getenv("GEMINI_API_KEY", "your-gemini-key-here")
|
| 62 |
+
)
|
| 63 |
+
translation_service.clear_cache()
|
| 64 |
+
return {"status": "translation cache cleared"}
|
| 65 |
+
except Exception as e:
|
| 66 |
+
raise HTTPException(status_code=500, detail=f"Error clearing cache: {str(e)}")
|
database/schema.sql
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- Database schema for Neon Postgres
|
| 2 |
+
|
| 3 |
+
-- Users table
|
| 4 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 5 |
+
id SERIAL PRIMARY KEY,
|
| 6 |
+
email VARCHAR(255) UNIQUE NOT NULL,
|
| 7 |
+
password_hash VARCHAR(255) NOT NULL,
|
| 8 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 9 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 10 |
+
software_background VARCHAR(100),
|
| 11 |
+
hardware_background VARCHAR(100),
|
| 12 |
+
experience_level VARCHAR(50)
|
| 13 |
+
);
|
| 14 |
+
|
| 15 |
+
-- User profiles table
|
| 16 |
+
CREATE TABLE IF NOT EXISTS user_profiles (
|
| 17 |
+
id SERIAL PRIMARY KEY,
|
| 18 |
+
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
|
| 19 |
+
personalization_settings JSONB DEFAULT '{}',
|
| 20 |
+
learning_progress JSONB DEFAULT '{}',
|
| 21 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 22 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
-- Chat sessions table
|
| 26 |
+
CREATE TABLE IF NOT EXISTS chat_sessions (
|
| 27 |
+
id SERIAL PRIMARY KEY,
|
| 28 |
+
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
|
| 29 |
+
selected_text TEXT NOT NULL,
|
| 30 |
+
question TEXT NOT NULL,
|
| 31 |
+
response TEXT NOT NULL,
|
| 32 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 33 |
+
conversation_history JSONB DEFAULT '[]'
|
| 34 |
+
);
|
| 35 |
+
|
| 36 |
+
-- Textbook content table (for RAG)
|
| 37 |
+
CREATE TABLE IF NOT EXISTS textbook_content (
|
| 38 |
+
id SERIAL PRIMARY KEY,
|
| 39 |
+
chapter_id VARCHAR(100) NOT NULL,
|
| 40 |
+
chapter_title VARCHAR(255) NOT NULL,
|
| 41 |
+
content TEXT NOT NULL,
|
| 42 |
+
embeddings JSONB, -- Store vector embeddings
|
| 43 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 44 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 45 |
+
);
|
| 46 |
+
|
| 47 |
+
-- Indexes
|
| 48 |
+
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
| 49 |
+
CREATE INDEX IF NOT EXISTS idx_chat_sessions_user_id ON chat_sessions(user_id);
|
| 50 |
+
CREATE INDEX IF NOT EXISTS idx_textbook_content_chapter_id ON textbook_content(chapter_id);
|
debug_comprehensive.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from qdrant_client import QdrantClient
|
| 5 |
+
from qdrant_client.http import models
|
| 6 |
+
|
| 7 |
+
# Load environment variables from parent directory
|
| 8 |
+
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '.env'))
|
| 9 |
+
|
| 10 |
+
def comprehensive_debug():
|
| 11 |
+
# Get environment variables
|
| 12 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 13 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 14 |
+
collection_name = os.getenv("QDRANT_COLLECTION", "project_documents")
|
| 15 |
+
|
| 16 |
+
print(f"QDRANT_URL: {qdrant_url}")
|
| 17 |
+
print(f"Collection: {collection_name}")
|
| 18 |
+
|
| 19 |
+
# Initialize Qdrant client for cloud
|
| 20 |
+
if qdrant_url and qdrant_api_key and "qdrant.io" in qdrant_url:
|
| 21 |
+
qdrant_client = QdrantClient(
|
| 22 |
+
url=qdrant_url.replace(":6333", ""), # Remove port from URL for cloud
|
| 23 |
+
api_key=qdrant_api_key,
|
| 24 |
+
prefer_grpc=False
|
| 25 |
+
)
|
| 26 |
+
else:
|
| 27 |
+
# Use local Qdrant if cloud not configured
|
| 28 |
+
qdrant_client = QdrantClient(
|
| 29 |
+
host=os.getenv("QDRANT_HOST", "localhost"),
|
| 30 |
+
port=int(os.getenv("QDRANT_PORT", 6333))
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# 1. Check collections
|
| 34 |
+
print("\n1. Available collections:")
|
| 35 |
+
try:
|
| 36 |
+
collections = qdrant_client.get_collections()
|
| 37 |
+
for collection in collections.collections:
|
| 38 |
+
print(f" - {collection.name}")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f" Error getting collections: {e}")
|
| 41 |
+
|
| 42 |
+
# 2. Get collection info
|
| 43 |
+
print(f"\n2. Collection info:")
|
| 44 |
+
try:
|
| 45 |
+
collection_info = qdrant_client.get_collection(collection_name)
|
| 46 |
+
print(f" Points count: {collection_info.points_count}")
|
| 47 |
+
print(f" Vector size: {collection_info.config.params.vectors.size}")
|
| 48 |
+
print(f" Distance: {collection_info.config.params.vectors.distance}")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f" Error getting collection info: {e}")
|
| 51 |
+
|
| 52 |
+
# 3. List all points in the collection (up to 10 for debugging)
|
| 53 |
+
print(f"\n3. All points in collection (up to 10):")
|
| 54 |
+
try:
|
| 55 |
+
points = qdrant_client.scroll(
|
| 56 |
+
collection_name=collection_name,
|
| 57 |
+
limit=10
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
count = 0
|
| 61 |
+
for point in points[0]: # points[0] contains the list of points
|
| 62 |
+
count += 1
|
| 63 |
+
print(f" Point {count}:")
|
| 64 |
+
print(f" ID: {point.id}")
|
| 65 |
+
print(f" Payload keys: {list(point.payload.keys()) if point.payload else 'None'}")
|
| 66 |
+
if point.payload and 'content' in point.payload:
|
| 67 |
+
content_preview = point.payload['content'][:100] + "..." if len(point.payload['content']) > 100 else point.payload['content']
|
| 68 |
+
print(f" Content preview: {content_preview}")
|
| 69 |
+
print(f" Topic: {point.payload.get('metadata', {}).get('topic', 'Unknown')}")
|
| 70 |
+
print()
|
| 71 |
+
|
| 72 |
+
print(f" Total points found: {count}")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f" Error listing points: {e}")
|
| 75 |
+
import traceback
|
| 76 |
+
traceback.print_exc()
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
comprehensive_debug()
|
debug_qdrant.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from qdrant_client import QdrantClient
|
| 5 |
+
|
| 6 |
+
# Load environment variables from parent directory
|
| 7 |
+
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '.env'))
|
| 8 |
+
|
| 9 |
+
def debug_qdrant():
|
| 10 |
+
# Get environment variables
|
| 11 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 12 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 13 |
+
collection_name = os.getenv("QDRANT_COLLECTION", "project_documents")
|
| 14 |
+
|
| 15 |
+
print(f"QDRANT_URL: {qdrant_url}")
|
| 16 |
+
print(f"Collection: {collection_name}")
|
| 17 |
+
|
| 18 |
+
# Initialize Qdrant client for cloud
|
| 19 |
+
if qdrant_url and qdrant_api_key and "qdrant.io" in qdrant_url:
|
| 20 |
+
qdrant_client = QdrantClient(
|
| 21 |
+
url=qdrant_url.replace(":6333", ""), # Remove port from URL for cloud
|
| 22 |
+
api_key=qdrant_api_key,
|
| 23 |
+
prefer_grpc=False
|
| 24 |
+
)
|
| 25 |
+
else:
|
| 26 |
+
# Use local Qdrant if cloud not configured
|
| 27 |
+
qdrant_client = QdrantClient(
|
| 28 |
+
host=os.getenv("QDRANT_HOST", "localhost"),
|
| 29 |
+
port=int(os.getenv("QDRANT_PORT", 6333))
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# 1. Check collections
|
| 33 |
+
print("\n1. Available collections:")
|
| 34 |
+
try:
|
| 35 |
+
collections = qdrant_client.get_collections()
|
| 36 |
+
for collection in collections.collections:
|
| 37 |
+
print(f" - {collection.name}")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f" Error getting collections: {e}")
|
| 40 |
+
|
| 41 |
+
# 2. Try to search for content to verify it exists
|
| 42 |
+
print(f"\n2. Testing search functionality:")
|
| 43 |
+
try:
|
| 44 |
+
# Try to create a simple embedding to test if connection works
|
| 45 |
+
from openai import OpenAI
|
| 46 |
+
openrouter_api_key = os.getenv("OPENAI_API_KEY")
|
| 47 |
+
client = OpenAI(
|
| 48 |
+
api_key=openrouter_api_key,
|
| 49 |
+
base_url="https://openrouter.ai/api/v1"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
response = client.embeddings.create(
|
| 53 |
+
model="text-embedding-3-small",
|
| 54 |
+
input="Artificial Intelligence"
|
| 55 |
+
)
|
| 56 |
+
vector = response.data[0].embedding
|
| 57 |
+
|
| 58 |
+
# Now search in Qdrant
|
| 59 |
+
hits = qdrant_client.search(
|
| 60 |
+
collection_name=collection_name,
|
| 61 |
+
query_vector=vector,
|
| 62 |
+
limit=2
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
print(f" Search successful! Found {len(hits)} results")
|
| 66 |
+
if hits:
|
| 67 |
+
for i, hit in enumerate(hits):
|
| 68 |
+
print(f" Result {i+1}:")
|
| 69 |
+
print(f" ID: {hit.id}")
|
| 70 |
+
print(f" Payload keys: {list(hit.payload.keys()) if hit.payload else 'None'}")
|
| 71 |
+
if hit.payload and 'content' in hit.payload:
|
| 72 |
+
print(f" Content preview: {hit.payload['content'][:100]}...")
|
| 73 |
+
else:
|
| 74 |
+
print(f" Payload content: {hit.payload}")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f" Error during search test: {e}")
|
| 77 |
+
import traceback
|
| 78 |
+
traceback.print_exc()
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
debug_qdrant()
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
backend:
|
| 5 |
+
build: .
|
| 6 |
+
ports:
|
| 7 |
+
- "8000:8000"
|
| 8 |
+
environment:
|
| 9 |
+
- NEON_DB_URL=${NEON_DB_URL}
|
| 10 |
+
- QDRANT_URL=${QDRANT_URL}
|
| 11 |
+
- QDRANT_API_KEY=${QDRANT_API_KEY}
|
| 12 |
+
- GEMINI_API_KEY=${GEMINI_API_KEY}
|
| 13 |
+
- SECRET_KEY=${SECRET_KEY}
|
| 14 |
+
- JWT_EXPIRES_IN=${JWT_EXPIRES_IN:-3600}
|
| 15 |
+
- DEBUG=${DEBUG:-false}
|
| 16 |
+
- LOG_LEVEL=${LOG_LEVEL:-info}
|
| 17 |
+
volumes:
|
| 18 |
+
- .:/app
|
| 19 |
+
depends_on:
|
| 20 |
+
- postgres
|
| 21 |
+
- qdrant
|
| 22 |
+
restart: unless-stopped
|
| 23 |
+
|
| 24 |
+
postgres:
|
| 25 |
+
image: postgres:15-alpine
|
| 26 |
+
environment:
|
| 27 |
+
- POSTGRES_DB=ai_backend
|
| 28 |
+
- POSTGRES_USER=postgres
|
| 29 |
+
- POSTGRES_PASSWORD=password
|
| 30 |
+
ports:
|
| 31 |
+
- "5432:5432"
|
| 32 |
+
volumes:
|
| 33 |
+
- postgres_data:/var/lib/postgresql/data
|
| 34 |
+
restart: unless-stopped
|
| 35 |
+
|
| 36 |
+
qdrant:
|
| 37 |
+
image: qdrant/qdrant:latest
|
| 38 |
+
ports:
|
| 39 |
+
- "6333:6333"
|
| 40 |
+
- "6334:6334"
|
| 41 |
+
volumes:
|
| 42 |
+
- qdrant_data:/qdrant/storage
|
| 43 |
+
restart: unless-stopped
|
| 44 |
+
|
| 45 |
+
volumes:
|
| 46 |
+
postgres_data:
|
| 47 |
+
qdrant_data:
|
final_verification.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# Load environment variables from parent directory
|
| 6 |
+
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '.env'))
|
| 7 |
+
|
| 8 |
+
# Add the backend directory to the path
|
| 9 |
+
sys.path.append(os.path.dirname(__file__))
|
| 10 |
+
|
| 11 |
+
from services.rag_service import RAGService
|
| 12 |
+
from qdrant_client import QdrantClient
|
| 13 |
+
|
| 14 |
+
def final_verification():
|
| 15 |
+
# Get environment variables
|
| 16 |
+
openrouter_api_key = os.getenv("OPENAI_API_KEY")
|
| 17 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 18 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 19 |
+
collection_name = os.getenv("QDRANT_COLLECTION", "project_documents")
|
| 20 |
+
|
| 21 |
+
# Initialize Qdrant client for cloud
|
| 22 |
+
if qdrant_url and qdrant_api_key and "qdrant.io" in qdrant_url:
|
| 23 |
+
qdrant_client = QdrantClient(
|
| 24 |
+
url=qdrant_url.replace(":6333", ""), # Remove port from URL for cloud (same as in chat.py)
|
| 25 |
+
api_key=qdrant_api_key,
|
| 26 |
+
prefer_grpc=False
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
# Use local Qdrant if cloud not configured
|
| 30 |
+
qdrant_client = QdrantClient(
|
| 31 |
+
host=os.getenv("QDRANT_HOST", "localhost"),
|
| 32 |
+
port=int(os.getenv("QDRANT_PORT", 6333))
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Initialize RAG service
|
| 36 |
+
rag_service = RAGService(openrouter_api_key, qdrant_client, collection_name)
|
| 37 |
+
|
| 38 |
+
print("=== FINAL VERIFICATION ===")
|
| 39 |
+
|
| 40 |
+
# Test 1: Content exists - should return actual content
|
| 41 |
+
print("\n✅ Test 1: Content exists in database")
|
| 42 |
+
selected_text = "Robotics is an interdisciplinary field that combines mechanical engineering, electrical engineering, and computer science to design, construct, and operate robots."
|
| 43 |
+
question = "What is robotics?"
|
| 44 |
+
|
| 45 |
+
result = rag_service.query_rag(selected_text, question)
|
| 46 |
+
print(f"Selected text: {selected_text[:50]}...")
|
| 47 |
+
print(f"Question: {question}")
|
| 48 |
+
print(f"Result: {result[:100]}...")
|
| 49 |
+
print(f"Expected: Actual content (not fallback)")
|
| 50 |
+
print(f"✅ PASS: Content returned (not fallback message)" if "Is sawal ka jawab" not in result else "❌ FAIL: Fallback message returned")
|
| 51 |
+
|
| 52 |
+
# Test 2: Content doesn't exist - should return fallback
|
| 53 |
+
print("\n✅ Test 2: Content does not exist in database")
|
| 54 |
+
selected_text = "This is completely unrelated text that should not match anything in the database."
|
| 55 |
+
question = "What is Quantum Computing?"
|
| 56 |
+
|
| 57 |
+
result = rag_service.query_rag(selected_text, question)
|
| 58 |
+
print(f"Selected text: {selected_text[:50]}...")
|
| 59 |
+
print(f"Question: {question}")
|
| 60 |
+
print(f"Result: {result}")
|
| 61 |
+
print(f"Expected: 'Is sawal ka jawab provided data me mojood nahi hai.'")
|
| 62 |
+
print(f"✅ PASS: Fallback message returned" if "Is sawal ka jawab" in result else "❌ FAIL: Content returned")
|
| 63 |
+
|
| 64 |
+
# Test 3: AI content exists
|
| 65 |
+
print("\n✅ Test 3: AI content exists in database")
|
| 66 |
+
selected_text = "Artificial Intelligence is a branch of computer science that aims to create software or machines that exhibit human-like intelligence."
|
| 67 |
+
question = "What is Artificial Intelligence?"
|
| 68 |
+
|
| 69 |
+
result = rag_service.query_rag(selected_text, question)
|
| 70 |
+
print(f"Selected text: {selected_text[:50]}...")
|
| 71 |
+
print(f"Question: {question}")
|
| 72 |
+
print(f"Result: {result[:100]}...")
|
| 73 |
+
print(f"✅ PASS: Content returned (not fallback message)" if "Is sawal ka jawab" not in result else "❌ FAIL: Fallback message returned")
|
| 74 |
+
|
| 75 |
+
print("\n=== VERIFICATION COMPLETE ===")
|
| 76 |
+
print("✅ Backend RAG service is working correctly")
|
| 77 |
+
print("✅ Uses selected_text for Qdrant search")
|
| 78 |
+
print("✅ Returns actual content when found")
|
| 79 |
+
print("✅ Returns fallback message when not found")
|
| 80 |
+
print("✅ Ready for frontend integration")
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
final_verification()
|
main.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# Load environment variables from .env file in the project root
|
| 6 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 7 |
+
dotenv_path = os.path.join(project_root, '.env')
|
| 8 |
+
load_dotenv(dotenv_path)
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
|
| 14 |
+
from fastapi import FastAPI
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
import uvicorn
|
| 17 |
+
from api.chat import router as chat_router
|
| 18 |
+
from api.auth import router as auth_router
|
| 19 |
+
from api.translation import router as translation_router
|
| 20 |
+
from api.personalization import router as personalization_router
|
| 21 |
+
from api.rag_search import router as rag_search_router
|
| 22 |
+
from api.chat import router as chat_api_router # New chat API with RAG
|
| 23 |
+
|
| 24 |
+
app = FastAPI(title="AI-native Textbook Platform API")
|
| 25 |
+
|
| 26 |
+
# Add CORS middleware to allow requests from the Docusaurus frontend
|
| 27 |
+
app.add_middleware(
|
| 28 |
+
CORSMiddleware,
|
| 29 |
+
allow_origins=["http://localhost:3000", "http://localhost:3001", "http://localhost:8000", "*"], # Allow frontend origins
|
| 30 |
+
allow_credentials=True,
|
| 31 |
+
allow_methods=["*"],
|
| 32 |
+
allow_headers=["*"],
|
| 33 |
+
allow_origin_regex=r"https?://localhost(:[0-9]+)?",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Include API routers
|
| 37 |
+
app.include_router(chat_router, prefix="/api") # Original chat API (for compatibility)
|
| 38 |
+
app.include_router(auth_router, prefix="/api")
|
| 39 |
+
app.include_router(translation_router, prefix="/api")
|
| 40 |
+
app.include_router(personalization_router, prefix="/api")
|
| 41 |
+
app.include_router(rag_search_router, prefix="/api")
|
| 42 |
+
# New enhanced chat API with RAG is included in the original chat_router
|
| 43 |
+
|
| 44 |
+
@app.get("/")
|
| 45 |
+
def read_root():
|
| 46 |
+
return {"message": "Welcome to the AI-native Interactive Textbook Platform for Physical AI & Humanoid Robotics"}
|
| 47 |
+
|
| 48 |
+
@app.get("/health")
|
| 49 |
+
def health_check():
|
| 50 |
+
return {"status": "healthy"}
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
uvicorn.run(app, host="0.0.0.0", port=8001)
|
middleware/auth_middleware.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request, HTTPException
|
| 2 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 3 |
+
import jwt
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
class JWTAuth:
|
| 8 |
+
def __init__(self, secret_key: str = None, algorithm: str = "HS256"):
|
| 9 |
+
self.secret_key = secret_key or os.getenv("JWT_SECRET_KEY", "your-secret-key-here")
|
| 10 |
+
self.algorithm = algorithm
|
| 11 |
+
self.security = HTTPBearer()
|
| 12 |
+
|
| 13 |
+
async def __call__(self, request: Request) -> Optional[dict]:
|
| 14 |
+
credentials: HTTPAuthorizationCredentials = await self.security(request)
|
| 15 |
+
|
| 16 |
+
if credentials:
|
| 17 |
+
token = credentials.credentials
|
| 18 |
+
try:
|
| 19 |
+
# Decode the JWT token
|
| 20 |
+
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
| 21 |
+
request.state.user = payload
|
| 22 |
+
return payload
|
| 23 |
+
except jwt.ExpiredSignatureError:
|
| 24 |
+
raise HTTPException(status_code=401, detail="Token has expired")
|
| 25 |
+
except jwt.InvalidTokenError:
|
| 26 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 27 |
+
else:
|
| 28 |
+
raise HTTPException(status_code=401, detail="No authorization token provided")
|
| 29 |
+
|
| 30 |
+
# Example usage in routes:
|
| 31 |
+
# @router.get("/protected-route")
|
| 32 |
+
# async def protected_route(request: Request, user: dict = Depends(JWTAuth())):
|
| 33 |
+
# return {"message": f"Hello {user.get('email')}, you are authenticated!"}
|
| 34 |
+
|
| 35 |
+
# For now, we'll create a simple dependency that can be used to require authentication
|
| 36 |
+
async def require_auth(request: Request):
|
| 37 |
+
"""Simple dependency to require authentication (placeholder for real implementation)"""
|
| 38 |
+
# In a real implementation, this would validate the JWT token
|
| 39 |
+
# For now, we'll just check if there's a mock token in the header
|
| 40 |
+
auth_header = request.headers.get("Authorization")
|
| 41 |
+
if not auth_header or not auth_header.startswith("Bearer "):
|
| 42 |
+
raise HTTPException(status_code=401, detail="Authorization header missing or invalid")
|
| 43 |
+
|
| 44 |
+
# In a real app, you would validate the JWT here
|
| 45 |
+
# For demo purposes, we'll just continue
|
| 46 |
+
pass
|
models/chat_session.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
class ChatSession(BaseModel):
|
| 6 |
+
id: Optional[str] = None
|
| 7 |
+
user_id: str
|
| 8 |
+
selected_text: str
|
| 9 |
+
question: str
|
| 10 |
+
response: str
|
| 11 |
+
created_at: Optional[datetime] = None
|
| 12 |
+
conversation_history: Optional[List[dict]] = []
|
models/user.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
class User(BaseModel):
|
| 6 |
+
id: Optional[str] = None
|
| 7 |
+
email: str
|
| 8 |
+
password: str
|
| 9 |
+
created_at: Optional[datetime] = None
|
| 10 |
+
updated_at: Optional[datetime] = None
|
| 11 |
+
software_background: Optional[str] = None # Software Engineer, Beginner, etc.
|
| 12 |
+
hardware_background: Optional[str] = None # Hardware Engineer, Beginner, etc.
|
| 13 |
+
experience_level: Optional[str] = None # Beginner, Intermediate, Advanced
|
models/user_profile.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
class UserProfile(BaseModel):
|
| 6 |
+
id: Optional[str] = None
|
| 7 |
+
user_id: str
|
| 8 |
+
software_background: Optional[str] = None # Software Engineer, Beginner, etc.
|
| 9 |
+
hardware_background: Optional[str] = None # Hardware Engineer, Beginner, etc.
|
| 10 |
+
experience_level: Optional[str] = None # Beginner, Intermediate, Advanced
|
| 11 |
+
personalization_settings: Optional[dict] = {}
|
| 12 |
+
learning_progress: Optional[dict] = {}
|
| 13 |
+
created_at: Optional[datetime] = None
|
| 14 |
+
updated_at: Optional[datetime] = None
|
pyproject.toml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "ai-backend-rag-auth"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "AI Backend with RAG + Authentication using Qdrant, Neon, Gemini, FastAPI, and Better Auth"
|
| 5 |
+
authors = ["Your Name <your.email@example.com>"]
|
| 6 |
+
|
| 7 |
+
[tool.poetry.dependencies]
|
| 8 |
+
python = "^3.9"
|
| 9 |
+
fastapi = "^0.104.1"
|
| 10 |
+
uvicorn = {extras = ["standard"], version = "^0.24.0"}
|
| 11 |
+
sqlalchemy = {extras = ["asyncio"], version = "^2.0.23"}
|
| 12 |
+
asyncpg = "^0.29.0"
|
| 13 |
+
qdrant-client = "^1.7.0"
|
| 14 |
+
google-generativeai = "^0.4.0"
|
| 15 |
+
python-multipart = "^0.0.6"
|
| 16 |
+
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
| 17 |
+
passlib = {extras = ["bcrypt"], version = "^1.7.4"}
|
| 18 |
+
better-exceptions = "^0.3.3"
|
| 19 |
+
python-dotenv = "^1.0.0"
|
| 20 |
+
pydantic = "^2.5.0"
|
| 21 |
+
pydantic-settings = "^2.1.0"
|
| 22 |
+
uuid = "^1.30"
|
| 23 |
+
httpx = "^0.25.2"
|
| 24 |
+
alembic = "^1.13.1"
|
| 25 |
+
|
| 26 |
+
[tool.poetry.group.dev.dependencies]
|
| 27 |
+
pytest = "^7.4.3"
|
| 28 |
+
pytest-asyncio = "^0.21.1"
|
| 29 |
+
black = "^23.10.1"
|
| 30 |
+
isort = "^5.12.0"
|
| 31 |
+
mypy = "^1.7.1"
|
| 32 |
+
|
| 33 |
+
[build-system]
|
| 34 |
+
requires = ["poetry-core"]
|
| 35 |
+
build-backend = "poetry.core.masonry.api"
|
| 36 |
+
|
| 37 |
+
[tool.pytest.ini_options]
|
| 38 |
+
testpaths = ["tests"]
|
| 39 |
+
asyncio_mode = "auto"
|
| 40 |
+
addopts = "-v --tb=short"
|
| 41 |
+
|
| 42 |
+
[tool.black]
|
| 43 |
+
line-length = 88
|
| 44 |
+
target-version = ['py39']
|
| 45 |
+
include = '\.pyi?$'
|
| 46 |
+
extend-exclude = '''
|
| 47 |
+
/(
|
| 48 |
+
# directories
|
| 49 |
+
\.eggs
|
| 50 |
+
| \.git
|
| 51 |
+
| \.hg
|
| 52 |
+
| \.mypy_cache
|
| 53 |
+
| \.tox
|
| 54 |
+
| \.venv
|
| 55 |
+
| build
|
| 56 |
+
| dist
|
| 57 |
+
)/
|
| 58 |
+
'''
|
| 59 |
+
|
| 60 |
+
[tool.isort]
|
| 61 |
+
profile = "black"
|
| 62 |
+
multi_line_output = 3
|
| 63 |
+
known_first_party = ["src"]
|
| 64 |
+
known_third_party = ["fastapi", "uvicorn", "sqlalchemy", "asyncpg", "qdrant_client", "google", "pydantic", "pytest"]
|
| 65 |
+
|
| 66 |
+
[tool.mypy]
|
| 67 |
+
python_version = "3.9"
|
| 68 |
+
warn_return_any = true
|
| 69 |
+
warn_unused_configs = true
|
| 70 |
+
warn_unused_ignores = true
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn[standard]==0.24.0
|
| 3 |
+
sqlalchemy[asyncio]==2.0.23
|
| 4 |
+
asyncpg==0.29.0
|
| 5 |
+
qdrant-client==1.7.0
|
| 6 |
+
google-generativeai==0.4.0
|
| 7 |
+
python-multipart==0.0.6
|
| 8 |
+
python-jose[cryptography]==3.3.0
|
| 9 |
+
passlib[bcrypt]==1.7.4
|
| 10 |
+
better-exceptions==0.3.3
|
| 11 |
+
python-dotenv==1.0.0
|
| 12 |
+
pydantic==2.5.0
|
| 13 |
+
pydantic-settings==2.1.0
|
| 14 |
+
uuid==1.30
|
| 15 |
+
httpx==0.25.2
|
| 16 |
+
pytest==7.4.3
|
| 17 |
+
pytest-asyncio==0.21.1
|
| 18 |
+
alembic==1.13.1
|
| 19 |
+
openai==1.10.0
|
services/content_adaptation.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Optional
|
| 2 |
+
import google.generativeai as genai
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
class ContentAdaptationService:
|
| 8 |
+
def __init__(self, gemini_api_key: str):
|
| 9 |
+
genai.configure(api_key=gemini_api_key)
|
| 10 |
+
self.model = genai.GenerativeModel('gemini-pro')
|
| 11 |
+
|
| 12 |
+
def adapt_content(self, content: str, user_background: str, experience_level: str, chapter_id: str) -> str:
|
| 13 |
+
"""Adapt content based on user background and experience level"""
|
| 14 |
+
try:
|
| 15 |
+
# Determine adaptation instructions based on user profile
|
| 16 |
+
adaptation_instructions = self._get_adaptation_instructions(user_background, experience_level, chapter_id)
|
| 17 |
+
|
| 18 |
+
# Call Gemini API to adapt the content
|
| 19 |
+
prompt = f"""You are an educational content adapter for a Physical AI & Humanoid Robotics textbook. Adapt the provided content according to these instructions: {adaptation_instructions}. Maintain the core educational value while making it appropriate for the target audience.
|
| 20 |
+
|
| 21 |
+
Original content:
|
| 22 |
+
{content}
|
| 23 |
+
|
| 24 |
+
Adapted content:"""
|
| 25 |
+
|
| 26 |
+
response = self.model.generate_content(
|
| 27 |
+
prompt,
|
| 28 |
+
generation_config=genai.types.GenerationConfig(
|
| 29 |
+
max_output_tokens=len(content) * 2,
|
| 30 |
+
temperature=0.4,
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
adapted_content = response.text
|
| 35 |
+
|
| 36 |
+
logger.info(f"Adapted content for background: {user_background}, level: {experience_level}")
|
| 37 |
+
return adapted_content
|
| 38 |
+
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Error adapting content: {str(e)}")
|
| 41 |
+
# Return original content if adaptation fails
|
| 42 |
+
return content
|
| 43 |
+
|
| 44 |
+
def _get_adaptation_instructions(self, user_background: str, experience_level: str, chapter_id: str) -> str:
|
| 45 |
+
"""Generate adaptation instructions based on user profile"""
|
| 46 |
+
instructions = []
|
| 47 |
+
|
| 48 |
+
# Add background-specific instructions
|
| 49 |
+
if user_background and 'software' in user_background.lower():
|
| 50 |
+
instructions.append("Include more code examples and programming concepts")
|
| 51 |
+
elif user_background and 'hardware' in user_background.lower():
|
| 52 |
+
instructions.append("Include more hardware specifications and physical implementations")
|
| 53 |
+
else:
|
| 54 |
+
instructions.append("Provide balanced content with both software and hardware aspects")
|
| 55 |
+
|
| 56 |
+
# Add experience level-specific instructions
|
| 57 |
+
if experience_level == 'beginner':
|
| 58 |
+
instructions.append("Use simpler explanations, more examples, and step-by-step instructions")
|
| 59 |
+
elif experience_level == 'intermediate':
|
| 60 |
+
instructions.append("Provide moderate complexity with practical applications")
|
| 61 |
+
elif experience_level == 'advanced':
|
| 62 |
+
instructions.append("Include complex examples, optimization techniques, and advanced concepts")
|
| 63 |
+
else:
|
| 64 |
+
instructions.append("Use moderate complexity appropriate for mixed experience levels")
|
| 65 |
+
|
| 66 |
+
# Add chapter-specific instructions if needed
|
| 67 |
+
if 'ros2' in chapter_id.lower():
|
| 68 |
+
instructions.append("Focus on ROS 2 concepts like nodes, topics, and URDF")
|
| 69 |
+
elif 'gazebo' in chapter_id.lower() or 'unity' in chapter_id.lower():
|
| 70 |
+
instructions.append("Emphasize simulation concepts, sensors, and environment modeling")
|
| 71 |
+
elif 'nvidia' in chapter_id.lower() or 'isaac' in chapter_id.lower():
|
| 72 |
+
instructions.append("Highlight perception, VSLAM, navigation, and Isaac-specific concepts")
|
| 73 |
+
elif 'vla' in chapter_id.lower():
|
| 74 |
+
instructions.append("Focus on voice, cognitive, and capstone project concepts")
|
| 75 |
+
|
| 76 |
+
return "; ".join(instructions)
|
| 77 |
+
|
| 78 |
+
def adapt_examples(self, examples: list, user_background: str, experience_level: str) -> list:
|
| 79 |
+
"""Adapt code or practical examples based on user profile"""
|
| 80 |
+
try:
|
| 81 |
+
adapted_examples = []
|
| 82 |
+
for example in examples:
|
| 83 |
+
prompt = f"""You are adapting educational examples for a Physical AI & Humanoid Robotics textbook. Adapt this example for a user with {user_background} background and {experience_level} experience level. Return the adapted example.
|
| 84 |
+
|
| 85 |
+
Original example:
|
| 86 |
+
{example}
|
| 87 |
+
|
| 88 |
+
Adapted example:"""
|
| 89 |
+
|
| 90 |
+
response = self.model.generate_content(
|
| 91 |
+
prompt,
|
| 92 |
+
generation_config=genai.types.GenerationConfig(
|
| 93 |
+
max_output_tokens=1000,
|
| 94 |
+
temperature=0.3,
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
adapted_examples.append(response.text)
|
| 99 |
+
|
| 100 |
+
logger.info(f"Adapted {len(examples)} examples for background: {user_background}, level: {experience_level}")
|
| 101 |
+
return adapted_examples
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"Error adapting examples: {str(e)}")
|
| 105 |
+
return examples # Return original examples if adaptation fails
|
services/personalization_service.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Optional
|
| 2 |
+
from .content_adaptation import ContentAdaptationService
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
class PersonalizationService:
|
| 8 |
+
def __init__(self, content_adaptation_service: ContentAdaptationService):
|
| 9 |
+
self.content_adaptation_service = content_adaptation_service
|
| 10 |
+
|
| 11 |
+
def get_personalized_content(self, content: str, user_profile: Dict[str, Any], chapter_id: str) -> str:
|
| 12 |
+
"""Get personalized content based on user profile"""
|
| 13 |
+
try:
|
| 14 |
+
# Determine the user's background and experience level
|
| 15 |
+
software_background = user_profile.get('software_background', '')
|
| 16 |
+
hardware_background = user_profile.get('hardware_background', '')
|
| 17 |
+
experience_level = user_profile.get('experience_level', 'beginner')
|
| 18 |
+
|
| 19 |
+
# Adapt content based on user profile
|
| 20 |
+
adapted_content = self.content_adaptation_service.adapt_content(
|
| 21 |
+
content=content,
|
| 22 |
+
user_background=software_background or hardware_background,
|
| 23 |
+
experience_level=experience_level,
|
| 24 |
+
chapter_id=chapter_id
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
logger.info(f"Personalized content for user with background: {software_background}/{hardware_background}, level: {experience_level}")
|
| 28 |
+
return adapted_content
|
| 29 |
+
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Error in personalization: {str(e)}")
|
| 32 |
+
# Return original content if personalization fails
|
| 33 |
+
return content
|
| 34 |
+
|
| 35 |
+
def get_user_recommendations(self, user_profile: Dict[str, Any], current_chapter: str) -> Dict[str, Any]:
|
| 36 |
+
"""Get personalized recommendations for the user"""
|
| 37 |
+
try:
|
| 38 |
+
software_background = user_profile.get('software_background', '')
|
| 39 |
+
hardware_background = user_profile.get('hardware_background', '')
|
| 40 |
+
experience_level = user_profile.get('experience_level', 'beginner')
|
| 41 |
+
|
| 42 |
+
recommendations = {
|
| 43 |
+
'next_chapters': self._get_next_chapters(user_profile, current_chapter),
|
| 44 |
+
'difficulty_level': experience_level,
|
| 45 |
+
'focus_areas': self._get_focus_areas(software_background, hardware_background),
|
| 46 |
+
'additional_resources': self._get_resources(experience_level)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
logger.info(f"Generated recommendations for user")
|
| 50 |
+
return recommendations
|
| 51 |
+
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Error generating recommendations: {str(e)}")
|
| 54 |
+
return {}
|
| 55 |
+
|
| 56 |
+
def _get_next_chapters(self, user_profile: Dict[str, Any], current_chapter: str) -> list:
|
| 57 |
+
"""Determine next chapters based on user profile and current progress"""
|
| 58 |
+
# This would be more sophisticated in a real implementation
|
| 59 |
+
# For now, return a default sequence
|
| 60 |
+
chapter_sequence = {
|
| 61 |
+
'1-ros2': ['2-gazebo-unity', '3-nvidia-isaac'],
|
| 62 |
+
'2-gazebo-unity': ['3-nvidia-isaac', '4-vla'],
|
| 63 |
+
'3-nvidia-isaac': ['4-vla', 'capstone'],
|
| 64 |
+
'4-vla': ['capstone'],
|
| 65 |
+
'capstone': []
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return chapter_sequence.get(current_chapter, [])
|
| 69 |
+
|
| 70 |
+
def _get_focus_areas(self, software_background: str, hardware_background: str) -> list:
|
| 71 |
+
"""Determine focus areas based on user background"""
|
| 72 |
+
focus_areas = []
|
| 73 |
+
|
| 74 |
+
if software_background and 'software' in software_background.lower():
|
| 75 |
+
focus_areas.append('code examples')
|
| 76 |
+
focus_areas.append('programming concepts')
|
| 77 |
+
elif hardware_background and 'hardware' in hardware_background.lower():
|
| 78 |
+
focus_areas.append('hardware specifications')
|
| 79 |
+
focus_areas.append('physical implementations')
|
| 80 |
+
|
| 81 |
+
if not focus_areas:
|
| 82 |
+
focus_areas.append('general concepts')
|
| 83 |
+
|
| 84 |
+
return focus_areas
|
| 85 |
+
|
| 86 |
+
def _get_resources(self, experience_level: str) -> list:
|
| 87 |
+
"""Get additional resources based on experience level"""
|
| 88 |
+
if experience_level == 'beginner':
|
| 89 |
+
return ['tutorials', 'basic examples', 'step-by-step guides']
|
| 90 |
+
elif experience_level == 'intermediate':
|
| 91 |
+
return ['advanced examples', 'practical applications']
|
| 92 |
+
elif experience_level == 'advanced':
|
| 93 |
+
return ['research papers', 'cutting-edge implementations', 'optimization techniques']
|
| 94 |
+
else:
|
| 95 |
+
return ['tutorials', 'examples']
|
services/rag_service.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class RAGService:
|
| 11 |
+
def __init__(self, openrouter_api_key: str, vector_db_service: QdrantClient, collection_name: str = "project_documents"):
|
| 12 |
+
# Initialize OpenRouter client
|
| 13 |
+
self.client = OpenAI(
|
| 14 |
+
api_key=openrouter_api_key,
|
| 15 |
+
base_url="https://openrouter.ai/api/v1"
|
| 16 |
+
)
|
| 17 |
+
self.qdrant = vector_db_service
|
| 18 |
+
self.collection_name = collection_name
|
| 19 |
+
|
| 20 |
+
def get_embedding(self, text: str) -> List[float]:
|
| 21 |
+
"""Get embeddings for text using OpenAI's embedding API"""
|
| 22 |
+
try:
|
| 23 |
+
response = self.client.embeddings.create(
|
| 24 |
+
model="text-embedding-3-small",
|
| 25 |
+
input=text
|
| 26 |
+
)
|
| 27 |
+
return response.data[0].embedding
|
| 28 |
+
except Exception as e:
|
| 29 |
+
logger.error(f"Error getting embeddings: {str(e)}")
|
| 30 |
+
raise e
|
| 31 |
+
|
| 32 |
+
def search_qdrant(self, query: str) -> str:
|
| 33 |
+
"""Search Qdrant for relevant content based on query"""
|
| 34 |
+
try:
|
| 35 |
+
vector = self.get_embedding(query)
|
| 36 |
+
|
| 37 |
+
hits = self.qdrant.search(
|
| 38 |
+
collection_name=self.collection_name,
|
| 39 |
+
query_vector=vector,
|
| 40 |
+
limit=5
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return "\n\n".join(hit.payload["content"] for hit in hits if "content" in hit.payload)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Error searching Qdrant: {str(e)}")
|
| 46 |
+
return ""
|
| 47 |
+
|
| 48 |
+
def query_rag(self, selected_text: str, question: str) -> str:
|
| 49 |
+
"""Process a RAG query using OpenRouter with context from Qdrant"""
|
| 50 |
+
# Validate inputs
|
| 51 |
+
if not selected_text or len(selected_text.strip()) == 0:
|
| 52 |
+
# Check length (as per requirement TC-002: max 5000 characters)
|
| 53 |
+
if len(selected_text) > 5000:
|
| 54 |
+
logger.warning(f"Selected text exceeds 5000 character limit: {len(selected_text)} characters")
|
| 55 |
+
return "Selected text exceeds the 5000 character limit. Please select a shorter text."
|
| 56 |
+
|
| 57 |
+
SYSTEM_PROMPT = """You are a RAG-based AI agent.
|
| 58 |
+
|
| 59 |
+
RULES:
|
| 60 |
+
- Answer ONLY from the retrieved context.
|
| 61 |
+
- If the answer is not found, say:
|
| 62 |
+
"Is sawal ka jawab provided data me mojood nahi hai."""
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
# Search Qdrant using the selected_text to get relevant context
|
| 66 |
+
context = self.search_qdrant(selected_text)
|
| 67 |
+
|
| 68 |
+
# If we found context, generate the final answer using the context
|
| 69 |
+
if context.strip():
|
| 70 |
+
final_response = self.client.chat.completions.create(
|
| 71 |
+
model="openai/gpt-3.5-turbo",
|
| 72 |
+
messages=[
|
| 73 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 74 |
+
{"role": "assistant", "content": f"Here is the relevant context: {context}"},
|
| 75 |
+
{"role": "user", "content": question}
|
| 76 |
+
],
|
| 77 |
+
temperature=0
|
| 78 |
+
)
|
| 79 |
+
return final_response.choices[0].message.content
|
| 80 |
+
else:
|
| 81 |
+
# If no context was found, return the fallback message
|
| 82 |
+
return "Is sawal ka jawab provided data me mojood nahi hai."
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"Error in RAG query: {str(e)}")
|
| 86 |
+
# Check if the error is related to API key validity
|
| 87 |
+
error_str = str(e).lower()
|
| 88 |
+
if "api key" in error_str or "quota" in error_str or "billing" in error_str or "permission" in error_str or "401" in str(e) or "403" in str(e):
|
| 89 |
+
# Return a more specific message about API configuration
|
| 90 |
+
return f"I'm currently unable to process your request about '{question}'. The AI service may be temporarily unavailable due to API key issues or quota limits. Please check that your OPENROUTER_API_KEY is properly configured in the .env file and has sufficient quota available."
|
| 91 |
+
else:
|
| 92 |
+
# Return a general fallback response
|
| 93 |
+
return f"I apologize, but I'm currently unable to process your request about '{question}'. The AI service may be temporarily unavailable. Please try again later or contact support if the issue persists."
|
| 94 |
+
|
| 95 |
+
def index_content(self, content_id: str, content: str, metadata: Dict[str, Any] = None):
|
| 96 |
+
"""Index textbook content for RAG retrieval"""
|
| 97 |
+
if metadata is None:
|
| 98 |
+
metadata = {}
|
| 99 |
+
|
| 100 |
+
# Get embeddings for the content
|
| 101 |
+
embeddings = self.get_embedding(content)
|
| 102 |
+
|
| 103 |
+
# Store in vector database
|
| 104 |
+
self.qdrant.upsert(
|
| 105 |
+
collection_name=self.collection_name,
|
| 106 |
+
points=[
|
| 107 |
+
{
|
| 108 |
+
"id": content_id,
|
| 109 |
+
"vector": embeddings,
|
| 110 |
+
"payload": {
|
| 111 |
+
"content": content,
|
| 112 |
+
"metadata": metadata
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
logger.info(f"Indexed content: {content_id}")
|
services/translation_service.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import google.generativeai as genai
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class TranslationService:
|
| 9 |
+
def __init__(self, gemini_api_key: str):
|
| 10 |
+
genai.configure(api_key=gemini_api_key)
|
| 11 |
+
self.model = genai.GenerativeModel('gemini-pro')
|
| 12 |
+
self.translation_cache: Dict[str, str] = {}
|
| 13 |
+
self.cache_timestamps: Dict[str, float] = {}
|
| 14 |
+
|
| 15 |
+
def translate_to_urdu(self, text: str, ttl: int = 3600) -> str:
|
| 16 |
+
"""Translate English text to Urdu with caching"""
|
| 17 |
+
# Create cache key
|
| 18 |
+
cache_key = f"en_to_ur_{hash(text)}"
|
| 19 |
+
|
| 20 |
+
# Check if translation is in cache and not expired
|
| 21 |
+
if cache_key in self.translation_cache:
|
| 22 |
+
if time.time() - self.cache_timestamps.get(cache_key, 0) < ttl:
|
| 23 |
+
logger.info("Returning cached translation")
|
| 24 |
+
return self.translation_cache[cache_key]
|
| 25 |
+
|
| 26 |
+
# Call Gemini API for translation with improved prompt
|
| 27 |
+
try:
|
| 28 |
+
prompt = self._create_urdu_translation_prompt(text)
|
| 29 |
+
|
| 30 |
+
response = self.model.generate_content(
|
| 31 |
+
prompt,
|
| 32 |
+
generation_config=genai.types.GenerationConfig(
|
| 33 |
+
max_output_tokens=min(len(text) * 3, 4000), # Urdu text might be longer
|
| 34 |
+
temperature=0.2,
|
| 35 |
+
top_p=0.9,
|
| 36 |
+
)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
translated_text = self._format_translation_response(response.text)
|
| 40 |
+
|
| 41 |
+
# Cache the translation
|
| 42 |
+
self.translation_cache[cache_key] = translated_text
|
| 43 |
+
self.cache_timestamps[cache_key] = time.time()
|
| 44 |
+
|
| 45 |
+
logger.info(f"Translated text to Urdu (length: {len(translated_text)} chars)")
|
| 46 |
+
return translated_text
|
| 47 |
+
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Translation failed: {str(e)}")
|
| 50 |
+
# Return a professional fallback response
|
| 51 |
+
return f"Translation unavailable: {text[:100]}..."
|
| 52 |
+
|
| 53 |
+
def _create_urdu_translation_prompt(self, text: str) -> str:
|
| 54 |
+
"""Create a professional Urdu translation prompt"""
|
| 55 |
+
return f"""You are an elite professional translator specializing in technical and educational content. Translate the provided English text to Urdu with precision and cultural sensitivity.
|
| 56 |
+
|
| 57 |
+
TRANSLATION REQUIREMENTS:
|
| 58 |
+
• Maintain technical accuracy for robotics/AI terminology
|
| 59 |
+
• Use proper Urdu script and correct grammar
|
| 60 |
+
• Preserve the original meaning and context
|
| 61 |
+
• Apply appropriate formality level for educational content
|
| 62 |
+
• Ensure readability and flow in Urdu
|
| 63 |
+
• Do not add any commentary or explanations
|
| 64 |
+
|
| 65 |
+
SOURCE TEXT:
|
| 66 |
+
"{text}"
|
| 67 |
+
|
| 68 |
+
URDU TRANSLATION:"""
|
| 69 |
+
|
| 70 |
+
def _format_translation_response(self, response_text: str) -> str:
|
| 71 |
+
"""Format the translation response for consistency"""
|
| 72 |
+
# Clean up response
|
| 73 |
+
formatted = response_text.strip()
|
| 74 |
+
|
| 75 |
+
# Remove any unwanted prefixes or explanations
|
| 76 |
+
if 'TRANSLATION:' in formatted:
|
| 77 |
+
formatted = formatted.split('TRANSLATION:')[-1].strip()
|
| 78 |
+
elif 'TRANSLATED TEXT:' in formatted:
|
| 79 |
+
formatted = formatted.split('TRANSLATED TEXT:')[-1].strip()
|
| 80 |
+
|
| 81 |
+
# Clean up extra whitespace
|
| 82 |
+
formatted = ' '.join(formatted.split())
|
| 83 |
+
|
| 84 |
+
return formatted
|
| 85 |
+
|
| 86 |
+
def translate_to_english(self, urdu_text: str, ttl: int = 3600) -> str:
|
| 87 |
+
"""Translate Urdu text back to English with caching"""
|
| 88 |
+
# Create cache key
|
| 89 |
+
cache_key = f"ur_to_en_{hash(urdu_text)}"
|
| 90 |
+
|
| 91 |
+
# Check if translation is in cache and not expired
|
| 92 |
+
if cache_key in self.translation_cache:
|
| 93 |
+
if time.time() - self.cache_timestamps.get(cache_key, 0) < ttl:
|
| 94 |
+
logger.info("Returning cached translation")
|
| 95 |
+
return self.translation_cache[cache_key]
|
| 96 |
+
|
| 97 |
+
# Call Gemini API for translation with improved prompt
|
| 98 |
+
try:
|
| 99 |
+
prompt = self._create_english_translation_prompt(urdu_text)
|
| 100 |
+
|
| 101 |
+
response = self.model.generate_content(
|
| 102 |
+
prompt,
|
| 103 |
+
generation_config=genai.types.GenerationConfig(
|
| 104 |
+
max_output_tokens=min(len(urdu_text) * 2, 4000),
|
| 105 |
+
temperature=0.2,
|
| 106 |
+
top_p=0.9,
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
translated_text = self._format_translation_response(response.text)
|
| 111 |
+
|
| 112 |
+
# Cache the translation
|
| 113 |
+
self.translation_cache[cache_key] = translated_text
|
| 114 |
+
self.cache_timestamps[cache_key] = time.time()
|
| 115 |
+
|
| 116 |
+
logger.info(f"Translated text to English (length: {len(translated_text)} chars)")
|
| 117 |
+
return translated_text
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Translation failed: {str(e)}")
|
| 121 |
+
# Return a professional fallback response
|
| 122 |
+
return f"Translation unavailable: {urdu_text[:100]}..."
|
| 123 |
+
|
| 124 |
+
def _create_english_translation_prompt(self, urdu_text: str) -> str:
|
| 125 |
+
"""Create a professional English translation prompt"""
|
| 126 |
+
return f"""You are an elite professional translator specializing in technical and educational content. Translate the provided Urdu text to English with precision and accuracy.
|
| 127 |
+
|
| 128 |
+
TRANSLATION REQUIREMENTS:
|
| 129 |
+
• Maintain technical accuracy for robotics/AI terminology
|
| 130 |
+
• Preserve the original meaning and context
|
| 131 |
+
• Apply appropriate formality level for educational content
|
| 132 |
+
• Ensure readability and flow in English
|
| 133 |
+
• Do not add any commentary or explanations
|
| 134 |
+
|
| 135 |
+
SOURCE TEXT:
|
| 136 |
+
"{urdu_text}"
|
| 137 |
+
|
| 138 |
+
ENGLISH TRANSLATION:"""
|
| 139 |
+
|
| 140 |
+
def clear_cache(self):
|
| 141 |
+
"""Clear the translation cache"""
|
| 142 |
+
self.translation_cache.clear()
|
| 143 |
+
self.cache_timestamps.clear()
|
| 144 |
+
logger.info("Translation cache cleared")
|
services/vector_db.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from qdrant_client import QdrantClient
|
| 2 |
+
from qdrant_client.http import models
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class VectorDBService:
|
| 9 |
+
def __init__(self, host: str = "localhost", port: int = 6333, cloud_client=None):
|
| 10 |
+
self.collection_name = "textbook_content"
|
| 11 |
+
self.is_available = False # Initialize as False by default
|
| 12 |
+
|
| 13 |
+
if cloud_client:
|
| 14 |
+
# Use provided cloud client
|
| 15 |
+
self.client = cloud_client
|
| 16 |
+
try:
|
| 17 |
+
self._init_collection()
|
| 18 |
+
self.is_available = True
|
| 19 |
+
except Exception as e:
|
| 20 |
+
import logging
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
logger.warning(f"Vector database not available: {str(e)}. Running in fallback mode.")
|
| 23 |
+
self.client = None
|
| 24 |
+
else:
|
| 25 |
+
# Use local client
|
| 26 |
+
try:
|
| 27 |
+
self.client = QdrantClient(host=host, port=port)
|
| 28 |
+
self._init_collection()
|
| 29 |
+
self.is_available = True
|
| 30 |
+
except Exception as e:
|
| 31 |
+
import logging
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
logger.warning(f"Vector database not available: {str(e)}. Running in fallback mode.")
|
| 34 |
+
self.client = None
|
| 35 |
+
|
| 36 |
+
def _init_collection(self):
|
| 37 |
+
"""Initialize the Qdrant collection for textbook content"""
|
| 38 |
+
if not self.is_available:
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# Check if collection exists
|
| 43 |
+
self.client.get_collection(self.collection_name)
|
| 44 |
+
except:
|
| 45 |
+
# Create collection if it doesn't exist
|
| 46 |
+
self.client.create_collection(
|
| 47 |
+
collection_name=self.collection_name,
|
| 48 |
+
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE), # Assuming OpenAI embeddings
|
| 49 |
+
)
|
| 50 |
+
logger.info(f"Created collection: {self.collection_name}")
|
| 51 |
+
|
| 52 |
+
def add_content(self, content_id: str, content: str, embeddings: List[float], metadata: Dict[str, Any] = None):
|
| 53 |
+
"""Add textbook content to the vector database"""
|
| 54 |
+
if not self.is_available:
|
| 55 |
+
logger.warning("Vector database not available, skipping content addition")
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
if metadata is None:
|
| 59 |
+
metadata = {}
|
| 60 |
+
|
| 61 |
+
self.client.upsert(
|
| 62 |
+
collection_name=self.collection_name,
|
| 63 |
+
points=[
|
| 64 |
+
models.PointStruct(
|
| 65 |
+
id=content_id,
|
| 66 |
+
vector=embeddings,
|
| 67 |
+
payload={
|
| 68 |
+
"content": content,
|
| 69 |
+
"metadata": metadata
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
logger.info(f"Added content to vector DB: {content_id}")
|
| 75 |
+
|
| 76 |
+
def search_content(self, query_embeddings: List[float], limit: int = 10) -> List[Dict[str, Any]]:
|
| 77 |
+
"""Search for relevant content based on query embeddings"""
|
| 78 |
+
if not self.is_available or self.client is None:
|
| 79 |
+
logger.warning("Vector database not available, returning empty results")
|
| 80 |
+
# Return empty list when database is not available
|
| 81 |
+
return []
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# Handle different Qdrant client versions
|
| 85 |
+
if hasattr(self.client, 'search'):
|
| 86 |
+
# Newer version of Qdrant client
|
| 87 |
+
results = self.client.search(
|
| 88 |
+
collection_name=self.collection_name,
|
| 89 |
+
query_vector=query_embeddings,
|
| 90 |
+
limit=limit
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
# Older version or different interface
|
| 94 |
+
results = self.client.search(
|
| 95 |
+
collection_name=self.collection_name,
|
| 96 |
+
query_vector=query_embeddings,
|
| 97 |
+
limit=limit
|
| 98 |
+
)
|
| 99 |
+
except AttributeError:
|
| 100 |
+
logger.warning("Qdrant client search method not available, using direct processing")
|
| 101 |
+
return []
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.warning(f"Vector database search failed, using direct processing: {str(e)}")
|
| 104 |
+
# Return empty list when search fails
|
| 105 |
+
return []
|
| 106 |
+
|
| 107 |
+
return [
|
| 108 |
+
{
|
| 109 |
+
"id": result.id,
|
| 110 |
+
"content": result.payload.get("content"),
|
| 111 |
+
"metadata": result.payload.get("metadata", {}),
|
| 112 |
+
"score": getattr(result, 'score', 0.0) # Handle different result structures
|
| 113 |
+
}
|
| 114 |
+
for result in results
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
def delete_content(self, content_id: str):
|
| 118 |
+
"""Delete content from the vector database"""
|
| 119 |
+
if not self.is_available:
|
| 120 |
+
logger.warning("Vector database not available, skipping content deletion")
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
self.client.delete(
|
| 124 |
+
collection_name=self.collection_name,
|
| 125 |
+
points_selector=models.PointIdsList(points=[content_id])
|
| 126 |
+
)
|
| 127 |
+
logger.info(f"Deleted content from vector DB: {content_id}")
|
setup_sample_content.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from qdrant_client import QdrantClient
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
import uuid
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
# Load environment variables from .env file in the project root
|
| 9 |
+
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env'))
|
| 10 |
+
|
| 11 |
+
# Add the backend directory to the path so we can import the RAG service
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 13 |
+
|
| 14 |
+
from services.rag_service import RAGService
|
| 15 |
+
|
| 16 |
+
def setup_sample_content():
|
| 17 |
+
# Get environment variables
|
| 18 |
+
openrouter_api_key = os.getenv("OPENAI_API_KEY")
|
| 19 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 20 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 21 |
+
collection_name = os.getenv("QDRANT_COLLECTION", "project_documents")
|
| 22 |
+
|
| 23 |
+
# Initialize Qdrant client for cloud
|
| 24 |
+
if qdrant_url and qdrant_api_key and "qdrant.io" in qdrant_url:
|
| 25 |
+
qdrant_client = QdrantClient(
|
| 26 |
+
url=qdrant_url.replace(":6333", ""), # Remove port from URL for cloud (same as in chat.py)
|
| 27 |
+
api_key=qdrant_api_key,
|
| 28 |
+
prefer_grpc=False
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
# Use local Qdrant if cloud not configured
|
| 32 |
+
qdrant_client = QdrantClient(
|
| 33 |
+
host=os.getenv("QDRANT_HOST", "localhost"),
|
| 34 |
+
port=int(os.getenv("QDRANT_PORT", 6333))
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Initialize RAG service
|
| 38 |
+
rag_service = RAGService(openrouter_api_key, qdrant_client, collection_name)
|
| 39 |
+
|
| 40 |
+
# Sample content about AI and Robotics
|
| 41 |
+
import uuid
|
| 42 |
+
|
| 43 |
+
sample_content = [
|
| 44 |
+
{
|
| 45 |
+
"id": str(uuid.uuid4()),
|
| 46 |
+
"content": "Introduction to Physical AI & Humanoid Robotics: Embodied Intelligence represents the convergence of artificial intelligence with physical systems. It's the principle that true intelligence emerges not just from abstract computation, but from the interaction between an intelligent system and its physical environment. In the context of humanoid robotics, this means creating machines that can perceive, reason, and act in the physical world much like humans do. This textbook combines cutting-edge robotics concepts with artificial intelligence to provide a deep understanding of embodied intelligence systems.",
|
| 47 |
+
"metadata": {"topic": "Introduction to Physical AI & Humanoid Robotics", "level": "beginner", "original_id": "intro_physical_ai_1"}
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"id": str(uuid.uuid4()),
|
| 51 |
+
"content": "Artificial Intelligence (AI) is a branch of computer science that aims to create software or machines that exhibit human-like intelligence. This can include learning from experience, understanding natural language, solving problems, and recognizing patterns. AI systems can be trained using various techniques including machine learning, deep learning, and neural networks.",
|
| 52 |
+
"metadata": {"topic": "AI Fundamentals", "level": "beginner", "original_id": "ai_fundamentals_1"}
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"id": str(uuid.uuid4()),
|
| 56 |
+
"content": "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data. Instead of being explicitly programmed, machine learning models improve their performance through experience with data. Common types include supervised learning, unsupervised learning, and reinforcement learning.",
|
| 57 |
+
"metadata": {"topic": "Machine Learning", "level": "beginner", "original_id": "machine_learning_1"}
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"id": str(uuid.uuid4()),
|
| 61 |
+
"content": "Robotics is an interdisciplinary field that combines mechanical engineering, electrical engineering, and computer science to design, construct, and operate robots. Modern robots can perform complex tasks in manufacturing, healthcare, exploration, and service industries. They often incorporate AI to enable autonomous decision-making and adaptive behavior.",
|
| 62 |
+
"metadata": {"topic": "Robotics", "level": "beginner", "original_id": "robotics_intro_1"}
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"id": str(uuid.uuid4()),
|
| 66 |
+
"content": "Neural networks are computing systems inspired by the human brain's structure and function. They consist of interconnected nodes (neurons) organized in layers. Deep learning uses neural networks with multiple hidden layers to recognize patterns and make predictions. They are particularly effective for image recognition, natural language processing, and complex decision-making tasks.",
|
| 67 |
+
"metadata": {"topic": "Neural Networks", "level": "intermediate", "original_id": "neural_networks_1"}
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"id": str(uuid.uuid4()),
|
| 71 |
+
"content": "Natural Language Processing (NLP) is a field of AI focused on enabling computers to understand, interpret, and generate human language. NLP techniques are used in chatbots, translation services, sentiment analysis, and text summarization. Modern NLP systems often use transformer architectures and large language models.",
|
| 72 |
+
"metadata": {"topic": "Natural Language Processing", "level": "intermediate", "original_id": "nlp_fundamentals_1"}
|
| 73 |
+
}
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
print(f"Indexing {len(sample_content)} content items into collection '{collection_name}'...")
|
| 77 |
+
|
| 78 |
+
for item in sample_content:
|
| 79 |
+
rag_service.index_content(item["id"], item["content"], item["metadata"])
|
| 80 |
+
print(f"Indexed: {item['id']} - {item['metadata']['topic']}")
|
| 81 |
+
|
| 82 |
+
print(f"\nSuccessfully indexed {len(sample_content)} items into '{collection_name}' collection!")
|
| 83 |
+
print("Your RAG system is now ready to answer questions about AI, Machine Learning, Robotics, Neural Networks, and NLP.")
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
setup_sample_content()
|
src/__init__.py
ADDED
|
File without changes
|
src/auth/__init__.py
ADDED
|
File without changes
|
src/auth/auth.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication module for the AI Backend with RAG + Authentication
|
| 3 |
+
Implements JWT-based authentication with password hashing
|
| 4 |
+
"""
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
import jwt
|
| 8 |
+
from passlib.context import CryptContext
|
| 9 |
+
from fastapi import HTTPException, status, Depends
|
| 10 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from ..config.settings import settings
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# Password hashing context
|
| 19 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 20 |
+
|
| 21 |
+
# JWT security scheme
|
| 22 |
+
security = HTTPBearer()
|
| 23 |
+
|
| 24 |
+
class TokenData(BaseModel):
|
| 25 |
+
username: Optional[str] = None
|
| 26 |
+
user_id: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class AuthHandler:
|
| 30 |
+
def __init__(self):
|
| 31 |
+
self.secret_key = settings.secret_key
|
| 32 |
+
self.algorithm = settings.jwt_algorithm
|
| 33 |
+
self.access_token_expires = timedelta(minutes=settings.jwt_expires_in // 60) # Convert seconds to minutes
|
| 34 |
+
|
| 35 |
+
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
| 36 |
+
"""
|
| 37 |
+
Verify a plain password against a hashed password
|
| 38 |
+
"""
|
| 39 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 40 |
+
|
| 41 |
+
def get_password_hash(self, password: str) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Generate a hash for a plain password
|
| 44 |
+
"""
|
| 45 |
+
return pwd_context.hash(password)
|
| 46 |
+
|
| 47 |
+
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
| 48 |
+
"""
|
| 49 |
+
Create a JWT access token with optional expiration time
|
| 50 |
+
"""
|
| 51 |
+
to_encode = data.copy()
|
| 52 |
+
|
| 53 |
+
if expires_delta:
|
| 54 |
+
expire = datetime.utcnow() + expires_delta
|
| 55 |
+
else:
|
| 56 |
+
expire = datetime.utcnow() + self.access_token_expires
|
| 57 |
+
|
| 58 |
+
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
| 59 |
+
|
| 60 |
+
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
| 61 |
+
return encoded_jwt
|
| 62 |
+
|
| 63 |
+
def decode_access_token(self, token: str) -> Optional[TokenData]:
|
| 64 |
+
"""
|
| 65 |
+
Decode a JWT token and return token data
|
| 66 |
+
"""
|
| 67 |
+
try:
|
| 68 |
+
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
| 69 |
+
username: str = payload.get("sub")
|
| 70 |
+
user_id: str = payload.get("user_id")
|
| 71 |
+
|
| 72 |
+
if username is None:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
token_data = TokenData(username=username, user_id=user_id)
|
| 76 |
+
return token_data
|
| 77 |
+
except jwt.exceptions.ExpiredSignatureError:
|
| 78 |
+
logger.warning("Expired token attempted to be decoded")
|
| 79 |
+
return None
|
| 80 |
+
except jwt.exceptions.InvalidTokenError:
|
| 81 |
+
logger.warning("Invalid token attempted to be decoded")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
async def get_current_user(self, token: str = Depends(security)) -> TokenData:
|
| 85 |
+
"""
|
| 86 |
+
Get the current user from the provided JWT token
|
| 87 |
+
This function can be used as a dependency in route handlers
|
| 88 |
+
"""
|
| 89 |
+
credentials_exception = HTTPException(
|
| 90 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 91 |
+
detail="Could not validate credentials",
|
| 92 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
token_data = self.decode_access_token(token.credentials)
|
| 97 |
+
if token_data is None:
|
| 98 |
+
raise credentials_exception
|
| 99 |
+
return token_data
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Error getting current user: {e}")
|
| 102 |
+
raise credentials_exception
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# Create a global instance of AuthHandler
|
| 106 |
+
auth_handler = AuthHandler()
|
| 107 |
+
|
| 108 |
+
# Convenience functions for use in other modules
|
| 109 |
+
def get_password_hash(password: str) -> str:
|
| 110 |
+
"""Generate a hash for a plain password"""
|
| 111 |
+
return auth_handler.get_password_hash(password)
|
| 112 |
+
|
| 113 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 114 |
+
"""Verify a plain password against a hashed password"""
|
| 115 |
+
return auth_handler.verify_password(plain_password, hashed_password)
|
| 116 |
+
|
| 117 |
+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
| 118 |
+
"""Create a JWT access token"""
|
| 119 |
+
return auth_handler.create_access_token(data, expires_delta)
|
| 120 |
+
|
| 121 |
+
def decode_access_token(token: str) -> Optional[TokenData]:
|
| 122 |
+
"""Decode a JWT token and return token data"""
|
| 123 |
+
return auth_handler.decode_access_token(token)
|
| 124 |
+
|
| 125 |
+
async def get_current_user(token: str = Depends(security)) -> TokenData:
|
| 126 |
+
"""Get the current user from the provided JWT token"""
|
| 127 |
+
return await auth_handler.get_current_user(token)
|
| 128 |
+
|
| 129 |
+
def create_user_token(user_id: str, username: str) -> str:
|
| 130 |
+
"""Create a token specifically for a user"""
|
| 131 |
+
data = {"sub": username, "user_id": user_id}
|
| 132 |
+
return create_access_token(data)
|
src/auth/middleware.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication middleware for the AI Backend with RAG + Authentication
|
| 3 |
+
Provides utilities for protecting routes with JWT authentication
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import HTTPException, status, Request
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from .auth import auth_handler, TokenData
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class AuthMiddleware:
|
| 14 |
+
"""
|
| 15 |
+
Authentication middleware class to protect routes
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
async def verify_token(request: Request) -> Optional[TokenData]:
|
| 20 |
+
"""
|
| 21 |
+
Verify the JWT token in the request headers
|
| 22 |
+
"""
|
| 23 |
+
# Get authorization header
|
| 24 |
+
auth_header = request.headers.get("Authorization")
|
| 25 |
+
if not auth_header or not auth_header.startswith("Bearer "):
|
| 26 |
+
raise HTTPException(
|
| 27 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 28 |
+
detail="Authorization header missing or invalid format",
|
| 29 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
token = auth_header[7:] # Remove "Bearer " prefix
|
| 33 |
+
token_data = auth_handler.decode_access_token(token)
|
| 34 |
+
|
| 35 |
+
if token_data is None:
|
| 36 |
+
raise HTTPException(
|
| 37 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 38 |
+
detail="Invalid or expired token",
|
| 39 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Add user info to request state for use in route handlers
|
| 43 |
+
request.state.user = token_data
|
| 44 |
+
return token_data
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
async def require_auth(request: Request) -> TokenData:
|
| 48 |
+
"""
|
| 49 |
+
Require authentication for a route
|
| 50 |
+
This can be used as a dependency in route handlers
|
| 51 |
+
"""
|
| 52 |
+
return await AuthMiddleware.verify_token(request)
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
async def optional_auth(request: Request) -> Optional[TokenData]:
|
| 56 |
+
"""
|
| 57 |
+
Optionally authenticate a user (returns None if no valid token)
|
| 58 |
+
This can be used as a dependency in route handlers
|
| 59 |
+
"""
|
| 60 |
+
try:
|
| 61 |
+
return await AuthMiddleware.verify_token(request)
|
| 62 |
+
except HTTPException:
|
| 63 |
+
# If token is invalid or missing, return None instead of raising error
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Convenience functions for use in route handlers
|
| 68 |
+
async def require_auth(request: Request) -> TokenData:
|
| 69 |
+
"""Require authentication for a route"""
|
| 70 |
+
return await AuthMiddleware.require_auth(request)
|
| 71 |
+
|
| 72 |
+
async def optional_auth(request: Request) -> Optional[TokenData]:
|
| 73 |
+
"""Optionally authenticate a user"""
|
| 74 |
+
return await AuthMiddleware.optional_auth(request)
|
src/auth/schemas.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication schemas for request/response validation
|
| 3 |
+
"""
|
| 4 |
+
from pydantic import BaseModel, EmailStr
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import uuid
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UserBase(BaseModel):
|
| 11 |
+
email: EmailStr
|
| 12 |
+
full_name: Optional[str] = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class UserCreate(UserBase):
|
| 16 |
+
password: str
|
| 17 |
+
|
| 18 |
+
class Config:
|
| 19 |
+
from_attributes = True
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class UserUpdate(BaseModel):
|
| 23 |
+
full_name: Optional[str] = None
|
| 24 |
+
email: Optional[EmailStr] = None
|
| 25 |
+
|
| 26 |
+
class Config:
|
| 27 |
+
from_attributes = True
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class UserInDB(UserBase):
|
| 31 |
+
id: uuid.UUID
|
| 32 |
+
is_active: bool
|
| 33 |
+
created_at: datetime
|
| 34 |
+
updated_at: Optional[datetime] = None
|
| 35 |
+
|
| 36 |
+
class Config:
|
| 37 |
+
from_attributes = True
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class UserLogin(BaseModel):
|
| 41 |
+
email: EmailStr
|
| 42 |
+
password: str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Token(BaseModel):
|
| 46 |
+
access_token: str
|
| 47 |
+
token_type: str = "bearer"
|
| 48 |
+
expires_in: int
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TokenData(BaseModel):
|
| 52 |
+
user_id: Optional[str] = None
|
| 53 |
+
username: Optional[str] = None
|
src/config/__init__.py
ADDED
|
File without changes
|
src/config/database.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
| 2 |
+
from sqlalchemy.orm import sessionmaker
|
| 3 |
+
from .settings import settings
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
# Create async engine with proper configuration from settings
|
| 10 |
+
engine = create_async_engine(
|
| 11 |
+
settings.neon_db_url,
|
| 12 |
+
echo=settings.debug, # Set to True to log SQL queries
|
| 13 |
+
pool_pre_ping=True, # Verify connections before use
|
| 14 |
+
pool_size=20, # Connection pool size
|
| 15 |
+
max_overflow=30, # Additional connections beyond pool_size
|
| 16 |
+
pool_recycle=3600, # Recycle connections after 1 hour
|
| 17 |
+
pool_pre_ping_enabled=True, # Enable connection health checks
|
| 18 |
+
pool_pool_timeout=30, # Connection timeout
|
| 19 |
+
pool_reset_on_return='commit' # Reset connection on return
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Create async session factory
|
| 23 |
+
AsyncSessionLocal = sessionmaker(
|
| 24 |
+
engine,
|
| 25 |
+
class_=AsyncSession,
|
| 26 |
+
expire_on_commit=False
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
logger.info("Database engine created successfully")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Failed to create database engine: {e}")
|
| 32 |
+
raise
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
async def get_db_session():
|
| 36 |
+
"""Dependency to get database session"""
|
| 37 |
+
async with AsyncSessionLocal() as session:
|
| 38 |
+
try:
|
| 39 |
+
yield session
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.error(f"Database session error: {e}")
|
| 42 |
+
await session.rollback()
|
| 43 |
+
raise
|
| 44 |
+
finally:
|
| 45 |
+
await session.close()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Initialize the database connection
|
| 49 |
+
async def init_db():
|
| 50 |
+
"""Initialize the database connection and create tables if needed"""
|
| 51 |
+
from ..db.base import Base
|
| 52 |
+
|
| 53 |
+
logger.info("Initializing database connection...")
|
| 54 |
+
try:
|
| 55 |
+
# Create all tables
|
| 56 |
+
async with engine.begin() as conn:
|
| 57 |
+
await conn.run_sync(Base.metadata.create_all)
|
| 58 |
+
logger.info("Database tables created successfully")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Failed to initialize database: {e}")
|
| 61 |
+
raise
|
src/config/settings.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from pydantic import ValidationError, field_validator
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Settings(BaseSettings):
|
| 10 |
+
# Database settings
|
| 11 |
+
neon_db_url: str
|
| 12 |
+
|
| 13 |
+
# Qdrant settings
|
| 14 |
+
qdrant_url: str
|
| 15 |
+
qdrant_api_key: Optional[str] = None
|
| 16 |
+
|
| 17 |
+
# Gemini API settings
|
| 18 |
+
gemini_api_key: str
|
| 19 |
+
|
| 20 |
+
# JWT settings
|
| 21 |
+
secret_key: str
|
| 22 |
+
jwt_algorithm: str = "HS256"
|
| 23 |
+
jwt_expires_in: int = 3600 # 1 hour default
|
| 24 |
+
|
| 25 |
+
# Application settings
|
| 26 |
+
debug: bool = False
|
| 27 |
+
log_level: str = "info"
|
| 28 |
+
|
| 29 |
+
# Server settings
|
| 30 |
+
server_host: str = "0.0.0.0"
|
| 31 |
+
server_port: int = 8000
|
| 32 |
+
|
| 33 |
+
@field_validator('neon_db_url', 'qdrant_url', 'gemini_api_key', 'secret_key')
|
| 34 |
+
@classmethod
|
| 35 |
+
def validate_required_fields(cls, v, info):
|
| 36 |
+
if not v:
|
| 37 |
+
raise ValueError(f"{info.field_name} is required and must be set in environment variables")
|
| 38 |
+
return v
|
| 39 |
+
|
| 40 |
+
@field_validator('debug')
|
| 41 |
+
@classmethod
|
| 42 |
+
def validate_debug(cls, v):
|
| 43 |
+
if isinstance(v, str):
|
| 44 |
+
return v.lower() in ['true', '1', 'yes', 'on']
|
| 45 |
+
return bool(v)
|
| 46 |
+
|
| 47 |
+
@field_validator('jwt_expires_in')
|
| 48 |
+
@classmethod
|
| 49 |
+
def validate_jwt_expires_in(cls, v):
|
| 50 |
+
if v <= 0:
|
| 51 |
+
raise ValueError("JWT expires in must be a positive integer")
|
| 52 |
+
return v
|
| 53 |
+
|
| 54 |
+
class Config:
|
| 55 |
+
env_file = ".env"
|
| 56 |
+
env_file_encoding = 'utf-8'
|
| 57 |
+
case_sensitive = True
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Create a single instance of settings with error handling
|
| 61 |
+
try:
|
| 62 |
+
settings = Settings()
|
| 63 |
+
logger.info("Configuration loaded successfully")
|
| 64 |
+
except ValidationError as e:
|
| 65 |
+
logger.error(f"Configuration validation error: {e}")
|
| 66 |
+
raise
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Configuration error: {e}")
|
| 69 |
+
raise
|
src/db/__init__.py
ADDED
|
File without changes
|
src/db/base.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base class for SQLAlchemy models
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy.orm import DeclarativeBase
|
| 5 |
+
from sqlalchemy import Column, DateTime, func
|
| 6 |
+
from sqlalchemy.ext.asyncio import AsyncAttrs
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import uuid
|
| 9 |
+
from sqlalchemy.dialects.postgresql import UUID
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Base(AsyncAttrs, DeclarativeBase):
|
| 13 |
+
"""
|
| 14 |
+
Base class for all SQLAlchemy models
|
| 15 |
+
Includes common columns and configurations
|
| 16 |
+
"""
|
| 17 |
+
__abstract__ = True
|
| 18 |
+
|
| 19 |
+
# Common columns for all models
|
| 20 |
+
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
| 21 |
+
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
| 22 |
+
|
| 23 |
+
def __init__(self, *args, **kwargs):
|
| 24 |
+
# Set the id automatically if not provided
|
| 25 |
+
if 'id' not in kwargs and hasattr(self, 'id') and self.id is None:
|
| 26 |
+
# For models that have an id column, set a default UUID if not provided
|
| 27 |
+
pass # The column default will handle this
|
| 28 |
+
super().__init__(*args, **kwargs)
|
src/db/crud.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CRUD operations for the AI Backend with RAG + Authentication
|
| 3 |
+
Implements Create, Read, Update, Delete operations for all models
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
from uuid import UUID
|
| 7 |
+
from sqlalchemy import select, update, delete
|
| 8 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 9 |
+
from sqlalchemy.exc import IntegrityError
|
| 10 |
+
from fastapi import HTTPException, status
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from .models.user import User
|
| 14 |
+
from .models.chat_history import ChatHistory
|
| 15 |
+
from .models.document import Document
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# User CRUD Operations
|
| 21 |
+
async def create_user(db: AsyncSession, email: str, hashed_password: str, full_name: Optional[str] = None) -> User:
|
| 22 |
+
"""Create a new user"""
|
| 23 |
+
try:
|
| 24 |
+
db_user = User(
|
| 25 |
+
email=email,
|
| 26 |
+
hashed_password=hashed_password,
|
| 27 |
+
full_name=full_name
|
| 28 |
+
)
|
| 29 |
+
db.add(db_user)
|
| 30 |
+
await db.commit()
|
| 31 |
+
await db.refresh(db_user)
|
| 32 |
+
logger.info(f"User created with email: {email}")
|
| 33 |
+
return db_user
|
| 34 |
+
except IntegrityError:
|
| 35 |
+
await db.rollback()
|
| 36 |
+
logger.warning(f"User with email {email} already exists")
|
| 37 |
+
raise HTTPException(
|
| 38 |
+
status_code=status.HTTP_409_CONFLICT,
|
| 39 |
+
detail="User with this email already exists"
|
| 40 |
+
)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
await db.rollback()
|
| 43 |
+
logger.error(f"Error creating user: {e}")
|
| 44 |
+
raise
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
async def get_user_by_id(db: AsyncSession, user_id: UUID) -> Optional[User]:
|
| 48 |
+
"""Get a user by ID"""
|
| 49 |
+
try:
|
| 50 |
+
result = await db.execute(select(User).filter(User.id == user_id))
|
| 51 |
+
user = result.scalar_one_or_none()
|
| 52 |
+
return user
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Error getting user by ID: {e}")
|
| 55 |
+
raise
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
|
| 59 |
+
"""Get a user by email"""
|
| 60 |
+
try:
|
| 61 |
+
result = await db.execute(select(User).filter(User.email == email))
|
| 62 |
+
user = result.scalar_one_or_none()
|
| 63 |
+
return user
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"Error getting user by email: {e}")
|
| 66 |
+
raise
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def update_user(db: AsyncSession, user_id: UUID, **kwargs) -> Optional[User]:
|
| 70 |
+
"""Update a user"""
|
| 71 |
+
try:
|
| 72 |
+
query = update(User).where(User.id == user_id).values(**kwargs).returning(User)
|
| 73 |
+
result = await db.execute(query)
|
| 74 |
+
await db.commit()
|
| 75 |
+
|
| 76 |
+
updated_user = result.scalar_one_or_none()
|
| 77 |
+
if updated_user:
|
| 78 |
+
logger.info(f"User updated with ID: {user_id}")
|
| 79 |
+
return updated_user
|
| 80 |
+
except Exception as e:
|
| 81 |
+
await db.rollback()
|
| 82 |
+
logger.error(f"Error updating user: {e}")
|
| 83 |
+
raise
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
async def delete_user(db: AsyncSession, user_id: UUID) -> bool:
|
| 87 |
+
"""Delete a user"""
|
| 88 |
+
try:
|
| 89 |
+
result = await db.execute(delete(User).where(User.id == user_id))
|
| 90 |
+
await db.commit()
|
| 91 |
+
deleted_count = result.rowcount
|
| 92 |
+
if deleted_count > 0:
|
| 93 |
+
logger.info(f"User deleted with ID: {user_id}")
|
| 94 |
+
return deleted_count > 0
|
| 95 |
+
except Exception as e:
|
| 96 |
+
await db.rollback()
|
| 97 |
+
logger.error(f"Error deleting user: {e}")
|
| 98 |
+
raise
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
async def list_users(db: AsyncSession, skip: int = 0, limit: int = 100) -> List[User]:
|
| 102 |
+
"""List users with pagination"""
|
| 103 |
+
try:
|
| 104 |
+
result = await db.execute(select(User).offset(skip).limit(limit))
|
| 105 |
+
users = result.scalars().all()
|
| 106 |
+
return users
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error listing users: {e}")
|
| 109 |
+
raise
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ChatHistory CRUD Operations
|
| 113 |
+
async def create_chat_history(db: AsyncSession, user_id: UUID, query: str, response: str, context_used: Optional[str] = None) -> ChatHistory:
|
| 114 |
+
"""Create a new chat history record"""
|
| 115 |
+
try:
|
| 116 |
+
db_chat_history = ChatHistory(
|
| 117 |
+
user_id=user_id,
|
| 118 |
+
query=query,
|
| 119 |
+
response=response,
|
| 120 |
+
context_used=context_used
|
| 121 |
+
)
|
| 122 |
+
db.add(db_chat_history)
|
| 123 |
+
await db.commit()
|
| 124 |
+
await db.refresh(db_chat_history)
|
| 125 |
+
logger.info(f"Chat history created for user: {user_id}")
|
| 126 |
+
return db_chat_history
|
| 127 |
+
except Exception as e:
|
| 128 |
+
await db.rollback()
|
| 129 |
+
logger.error(f"Error creating chat history: {e}")
|
| 130 |
+
raise
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
async def get_chat_history_by_id(db: AsyncSession, chat_history_id: UUID) -> Optional[ChatHistory]:
|
| 134 |
+
"""Get a chat history record by ID"""
|
| 135 |
+
try:
|
| 136 |
+
result = await db.execute(select(ChatHistory).filter(ChatHistory.id == chat_history_id))
|
| 137 |
+
chat_history = result.scalar_one_or_none()
|
| 138 |
+
return chat_history
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Error getting chat history by ID: {e}")
|
| 141 |
+
raise
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
async def get_chat_histories_by_user(db: AsyncSession, user_id: UUID, skip: int = 0, limit: int = 100) -> List[ChatHistory]:
|
| 145 |
+
"""Get all chat histories for a user"""
|
| 146 |
+
try:
|
| 147 |
+
result = await db.execute(
|
| 148 |
+
select(ChatHistory)
|
| 149 |
+
.filter(ChatHistory.user_id == user_id)
|
| 150 |
+
.order_by(ChatHistory.created_at.desc())
|
| 151 |
+
.offset(skip)
|
| 152 |
+
.limit(limit)
|
| 153 |
+
)
|
| 154 |
+
chat_histories = result.scalars().all()
|
| 155 |
+
return chat_histories
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Error getting chat histories by user: {e}")
|
| 158 |
+
raise
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
async def update_chat_history(db: AsyncSession, chat_history_id: UUID, **kwargs) -> Optional[ChatHistory]:
|
| 162 |
+
"""Update a chat history record"""
|
| 163 |
+
try:
|
| 164 |
+
query = update(ChatHistory).where(ChatHistory.id == chat_history_id).values(**kwargs).returning(ChatHistory)
|
| 165 |
+
result = await db.execute(query)
|
| 166 |
+
await db.commit()
|
| 167 |
+
|
| 168 |
+
updated_chat_history = result.scalar_one_or_none()
|
| 169 |
+
if updated_chat_history:
|
| 170 |
+
logger.info(f"Chat history updated with ID: {chat_history_id}")
|
| 171 |
+
return updated_chat_history
|
| 172 |
+
except Exception as e:
|
| 173 |
+
await db.rollback()
|
| 174 |
+
logger.error(f"Error updating chat history: {e}")
|
| 175 |
+
raise
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
async def delete_chat_history(db: AsyncSession, chat_history_id: UUID) -> bool:
|
| 179 |
+
"""Delete a chat history record"""
|
| 180 |
+
try:
|
| 181 |
+
result = await db.execute(delete(ChatHistory).where(ChatHistory.id == chat_history_id))
|
| 182 |
+
await db.commit()
|
| 183 |
+
deleted_count = result.rowcount
|
| 184 |
+
if deleted_count > 0:
|
| 185 |
+
logger.info(f"Chat history deleted with ID: {chat_history_id}")
|
| 186 |
+
return deleted_count > 0
|
| 187 |
+
except Exception as e:
|
| 188 |
+
await db.rollback()
|
| 189 |
+
logger.error(f"Error deleting chat history: {e}")
|
| 190 |
+
raise
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Document CRUD Operations
|
| 194 |
+
async def create_document(db: AsyncSession, user_id: UUID, title: str, content: str, content_hash: str,
|
| 195 |
+
file_path: Optional[str] = None, metadata: Optional[dict] = None) -> Document:
|
| 196 |
+
"""Create a new document"""
|
| 197 |
+
try:
|
| 198 |
+
db_document = Document(
|
| 199 |
+
user_id=user_id,
|
| 200 |
+
title=title,
|
| 201 |
+
content=content,
|
| 202 |
+
content_hash=content_hash,
|
| 203 |
+
file_path=file_path,
|
| 204 |
+
metadata=metadata
|
| 205 |
+
)
|
| 206 |
+
db.add(db_document)
|
| 207 |
+
await db.commit()
|
| 208 |
+
await db.refresh(db_document)
|
| 209 |
+
logger.info(f"Document created for user: {user_id}, title: {title}")
|
| 210 |
+
return db_document
|
| 211 |
+
except IntegrityError:
|
| 212 |
+
await db.rollback()
|
| 213 |
+
logger.warning(f"Document with content_hash {content_hash} already exists")
|
| 214 |
+
raise HTTPException(
|
| 215 |
+
status_code=status.HTTP_409_CONFLICT,
|
| 216 |
+
detail="Document with this content already exists"
|
| 217 |
+
)
|
| 218 |
+
except Exception as e:
|
| 219 |
+
await db.rollback()
|
| 220 |
+
logger.error(f"Error creating document: {e}")
|
| 221 |
+
raise
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
async def get_document_by_id(db: AsyncSession, document_id: UUID) -> Optional[Document]:
|
| 225 |
+
"""Get a document by ID"""
|
| 226 |
+
try:
|
| 227 |
+
result = await db.execute(select(Document).filter(Document.id == document_id))
|
| 228 |
+
document = result.scalar_one_or_none()
|
| 229 |
+
return document
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Error getting document by ID: {e}")
|
| 232 |
+
raise
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
async def get_documents_by_user(db: AsyncSession, user_id: UUID, skip: int = 0, limit: int = 100) -> List[Document]:
|
| 236 |
+
"""Get all documents for a user"""
|
| 237 |
+
try:
|
| 238 |
+
result = await db.execute(
|
| 239 |
+
select(Document)
|
| 240 |
+
.filter(Document.user_id == user_id)
|
| 241 |
+
.order_by(Document.created_at.desc())
|
| 242 |
+
.offset(skip)
|
| 243 |
+
.limit(limit)
|
| 244 |
+
)
|
| 245 |
+
documents = result.scalars().all()
|
| 246 |
+
return documents
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.error(f"Error getting documents by user: {e}")
|
| 249 |
+
raise
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
async def get_document_by_hash(db: AsyncSession, content_hash: str) -> Optional[Document]:
|
| 253 |
+
"""Get a document by content hash"""
|
| 254 |
+
try:
|
| 255 |
+
result = await db.execute(select(Document).filter(Document.content_hash == content_hash))
|
| 256 |
+
document = result.scalar_one_or_none()
|
| 257 |
+
return document
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.error(f"Error getting document by hash: {e}")
|
| 260 |
+
raise
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
async def update_document(db: AsyncSession, document_id: UUID, **kwargs) -> Optional[Document]:
|
| 264 |
+
"""Update a document"""
|
| 265 |
+
try:
|
| 266 |
+
query = update(Document).where(Document.id == document_id).values(**kwargs).returning(Document)
|
| 267 |
+
result = await db.execute(query)
|
| 268 |
+
await db.commit()
|
| 269 |
+
|
| 270 |
+
updated_document = result.scalar_one_or_none()
|
| 271 |
+
if updated_document:
|
| 272 |
+
logger.info(f"Document updated with ID: {document_id}")
|
| 273 |
+
return updated_document
|
| 274 |
+
except Exception as e:
|
| 275 |
+
await db.rollback()
|
| 276 |
+
logger.error(f"Error updating document: {e}")
|
| 277 |
+
raise
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
async def delete_document(db: AsyncSession, document_id: UUID) -> bool:
|
| 281 |
+
"""Delete a document"""
|
| 282 |
+
try:
|
| 283 |
+
result = await db.execute(delete(Document).where(Document.id == document_id))
|
| 284 |
+
await db.commit()
|
| 285 |
+
deleted_count = result.rowcount
|
| 286 |
+
if deleted_count > 0:
|
| 287 |
+
logger.info(f"Document deleted with ID: {document_id}")
|
| 288 |
+
return deleted_count > 0
|
| 289 |
+
except Exception as e:
|
| 290 |
+
await db.rollback()
|
| 291 |
+
logger.error(f"Error deleting document: {e}")
|
| 292 |
+
raise
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Chat History CRUD Operations
|
| 296 |
+
async def create_chat_history_entry(db: AsyncSession, user_id: UUID, query: str, response: str, context_used: Optional[str] = None) -> ChatHistory:
|
| 297 |
+
"""Create a new chat history entry"""
|
| 298 |
+
try:
|
| 299 |
+
db_chat_history = ChatHistory(
|
| 300 |
+
user_id=user_id,
|
| 301 |
+
query=query,
|
| 302 |
+
response=response,
|
| 303 |
+
context_used=context_used
|
| 304 |
+
)
|
| 305 |
+
db.add(db_chat_history)
|
| 306 |
+
await db.commit()
|
| 307 |
+
await db.refresh(db_chat_history)
|
| 308 |
+
logger.info(f"Chat history created for user: {user_id}")
|
| 309 |
+
return db_chat_history
|
| 310 |
+
except Exception as e:
|
| 311 |
+
await db.rollback()
|
| 312 |
+
logger.error(f"Error creating chat history: {e}")
|
| 313 |
+
raise
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
async def get_chat_history_by_id(db: AsyncSession, chat_history_id: UUID) -> Optional[ChatHistory]:
|
| 317 |
+
"""Get a chat history record by ID"""
|
| 318 |
+
try:
|
| 319 |
+
result = await db.execute(select(ChatHistory).filter(ChatHistory.id == chat_history_id))
|
| 320 |
+
chat_history = result.scalar_one_or_none()
|
| 321 |
+
return chat_history
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.error(f"Error getting chat history by ID: {e}")
|
| 324 |
+
raise
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
async def get_chat_histories_by_user(db: AsyncSession, user_id: UUID, skip: int = 0, limit: int = 100) -> List[ChatHistory]:
|
| 328 |
+
"""Get all chat histories for a user"""
|
| 329 |
+
try:
|
| 330 |
+
result = await db.execute(
|
| 331 |
+
select(ChatHistory)
|
| 332 |
+
.filter(ChatHistory.user_id == user_id)
|
| 333 |
+
.order_by(ChatHistory.created_at.desc())
|
| 334 |
+
.offset(skip)
|
| 335 |
+
.limit(limit)
|
| 336 |
+
)
|
| 337 |
+
chat_histories = result.scalars().all()
|
| 338 |
+
return chat_histories
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.error(f"Error getting chat histories by user: {e}")
|
| 341 |
+
raise
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
async def get_user_chat_history_count(db: AsyncSession, user_id: UUID) -> int:
|
| 345 |
+
"""Get the count of chat history records for a user"""
|
| 346 |
+
try:
|
| 347 |
+
from sqlalchemy import func
|
| 348 |
+
result = await db.execute(
|
| 349 |
+
select(func.count(ChatHistory.id))
|
| 350 |
+
.filter(ChatHistory.user_id == user_id)
|
| 351 |
+
)
|
| 352 |
+
count = result.scalar_one()
|
| 353 |
+
return count
|
| 354 |
+
except Exception as e:
|
| 355 |
+
logger.error(f"Error getting chat history count for user: {e}")
|
| 356 |
+
raise
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
async def update_chat_history(db: AsyncSession, chat_history_id: UUID, **kwargs) -> Optional[ChatHistory]:
|
| 360 |
+
"""Update a chat history record"""
|
| 361 |
+
try:
|
| 362 |
+
query = update(ChatHistory).where(ChatHistory.id == chat_history_id).values(**kwargs).returning(ChatHistory)
|
| 363 |
+
result = await db.execute(query)
|
| 364 |
+
await db.commit()
|
| 365 |
+
|
| 366 |
+
updated_chat_history = result.scalar_one_or_none()
|
| 367 |
+
if updated_chat_history:
|
| 368 |
+
logger.info(f"Chat history updated with ID: {chat_history_id}")
|
| 369 |
+
return updated_chat_history
|
| 370 |
+
except Exception as e:
|
| 371 |
+
await db.rollback()
|
| 372 |
+
logger.error(f"Error updating chat history: {e}")
|
| 373 |
+
raise
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
async def delete_chat_history(db: AsyncSession, chat_history_id: UUID) -> bool:
|
| 377 |
+
"""Delete a chat history record"""
|
| 378 |
+
try:
|
| 379 |
+
result = await db.execute(delete(ChatHistory).where(ChatHistory.id == chat_history_id))
|
| 380 |
+
await db.commit()
|
| 381 |
+
deleted_count = result.rowcount
|
| 382 |
+
if deleted_count > 0:
|
| 383 |
+
logger.info(f"Chat history deleted with ID: {chat_history_id}")
|
| 384 |
+
return deleted_count > 0
|
| 385 |
+
except Exception as e:
|
| 386 |
+
await db.rollback()
|
| 387 |
+
logger.error(f"Error deleting chat history: {e}")
|
| 388 |
+
raise
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
async def delete_user_chat_history(db: AsyncSession, user_id: UUID) -> bool:
|
| 392 |
+
"""Delete all chat history records for a user"""
|
| 393 |
+
try:
|
| 394 |
+
result = await db.execute(delete(ChatHistory).where(ChatHistory.user_id == user_id))
|
| 395 |
+
await db.commit()
|
| 396 |
+
deleted_count = result.rowcount
|
| 397 |
+
if deleted_count > 0:
|
| 398 |
+
logger.info(f"Deleted {deleted_count} chat history records for user: {user_id}")
|
| 399 |
+
return deleted_count > 0
|
| 400 |
+
except Exception as e:
|
| 401 |
+
await db.rollback()
|
| 402 |
+
logger.error(f"Error deleting user chat history: {e}")
|
| 403 |
+
raise
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# Utility functions
|
| 407 |
+
async def get_user_with_chat_histories(db: AsyncSession, user_id: UUID) -> Optional[User]:
|
| 408 |
+
"""Get a user with their chat histories"""
|
| 409 |
+
try:
|
| 410 |
+
result = await db.execute(
|
| 411 |
+
select(User)
|
| 412 |
+
.filter(User.id == user_id)
|
| 413 |
+
)
|
| 414 |
+
user = result.scalar_one_or_none()
|
| 415 |
+
return user
|
| 416 |
+
except Exception as e:
|
| 417 |
+
logger.error(f"Error getting user with chat histories: {e}")
|
| 418 |
+
raise
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
async def get_user_with_documents(db: AsyncSession, user_id: UUID) -> Optional[User]:
|
| 422 |
+
"""Get a user with their documents"""
|
| 423 |
+
try:
|
| 424 |
+
result = await db.execute(
|
| 425 |
+
select(User)
|
| 426 |
+
.filter(User.id == user_id)
|
| 427 |
+
)
|
| 428 |
+
user = result.scalar_one_or_none()
|
| 429 |
+
return user
|
| 430 |
+
except Exception as e:
|
| 431 |
+
logger.error(f"Error getting user with documents: {e}")
|
| 432 |
+
raise
|
src/db/models/__init__.py
ADDED
|
File without changes
|
src/db/models/chat_history.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ChatHistory model for the AI Backend with RAG + Authentication
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy import Column, String, Text, ForeignKey, Index
|
| 5 |
+
from sqlalchemy.dialects.postgresql import UUID
|
| 6 |
+
from sqlalchemy.orm import relationship
|
| 7 |
+
from uuid import uuid4
|
| 8 |
+
from ...db.base import Base
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ChatHistory(Base):
|
| 12 |
+
__tablename__ = "chat_history"
|
| 13 |
+
|
| 14 |
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, unique=True, nullable=False)
|
| 15 |
+
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
| 16 |
+
query = Column(Text, nullable=False)
|
| 17 |
+
response = Column(Text, nullable=False)
|
| 18 |
+
context_used = Column(Text, nullable=True) # JSON string of context snippets used
|
| 19 |
+
|
| 20 |
+
# Relationships
|
| 21 |
+
user = relationship("User", back_populates="chat_histories")
|
| 22 |
+
|
| 23 |
+
def __repr__(self):
|
| 24 |
+
return f"<ChatHistory(id={self.id}, user_id={self.user_id}, query='{self.query[:30]}...')>"
|
| 25 |
+
|
| 26 |
+
# Create indexes
|
| 27 |
+
Index('idx_chat_history_user_id', 'user_id')
|
| 28 |
+
Index('idx_chat_history_created_at', 'created_at')
|
src/db/models/document.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document model for the AI Backend with RAG + Authentication
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy import Column, String, Text, ForeignKey, Index
|
| 5 |
+
from sqlalchemy.dialects.postgresql import UUID, JSON
|
| 6 |
+
from sqlalchemy.orm import relationship
|
| 7 |
+
from uuid import uuid4
|
| 8 |
+
from ...db.base import Base
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Document(Base):
|
| 12 |
+
__tablename__ = "documents"
|
| 13 |
+
|
| 14 |
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, unique=True, nullable=False)
|
| 15 |
+
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
| 16 |
+
title = Column(String(255), nullable=False)
|
| 17 |
+
content = Column(Text, nullable=False)
|
| 18 |
+
content_hash = Column(String(255), nullable=False, index=True) # For deduplication
|
| 19 |
+
file_path = Column(String(500), nullable=True) # Path if uploaded file
|
| 20 |
+
metadata = Column(JSON, nullable=True) # Additional metadata as JSON
|
| 21 |
+
|
| 22 |
+
# Relationships
|
| 23 |
+
user = relationship("User", back_populates="documents")
|
| 24 |
+
|
| 25 |
+
def __repr__(self):
|
| 26 |
+
return f"<Document(id={self.id}, user_id={self.user_id}, title='{self.title}')>"
|
| 27 |
+
|
| 28 |
+
# Create indexes
|
| 29 |
+
Index('idx_document_user_id', 'user_id')
|
| 30 |
+
Index('idx_document_content_hash', 'content_hash')
|
| 31 |
+
Index('idx_document_title', 'title')
|
src/db/models/user.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User model for the AI Backend with RAG + Authentication
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy import Column, String, Boolean, Text, Index
|
| 5 |
+
from sqlalchemy.dialects.postgresql import UUID
|
| 6 |
+
from sqlalchemy.orm import relationship
|
| 7 |
+
from uuid import uuid4
|
| 8 |
+
from ...db.base import Base
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class User(Base):
|
| 12 |
+
__tablename__ = "users"
|
| 13 |
+
|
| 14 |
+
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, unique=True, nullable=False)
|
| 15 |
+
email = Column(String(255), unique=True, nullable=False, index=True)
|
| 16 |
+
hashed_password = Column(Text, nullable=False)
|
| 17 |
+
full_name = Column(String(255), nullable=True)
|
| 18 |
+
is_active = Column(Boolean, default=True, nullable=False)
|
| 19 |
+
|
| 20 |
+
# Relationships
|
| 21 |
+
chat_histories = relationship("ChatHistory", back_populates="user", cascade="all, delete-orphan")
|
| 22 |
+
documents = relationship("Document", back_populates="user", cascade="all, delete-orphan")
|
| 23 |
+
|
| 24 |
+
def __repr__(self):
|
| 25 |
+
return f"<User(id={self.id}, email='{self.email}', full_name='{self.full_name}')>"
|
| 26 |
+
|
| 27 |
+
# Create indexes
|
| 28 |
+
Index('idx_user_email', User.email)
|
src/embeddings/__init__.py
ADDED
|
File without changes
|
src/embeddings/gemini_client.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini API client for the AI Backend with RAG + Authentication
|
| 3 |
+
Implements embedding generation and chat functionality using Google's Gemini API
|
| 4 |
+
"""
|
| 5 |
+
import google.generativeai as genai
|
| 6 |
+
from google.generativeai import embedding
|
| 7 |
+
from typing import List, Optional, Dict, Any
|
| 8 |
+
import logging
|
| 9 |
+
import time
|
| 10 |
+
import asyncio
|
| 11 |
+
from functools import wraps
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
from ..config.settings import settings
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# Initialize the Gemini API client with the API key from settings
|
| 19 |
+
genai.configure(api_key=settings.gemini_api_key)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class GeminiClient:
|
| 23 |
+
"""
|
| 24 |
+
Client class to handle both Gemini embedding and chat operations
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
# Use the text-embedding-004 model for embeddings
|
| 29 |
+
self.embedding_model_name = "text-embedding-004"
|
| 30 |
+
# Use the Gemini 1.5 Flash model for chat responses (faster and more cost-effective)
|
| 31 |
+
self.chat_model_name = "gemini-1.5-flash-001" # Updated model name
|
| 32 |
+
self.client = genai
|
| 33 |
+
self.max_retries = 3
|
| 34 |
+
self.retry_delay = 1 # seconds
|
| 35 |
+
|
| 36 |
+
# Initialize the chat model
|
| 37 |
+
self.chat_model = genai.GenerativeModel(self.chat_model_name)
|
| 38 |
+
|
| 39 |
+
# EMBEDDING METHODS
|
| 40 |
+
async def generate_embedding(self, text: str) -> Optional[List[float]]:
|
| 41 |
+
"""
|
| 42 |
+
Generate embedding for the given text using Gemini text-embedding-004 model
|
| 43 |
+
"""
|
| 44 |
+
for attempt in range(self.max_retries):
|
| 45 |
+
try:
|
| 46 |
+
# Generate the embedding using the Gemini API
|
| 47 |
+
result = await asyncio.get_event_loop().run_in_executor(
|
| 48 |
+
None,
|
| 49 |
+
lambda: genai.embed_content(
|
| 50 |
+
model=self.embedding_model_name,
|
| 51 |
+
content=text,
|
| 52 |
+
task_type="RETRIEVAL_DOCUMENT", # Optimal for RAG applications
|
| 53 |
+
title="Document" # Title can help with embedding quality
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
embedding_values = result['embedding']
|
| 58 |
+
|
| 59 |
+
# Verify the embedding has the correct dimensions (1536 for text-embedding-004)
|
| 60 |
+
if len(embedding_values) != 1536:
|
| 61 |
+
logger.warning(f"Generated embedding has {len(embedding_values)} dimensions, expected 1536")
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
logger.info(f"Successfully generated embedding for text of length {len(text)}")
|
| 65 |
+
return embedding_values
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.warning(f"Attempt {attempt + 1} failed to generate embedding: {e}")
|
| 69 |
+
|
| 70 |
+
if attempt == self.max_retries - 1:
|
| 71 |
+
# Last attempt failed
|
| 72 |
+
logger.error(f"Failed to generate embedding after {self.max_retries} attempts: {e}")
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
# Wait before retrying
|
| 76 |
+
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # Exponential backoff
|
| 77 |
+
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
async def generate_embeddings_batch(self, texts: List[str]) -> Optional[List[List[float]]]:
|
| 81 |
+
"""
|
| 82 |
+
Generate embeddings for a batch of texts
|
| 83 |
+
"""
|
| 84 |
+
embeddings = []
|
| 85 |
+
|
| 86 |
+
for text in texts:
|
| 87 |
+
embedding = await self.generate_embedding(text)
|
| 88 |
+
if embedding is None:
|
| 89 |
+
logger.error(f"Failed to generate embedding for text: {text[:50]}...")
|
| 90 |
+
return None
|
| 91 |
+
embeddings.append(embedding)
|
| 92 |
+
|
| 93 |
+
return embeddings
|
| 94 |
+
|
| 95 |
+
# CHAT METHODS
|
| 96 |
+
async def generate_chat_response(
|
| 97 |
+
self,
|
| 98 |
+
query: str,
|
| 99 |
+
context: Optional[List[Dict[str, Any]]] = None,
|
| 100 |
+
conversation_history: Optional[List[Dict[str, str]]] = None
|
| 101 |
+
) -> Optional[str]:
|
| 102 |
+
"""
|
| 103 |
+
Generate a chat response using Gemini 1.5 Flash/Pro with RAG context
|
| 104 |
+
"""
|
| 105 |
+
for attempt in range(self.max_retries):
|
| 106 |
+
try:
|
| 107 |
+
# Format the prompt with context and query
|
| 108 |
+
formatted_prompt = self._format_rag_prompt(query, context, conversation_history)
|
| 109 |
+
|
| 110 |
+
# Safety settings to moderate content
|
| 111 |
+
safety_settings = [
|
| 112 |
+
{
|
| 113 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 114 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
| 118 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"category": "HARM_CATEGORY_HARASSMENT",
|
| 122 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 126 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 127 |
+
}
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
# Generate response using the chat model
|
| 131 |
+
response = await self.chat_model.generate_content_async(
|
| 132 |
+
formatted_prompt,
|
| 133 |
+
safety_settings=safety_settings,
|
| 134 |
+
generation_config={
|
| 135 |
+
"temperature": 0.3, # Lower temperature for more consistent responses
|
| 136 |
+
"max_output_tokens": 800, # Limit response length
|
| 137 |
+
"candidate_count": 1
|
| 138 |
+
}
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Extract the text response
|
| 142 |
+
if response and response.text:
|
| 143 |
+
logger.info(f"Successfully generated chat response for query: {query[:50]}...")
|
| 144 |
+
return response.text.strip()
|
| 145 |
+
else:
|
| 146 |
+
logger.warning("Gemini returned empty response")
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.warning(f"Attempt {attempt + 1} failed to generate chat response: {e}")
|
| 151 |
+
|
| 152 |
+
if attempt == self.max_retries - 1:
|
| 153 |
+
# Last attempt failed
|
| 154 |
+
logger.error(f"Failed to generate chat response after {self.max_retries} attempts: {e}")
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
# Wait before retrying
|
| 158 |
+
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # Exponential backoff
|
| 159 |
+
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
def _format_rag_prompt(
|
| 163 |
+
self,
|
| 164 |
+
query: str,
|
| 165 |
+
context: Optional[List[Dict[str, Any]]] = None,
|
| 166 |
+
conversation_history: Optional[List[Dict[str, str]]] = None
|
| 167 |
+
) -> str:
|
| 168 |
+
"""
|
| 169 |
+
Format the prompt with RAG context and conversation history
|
| 170 |
+
"""
|
| 171 |
+
prompt_parts = []
|
| 172 |
+
|
| 173 |
+
# Add system context
|
| 174 |
+
prompt_parts.append(
|
| 175 |
+
"You are an AI assistant that helps users by answering questions based on provided context. "
|
| 176 |
+
"Use only the information provided in the context to answer the questions. "
|
| 177 |
+
"If the context doesn't contain the information needed to answer the question, say so clearly. "
|
| 178 |
+
"Be helpful, accurate, and concise in your responses."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Add conversation history if available
|
| 182 |
+
if conversation_history:
|
| 183 |
+
prompt_parts.append("\nPrevious conversation:")
|
| 184 |
+
for msg in conversation_history[-4:]: # Use last 4 messages to avoid exceeding token limits
|
| 185 |
+
role = msg.get("role", "user")
|
| 186 |
+
content = msg.get("content", "")
|
| 187 |
+
prompt_parts.append(f"{role.capitalize()}: {content}")
|
| 188 |
+
|
| 189 |
+
# Add retrieved context if available
|
| 190 |
+
if context:
|
| 191 |
+
prompt_parts.append("\nContext for answering the question:")
|
| 192 |
+
for i, ctx in enumerate(context[:5]): # Use top 5 context snippets
|
| 193 |
+
chunk_text = ctx.get("payload", {}).get("chunk_text", "") if isinstance(ctx, dict) else str(ctx)
|
| 194 |
+
# Clean up the chunk text if it contains the "..." marker from storage
|
| 195 |
+
if chunk_text.endswith("..."):
|
| 196 |
+
# If it was truncated when stored, we don't have the full text
|
| 197 |
+
# But we can still use what we have
|
| 198 |
+
pass
|
| 199 |
+
prompt_parts.append(f"Context {i+1}: {chunk_text}")
|
| 200 |
+
|
| 201 |
+
# Add the current query
|
| 202 |
+
prompt_parts.append(f"\nQuestion: {query}")
|
| 203 |
+
prompt_parts.append("Answer:")
|
| 204 |
+
|
| 205 |
+
return "\n".join(prompt_parts)
|
| 206 |
+
|
| 207 |
+
async def moderate_content(self, text: str) -> Dict[str, Any]:
|
| 208 |
+
"""
|
| 209 |
+
Moderate content using Gemini's safety features
|
| 210 |
+
"""
|
| 211 |
+
try:
|
| 212 |
+
# Use the chat model to analyze content safety
|
| 213 |
+
response = await self.chat_model.generate_content_async(
|
| 214 |
+
f"Analyze the following text for safety issues: {text}",
|
| 215 |
+
safety_settings=[
|
| 216 |
+
{
|
| 217 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 218 |
+
"threshold": "BLOCK_ONLY_HIGH"
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
| 222 |
+
"threshold": "BLOCK_ONLY_HIGH"
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"category": "HARM_CATEGORY_HARASSMENT",
|
| 226 |
+
"threshold": "BLOCK_ONLY_HIGH"
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 230 |
+
"threshold": "BLOCK_ONLY_HIGH"
|
| 231 |
+
}
|
| 232 |
+
]
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Return safety analysis
|
| 236 |
+
return {
|
| 237 |
+
"is_safe": True,
|
| 238 |
+
"text": text,
|
| 239 |
+
"moderation_applied": False # Gemini handles moderation internally
|
| 240 |
+
}
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.error(f"Content moderation error: {e}")
|
| 243 |
+
return {
|
| 244 |
+
"is_safe": False,
|
| 245 |
+
"text": text,
|
| 246 |
+
"moderation_applied": True,
|
| 247 |
+
"error": str(e)
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# Global instance of GeminiClient
|
| 252 |
+
gemini_client = GeminiClient()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def get_gemini_client() -> GeminiClient:
|
| 256 |
+
"""Get the Gemini client instance (for both embeddings and chat)"""
|
| 257 |
+
return gemini_client
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# Embedding functions (backward compatibility)
|
| 261 |
+
async def generate_embedding(text: str) -> Optional[List[float]]:
|
| 262 |
+
"""
|
| 263 |
+
Generate embedding for the given text using the configured Gemini model
|
| 264 |
+
"""
|
| 265 |
+
return await gemini_client.generate_embedding(text)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
async def generate_embeddings_batch(texts: List[str]) -> Optional[List[List[float]]]:
|
| 269 |
+
"""
|
| 270 |
+
Generate embeddings for a batch of texts
|
| 271 |
+
"""
|
| 272 |
+
return await gemini_client.generate_embeddings_batch(texts)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Chat functions
|
| 276 |
+
async def generate_chat_response(
|
| 277 |
+
query: str,
|
| 278 |
+
context: Optional[List[Dict[str, Any]]] = None,
|
| 279 |
+
conversation_history: Optional[List[Dict[str, str]]] = None
|
| 280 |
+
) -> Optional[str]:
|
| 281 |
+
"""
|
| 282 |
+
Generate a chat response using Gemini with RAG context
|
| 283 |
+
"""
|
| 284 |
+
return await gemini_client.generate_chat_response(query, context, conversation_history)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
async def moderate_content(text: str) -> Dict[str, Any]:
|
| 288 |
+
"""
|
| 289 |
+
Moderate content using Gemini's safety features
|
| 290 |
+
"""
|
| 291 |
+
return await gemini_client.moderate_content(text)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# Decorator for rate limiting (basic implementation)
|
| 295 |
+
def rate_limit(calls_per_second: float = 10):
|
| 296 |
+
"""
|
| 297 |
+
Decorator to implement basic rate limiting
|
| 298 |
+
Google Gemini API has rate limits, so we need to be respectful
|
| 299 |
+
"""
|
| 300 |
+
min_interval = 1.0 / calls_per_second
|
| 301 |
+
last_called = [0.0]
|
| 302 |
+
|
| 303 |
+
def decorator(func):
|
| 304 |
+
@wraps(func)
|
| 305 |
+
async def wrapper(*args, **kwargs):
|
| 306 |
+
elapsed = time.time() - last_called[0]
|
| 307 |
+
left_to_wait = min_interval - elapsed
|
| 308 |
+
if left_to_wait > 0:
|
| 309 |
+
await asyncio.sleep(left_to_wait)
|
| 310 |
+
ret = await func(*args, **kwargs)
|
| 311 |
+
last_called[0] = time.time()
|
| 312 |
+
return ret
|
| 313 |
+
return wrapper
|
| 314 |
+
return decorator
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# Apply rate limiting to the main functions
|
| 318 |
+
@rate_limit(calls_per_second=10) # Adjust based on your API quota
|
| 319 |
+
async def generate_embedding_with_rate_limit(text: str) -> Optional[List[float]]:
|
| 320 |
+
"""
|
| 321 |
+
Generate embedding with rate limiting applied
|
| 322 |
+
"""
|
| 323 |
+
return await generate_embedding(text)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@rate_limit(calls_per_second=5) # Lower rate limit for chat as it's more resource intensive
|
| 327 |
+
async def generate_chat_response_with_rate_limit(
|
| 328 |
+
query: str,
|
| 329 |
+
context: Optional[List[Dict[str, Any]]] = None,
|
| 330 |
+
conversation_history: Optional[List[Dict[str, str]]] = None
|
| 331 |
+
) -> Optional[str]:
|
| 332 |
+
"""
|
| 333 |
+
Generate chat response with rate limiting applied
|
| 334 |
+
"""
|
| 335 |
+
return await generate_chat_response(query, context, conversation_history)
|
src/embeddings/processor.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding processor for the AI Backend with RAG + Authentication
|
| 3 |
+
Implements text preprocessing, caching, and document chunking for embeddings
|
| 4 |
+
"""
|
| 5 |
+
import hashlib
|
| 6 |
+
import asyncio
|
| 7 |
+
from typing import List, Optional, Tuple, Dict
|
| 8 |
+
import logging
|
| 9 |
+
from uuid import UUID
|
| 10 |
+
|
| 11 |
+
from ..config.settings import settings
|
| 12 |
+
from .gemini_client import generate_embedding, generate_embeddings_batch
|
| 13 |
+
from ..qdrant.operations import get_vector_operations
|
| 14 |
+
from ..db import crud
|
| 15 |
+
from ..config.database import get_db_session
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# Maximum characters per chunk (Gemini has token limits)
|
| 20 |
+
MAX_CHUNK_SIZE = 2000
|
| 21 |
+
OVERLAP_SIZE = 200 # Overlap between chunks to maintain context
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class EmbeddingProcessor:
|
| 25 |
+
"""
|
| 26 |
+
Processor class to handle embedding workflows including preprocessing,
|
| 27 |
+
caching, and document chunking
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
self.vector_ops = get_vector_operations()
|
| 32 |
+
# Simple in-memory cache (in production, use Redis or similar)
|
| 33 |
+
self.cache: Dict[str, List[float]] = {}
|
| 34 |
+
|
| 35 |
+
def _generate_content_hash(self, content: str) -> str:
|
| 36 |
+
"""
|
| 37 |
+
Generate a hash for content to use for caching and deduplication
|
| 38 |
+
"""
|
| 39 |
+
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
| 40 |
+
|
| 41 |
+
def _preprocess_text(self, text: str) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Preprocess text by cleaning and normalizing
|
| 44 |
+
"""
|
| 45 |
+
if not text or not isinstance(text, str):
|
| 46 |
+
raise ValueError("Input text must be a non-empty string")
|
| 47 |
+
|
| 48 |
+
# Remove extra whitespace
|
| 49 |
+
text = ' '.join(text.split())
|
| 50 |
+
|
| 51 |
+
# Validate text length
|
| 52 |
+
if len(text) > 1000000: # 1M characters max
|
| 53 |
+
logger.warning(f"Text is very long ({len(text)} chars), consider pre-processing")
|
| 54 |
+
|
| 55 |
+
return text.strip()
|
| 56 |
+
|
| 57 |
+
def _chunk_text(self, text: str, chunk_size: int = MAX_CHUNK_SIZE, overlap: int = OVERLAP_SIZE) -> List[str]:
|
| 58 |
+
"""
|
| 59 |
+
Split text into overlapping chunks to maintain context
|
| 60 |
+
"""
|
| 61 |
+
if len(text) <= chunk_size:
|
| 62 |
+
return [text]
|
| 63 |
+
|
| 64 |
+
chunks = []
|
| 65 |
+
start = 0
|
| 66 |
+
|
| 67 |
+
while start < len(text):
|
| 68 |
+
end = start + chunk_size
|
| 69 |
+
|
| 70 |
+
# If we're near the end, include the rest
|
| 71 |
+
if end > len(text):
|
| 72 |
+
end = len(text)
|
| 73 |
+
start = max(0, end - chunk_size)
|
| 74 |
+
|
| 75 |
+
chunk = text[start:end]
|
| 76 |
+
|
| 77 |
+
# If this isn't the last chunk, try to break at sentence boundary
|
| 78 |
+
if end < len(text):
|
| 79 |
+
# Look for sentence endings to break at
|
| 80 |
+
sentence_end = max(
|
| 81 |
+
chunk.rfind('.'),
|
| 82 |
+
chunk.rfind('!'),
|
| 83 |
+
chunk.rfind('?'),
|
| 84 |
+
chunk.rfind('\n')
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if sentence_end > chunk_size // 2: # Only if it's not too early
|
| 88 |
+
end = start + sentence_end + 1
|
| 89 |
+
chunk = text[start:end]
|
| 90 |
+
|
| 91 |
+
chunks.append(chunk)
|
| 92 |
+
start = end - overlap
|
| 93 |
+
|
| 94 |
+
return chunks
|
| 95 |
+
|
| 96 |
+
async def _get_from_cache(self, content_hash: str) -> Optional[List[float]]:
|
| 97 |
+
"""
|
| 98 |
+
Get embedding from cache if available
|
| 99 |
+
"""
|
| 100 |
+
return self.cache.get(content_hash)
|
| 101 |
+
|
| 102 |
+
async def _save_to_cache(self, content_hash: str, embedding: List[float]):
|
| 103 |
+
"""
|
| 104 |
+
Save embedding to cache
|
| 105 |
+
"""
|
| 106 |
+
self.cache[content_hash] = embedding
|
| 107 |
+
|
| 108 |
+
async def process_single_text(self, text: str, user_id: UUID) -> Optional[List[float]]:
|
| 109 |
+
"""
|
| 110 |
+
Process a single text for embedding with caching
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
# Preprocess the text
|
| 114 |
+
processed_text = self._preprocess_text(text)
|
| 115 |
+
|
| 116 |
+
if not processed_text:
|
| 117 |
+
logger.warning("Text preprocessing resulted in empty string")
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
# Generate content hash for caching
|
| 121 |
+
content_hash = self._generate_content_hash(processed_text)
|
| 122 |
+
|
| 123 |
+
# Check cache first
|
| 124 |
+
cached_embedding = await self._get_from_cache(content_hash)
|
| 125 |
+
if cached_embedding:
|
| 126 |
+
logger.info(f"Found embedding in cache for text of length {len(processed_text)}")
|
| 127 |
+
return cached_embedding
|
| 128 |
+
|
| 129 |
+
# Generate embedding using Gemini
|
| 130 |
+
embedding = await generate_embedding(processed_text)
|
| 131 |
+
if embedding is None:
|
| 132 |
+
logger.error(f"Failed to generate embedding for text of length {len(processed_text)}")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
# Save to cache
|
| 136 |
+
await self._save_to_cache(content_hash, embedding)
|
| 137 |
+
|
| 138 |
+
logger.info(f"Successfully processed embedding for text of length {len(processed_text)}")
|
| 139 |
+
return embedding
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Error processing single text: {e}")
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
async def process_document(
|
| 146 |
+
self,
|
| 147 |
+
document_id: UUID,
|
| 148 |
+
user_id: UUID,
|
| 149 |
+
content: str,
|
| 150 |
+
title: Optional[str] = None,
|
| 151 |
+
metadata: Optional[Dict] = None
|
| 152 |
+
) -> bool:
|
| 153 |
+
"""
|
| 154 |
+
Process a document for embedding, including chunking and storage
|
| 155 |
+
"""
|
| 156 |
+
try:
|
| 157 |
+
# Preprocess the content
|
| 158 |
+
processed_content = self._preprocess_text(content)
|
| 159 |
+
|
| 160 |
+
if not processed_content:
|
| 161 |
+
logger.warning("Document content preprocessing resulted in empty string")
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
# Chunk the document if it's large
|
| 165 |
+
if len(processed_content) > MAX_CHUNK_SIZE:
|
| 166 |
+
chunks = self._chunk_text(processed_content)
|
| 167 |
+
logger.info(f"Document chunked into {len(chunks)} parts")
|
| 168 |
+
else:
|
| 169 |
+
chunks = [processed_content]
|
| 170 |
+
|
| 171 |
+
# Process each chunk
|
| 172 |
+
all_embeddings = []
|
| 173 |
+
chunk_payloads = []
|
| 174 |
+
|
| 175 |
+
for i, chunk in enumerate(chunks):
|
| 176 |
+
# Generate content hash for caching
|
| 177 |
+
content_hash = self._generate_content_hash(chunk)
|
| 178 |
+
|
| 179 |
+
# Check cache first
|
| 180 |
+
embedding = await self._get_from_cache(content_hash)
|
| 181 |
+
if embedding is None:
|
| 182 |
+
# Generate embedding using Gemini
|
| 183 |
+
embedding = await generate_embedding(chunk)
|
| 184 |
+
if embedding is None:
|
| 185 |
+
logger.error(f"Failed to generate embedding for chunk {i}")
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
# Save to cache
|
| 189 |
+
await self._save_to_cache(content_hash, embedding)
|
| 190 |
+
|
| 191 |
+
all_embeddings.append(embedding)
|
| 192 |
+
|
| 193 |
+
# Create payload for this chunk
|
| 194 |
+
chunk_payload = {
|
| 195 |
+
"chunk_index": i,
|
| 196 |
+
"chunk_text": chunk[:100] + "..." if len(chunk) > 100 else chunk, # Store first 100 chars as reference
|
| 197 |
+
"document_id": str(document_id),
|
| 198 |
+
"user_id": str(user_id),
|
| 199 |
+
"title": title or "Untitled Document",
|
| 200 |
+
"total_chunks": len(chunks)
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
if metadata:
|
| 204 |
+
chunk_payload.update(metadata)
|
| 205 |
+
|
| 206 |
+
chunk_payloads.append(chunk_payload)
|
| 207 |
+
|
| 208 |
+
# Store embeddings in Qdrant
|
| 209 |
+
if all_embeddings:
|
| 210 |
+
success = await self.vector_ops.batch_upsert_vectors(
|
| 211 |
+
user_id=user_id,
|
| 212 |
+
document_id=document_id,
|
| 213 |
+
embeddings_list=all_embeddings,
|
| 214 |
+
payloads_list=chunk_payloads
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if success:
|
| 218 |
+
logger.info(f"Successfully stored {len(all_embeddings)} embeddings for document {document_id}")
|
| 219 |
+
return True
|
| 220 |
+
else:
|
| 221 |
+
logger.error(f"Failed to store embeddings in Qdrant for document {document_id}")
|
| 222 |
+
return False
|
| 223 |
+
else:
|
| 224 |
+
logger.warning("No embeddings were generated for the document")
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.error(f"Error processing document {document_id}: {e}")
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
async def process_texts_batch(
|
| 232 |
+
self,
|
| 233 |
+
texts: List[str],
|
| 234 |
+
user_id: UUID
|
| 235 |
+
) -> Optional[List[List[float]]]:
|
| 236 |
+
"""
|
| 237 |
+
Process a batch of texts for embedding with caching
|
| 238 |
+
"""
|
| 239 |
+
try:
|
| 240 |
+
embeddings = []
|
| 241 |
+
|
| 242 |
+
for text in texts:
|
| 243 |
+
embedding = await self.process_single_text(text, user_id)
|
| 244 |
+
if embedding is None:
|
| 245 |
+
logger.error(f"Failed to process text: {text[:50]}...")
|
| 246 |
+
return None
|
| 247 |
+
embeddings.append(embedding)
|
| 248 |
+
|
| 249 |
+
logger.info(f"Successfully processed batch of {len(texts)} texts")
|
| 250 |
+
return embeddings
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.error(f"Error processing text batch: {e}")
|
| 254 |
+
return None
|
| 255 |
+
|
| 256 |
+
async def invalidate_cache_for_document(self, document_id: UUID):
|
| 257 |
+
"""
|
| 258 |
+
Remove cached embeddings associated with a document
|
| 259 |
+
In a real implementation with Redis, this would be more sophisticated
|
| 260 |
+
"""
|
| 261 |
+
# In our simple in-memory cache, we can't easily identify which cache entries
|
| 262 |
+
# belong to a specific document, so we'd need to implement a more sophisticated
|
| 263 |
+
# cache structure. For now, we'll just log the action.
|
| 264 |
+
logger.info(f"Cache invalidation requested for document {document_id} (not implemented in simple cache)")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Global instance of EmbeddingProcessor
|
| 268 |
+
embedding_processor = EmbeddingProcessor()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def get_embedding_processor() -> EmbeddingProcessor:
|
| 272 |
+
"""Get the embedding processor instance"""
|
| 273 |
+
return embedding_processor
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
async def process_single_text(text: str, user_id: UUID) -> Optional[List[float]]:
|
| 277 |
+
"""
|
| 278 |
+
Process a single text for embedding with caching
|
| 279 |
+
"""
|
| 280 |
+
return await embedding_processor.process_single_text(text, user_id)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
async def process_document(
|
| 284 |
+
document_id: UUID,
|
| 285 |
+
user_id: UUID,
|
| 286 |
+
content: str,
|
| 287 |
+
title: Optional[str] = None,
|
| 288 |
+
metadata: Optional[Dict] = None
|
| 289 |
+
) -> bool:
|
| 290 |
+
"""
|
| 291 |
+
Process a document for embedding, including chunking and storage
|
| 292 |
+
"""
|
| 293 |
+
return await embedding_processor.process_document(document_id, user_id, content, title, metadata)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
async def process_texts_batch(
|
| 297 |
+
texts: List[str],
|
| 298 |
+
user_id: UUID
|
| 299 |
+
) -> Optional[List[List[float]]]:
|
| 300 |
+
"""
|
| 301 |
+
Process a batch of texts for embedding with caching
|
| 302 |
+
"""
|
| 303 |
+
return await embedding_processor.process_texts_batch(texts, user_id)
|
src/main.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
import logging
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
# Configure logging
|
| 7 |
+
logging.basicConfig(
|
| 8 |
+
level=logging.INFO,
|
| 9 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 10 |
+
handlers=[
|
| 11 |
+
logging.StreamHandler(sys.stdout)
|
| 12 |
+
]
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Import settings - this will validate environment variables on startup
|
| 18 |
+
from .config.settings import settings
|
| 19 |
+
|
| 20 |
+
app = FastAPI(
|
| 21 |
+
title="AI Backend with RAG + Authentication",
|
| 22 |
+
description="A scalable backend featuring authentication, RAG capabilities, and integration with external services",
|
| 23 |
+
version="1.0.0"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Add CORS middleware
|
| 27 |
+
app.add_middleware(
|
| 28 |
+
CORSMiddleware,
|
| 29 |
+
allow_origins=["*"], # In production, replace with specific origins
|
| 30 |
+
allow_credentials=True,
|
| 31 |
+
allow_methods=["*"],
|
| 32 |
+
allow_headers=["*"],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
@app.get("/")
|
| 36 |
+
async def root():
|
| 37 |
+
return {"message": "AI Backend with RAG + Authentication is running!"}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Include all routes
|
| 41 |
+
from .routes import auth, search, history, documents, health
|
| 42 |
+
|
| 43 |
+
app.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
| 44 |
+
app.include_router(search.router, prefix="/search", tags=["search"])
|
| 45 |
+
app.include_router(history.router, prefix="/history", tags=["history"])
|
| 46 |
+
app.include_router(documents.router, prefix="/documents", tags=["documents"])
|
| 47 |
+
app.include_router(health.router)
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
import uvicorn
|
| 51 |
+
uvicorn.run(app, host=settings.server_host, port=settings.server_port)
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/documents.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document models for the AI Backend with RAG + Authentication
|
| 3 |
+
Pydantic models for document-related request/response validation
|
| 4 |
+
"""
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
from typing import Optional, Dict, Any
|
| 7 |
+
from uuid import UUID
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DocumentCreate(BaseModel):
|
| 12 |
+
title: str = Field(..., min_length=1, max_length=255, description="Document title")
|
| 13 |
+
content: str = Field(..., min_length=1, description="Document content")
|
| 14 |
+
file_path: Optional[str] = Field(None, max_length=500, description="Path if uploaded file")
|
| 15 |
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DocumentResponse(BaseModel):
|
| 19 |
+
document_id: UUID
|
| 20 |
+
success: bool
|
| 21 |
+
message: str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DocumentUpdate(BaseModel):
|
| 25 |
+
title: Optional[str] = Field(None, min_length=1, max_length=255)
|
| 26 |
+
content: Optional[str] = Field(None, min_length=1)
|
| 27 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DocumentListResponse(BaseModel):
|
| 31 |
+
documents: list
|
| 32 |
+
total: int
|
src/models/search.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Search models for the AI Backend with RAG + Authentication
|
| 3 |
+
Pydantic models for search-related request/response validation
|
| 4 |
+
"""
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
from typing import Optional, List, Dict, Any
|
| 7 |
+
from uuid import UUID
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SearchRequest(BaseModel):
|
| 11 |
+
query: str = Field(..., min_length=1, max_length=1000, description="Search query")
|
| 12 |
+
top_k: Optional[int] = Field(default=5, ge=1, le=20, description="Number of results to return")
|
| 13 |
+
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters for search")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SearchResult(BaseModel):
|
| 17 |
+
id: str
|
| 18 |
+
document_id: str
|
| 19 |
+
score: float
|
| 20 |
+
payload: Dict[str, Any]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SearchResponse(BaseModel):
|
| 24 |
+
results: List[SearchResult]
|
| 25 |
+
query: str
|
| 26 |
+
total_results: int
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Message(BaseModel):
|
| 30 |
+
role: str = Field(..., pattern=r"^(user|assistant|system)$", description="Role of the message sender")
|
| 31 |
+
content: str = Field(..., min_length=1, description="Content of the message")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ChatRequest(BaseModel):
|
| 35 |
+
query: str = Field(..., min_length=1, max_length=1000, description="Chat query")
|
| 36 |
+
top_k: Optional[int] = Field(default=5, ge=1, le=10, description="Number of context results to retrieve")
|
| 37 |
+
conversation_history: Optional[List[Message]] = Field(None, description="Previous conversation messages")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ChatResponse(BaseModel):
|
| 41 |
+
response: str
|
| 42 |
+
sources: List[Dict[str, Any]]
|
| 43 |
+
context_used: List[Dict[str, Any]]
|
src/qdrant/__init__.py
ADDED
|
File without changes
|
src/qdrant/client.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Qdrant client setup for the AI Backend with RAG + Authentication
|
| 3 |
+
Implements Qdrant client initialization and connection management
|
| 4 |
+
"""
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
from qdrant_client.http import models
|
| 7 |
+
from qdrant_client.http.exceptions import UnexpectedResponse
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from ..config.settings import settings
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Vector dimensions for Gemini text-embedding-004 model
|
| 16 |
+
VECTOR_DIMENSIONS = 1536 # Standard for text-embedding-004
|
| 17 |
+
|
| 18 |
+
class QdrantService:
|
| 19 |
+
"""
|
| 20 |
+
Service class to manage Qdrant client and operations
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.client: Optional[QdrantClient] = None
|
| 25 |
+
self._initialize_client()
|
| 26 |
+
|
| 27 |
+
def _initialize_client(self):
|
| 28 |
+
"""Initialize the Qdrant client with settings from configuration"""
|
| 29 |
+
try:
|
| 30 |
+
if settings.qdrant_api_key:
|
| 31 |
+
if settings.qdrant_url.startswith('https://'):
|
| 32 |
+
# For cloud instances, use the URL directly with API key
|
| 33 |
+
self.client = QdrantClient(
|
| 34 |
+
url=settings.qdrant_url,
|
| 35 |
+
api_key=settings.qdrant_api_key,
|
| 36 |
+
timeout=10.0 # 10 second timeout
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
# For local instances
|
| 40 |
+
self.client = QdrantClient(
|
| 41 |
+
url=settings.qdrant_url,
|
| 42 |
+
api_key=settings.qdrant_api_key,
|
| 43 |
+
timeout=10.0 # 10 second timeout
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
# If no API key is provided, connect without authentication
|
| 47 |
+
self.client = QdrantClient(
|
| 48 |
+
url=settings.qdrant_url,
|
| 49 |
+
timeout=10.0 # 10 second timeout
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
logger.info("Qdrant client initialized successfully")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Failed to initialize Qdrant client: {e}")
|
| 55 |
+
raise
|
| 56 |
+
|
| 57 |
+
async def health_check(self) -> bool:
|
| 58 |
+
"""Check if Qdrant server is accessible"""
|
| 59 |
+
try:
|
| 60 |
+
# Try to get cluster info as a health check
|
| 61 |
+
if self.client:
|
| 62 |
+
cluster_info = self.client.get_cluster_info()
|
| 63 |
+
logger.info(f"Qdrant health check passed. Cluster: {cluster_info}")
|
| 64 |
+
return True
|
| 65 |
+
return False
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"Qdrant health check failed: {e}")
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
def get_client(self) -> QdrantClient:
|
| 71 |
+
"""Get the initialized Qdrant client"""
|
| 72 |
+
if self.client is None:
|
| 73 |
+
raise RuntimeError("Qdrant client not initialized")
|
| 74 |
+
return self.client
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Global instance of QdrantService
|
| 78 |
+
qdrant_service = QdrantService()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_qdrant_client() -> QdrantClient:
|
| 82 |
+
"""Get the Qdrant client instance"""
|
| 83 |
+
return qdrant_service.get_client()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
async def initialize_qdrant_collections():
|
| 87 |
+
"""Initialize required collections in Qdrant"""
|
| 88 |
+
try:
|
| 89 |
+
client = get_qdrant_client()
|
| 90 |
+
|
| 91 |
+
# Check if the documents collection already exists
|
| 92 |
+
collections = client.get_collections()
|
| 93 |
+
collection_names = [collection.name for collection in collections.collections]
|
| 94 |
+
|
| 95 |
+
if "documents" not in collection_names:
|
| 96 |
+
# Create the documents collection with proper vector configuration
|
| 97 |
+
client.create_collection(
|
| 98 |
+
collection_name="documents",
|
| 99 |
+
vectors_config=models.VectorParams(
|
| 100 |
+
size=VECTOR_DIMENSIONS,
|
| 101 |
+
distance=models.Distance.COSINE # Cosine distance is good for embeddings
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
logger.info("Created 'documents' collection in Qdrant")
|
| 105 |
+
else:
|
| 106 |
+
logger.info("Collection 'documents' already exists in Qdrant")
|
| 107 |
+
|
| 108 |
+
# Verify the collection has the correct configuration
|
| 109 |
+
collection_info = client.get_collection(collection_name="documents")
|
| 110 |
+
vector_config = collection_info.config.params.vectors
|
| 111 |
+
if hasattr(vector_config, 'size'):
|
| 112 |
+
actual_size = vector_config.size
|
| 113 |
+
else:
|
| 114 |
+
# Handle the case where vector_config is a dictionary
|
| 115 |
+
actual_size = vector_config['size'] if isinstance(vector_config, dict) else vector_config
|
| 116 |
+
|
| 117 |
+
if actual_size != VECTOR_DIMENSIONS:
|
| 118 |
+
logger.warning(f"Collection vector size is {actual_size}, expected {VECTOR_DIMENSIONS}")
|
| 119 |
+
else:
|
| 120 |
+
logger.info(f"Collection 'documents' has correct vector size: {VECTOR_DIMENSIONS}")
|
| 121 |
+
|
| 122 |
+
return True
|
| 123 |
+
|
| 124 |
+
except UnexpectedResponse as e:
|
| 125 |
+
logger.error(f"Qdrant API error during collection initialization: {e}")
|
| 126 |
+
return False
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Unexpected error during collection initialization: {e}")
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Initialize collections on module import
|
| 133 |
+
async def setup_qdrant():
|
| 134 |
+
"""Setup function to initialize Qdrant collections"""
|
| 135 |
+
success = await initialize_qdrant_collections()
|
| 136 |
+
if success:
|
| 137 |
+
logger.info("Qdrant setup completed successfully")
|
| 138 |
+
else:
|
| 139 |
+
logger.error("Qdrant setup failed")
|
| 140 |
+
return success
|