Spaces:
Sleeping
Sleeping
Upload 62 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +73 -0
- .env +40 -0
- Dockerfile +52 -0
- docker-compose.yml +46 -0
- scripts/embed_book_content.py +358 -0
- src/__init__.py +1 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/main.cpython-312.pyc +0 -0
- src/api/__init__.py +1 -0
- src/api/__pycache__/__init__.cpython-312.pyc +0 -0
- src/api/middleware/__init__.py +1 -0
- src/api/middleware/__pycache__/__init__.cpython-312.pyc +0 -0
- src/api/middleware/__pycache__/auth_middleware.cpython-312.pyc +0 -0
- src/api/middleware/__pycache__/logging_middleware.cpython-312.pyc +0 -0
- src/api/middleware/__pycache__/rate_limit.cpython-312.pyc +0 -0
- src/api/middleware/auth_middleware.py +148 -0
- src/api/middleware/logging_middleware.py +37 -0
- src/api/middleware/rate_limit.py +76 -0
- src/api/routes/__init__.py +1 -0
- src/api/routes/__pycache__/__init__.cpython-312.pyc +0 -0
- src/api/routes/__pycache__/auth.cpython-312.pyc +0 -0
- src/api/routes/__pycache__/chat.cpython-312.pyc +0 -0
- src/api/routes/__pycache__/health.cpython-312.pyc +0 -0
- src/api/routes/auth.py +304 -0
- src/api/routes/chat.py +264 -0
- src/api/routes/health.py +73 -0
- src/config/__init__.py +1 -0
- src/config/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__pycache__/database.cpython-312.pyc +0 -0
- src/config/__pycache__/settings.cpython-312.pyc +0 -0
- src/config/database.py +152 -0
- src/config/settings.py +103 -0
- src/main.py +191 -0
- src/models/__init__.py +6 -0
- src/models/__pycache__/__init__.cpython-312.pyc +0 -0
- src/models/__pycache__/chat_message.cpython-312.pyc +0 -0
- src/models/__pycache__/schemas.cpython-312.pyc +0 -0
- src/models/__pycache__/session.cpython-312.pyc +0 -0
- src/models/__pycache__/user.cpython-312.pyc +0 -0
- src/models/chat_message.py +55 -0
- src/models/schemas.py +143 -0
- src/models/session.py +53 -0
- src/models/user.py +44 -0
- src/services/__init__.py +6 -0
- src/services/__pycache__/__init__.cpython-312.pyc +0 -0
- src/services/__pycache__/auth_service.cpython-312.pyc +0 -0
- src/services/__pycache__/chat_service.cpython-312.pyc +0 -0
- src/services/__pycache__/rag_service.cpython-312.pyc +0 -0
- src/services/__pycache__/vector_service.cpython-312.pyc +0 -0
- src/services/auth_service.py +312 -0
.dockerignore
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
.venv/
|
| 23 |
+
venv/
|
| 24 |
+
ENV/
|
| 25 |
+
env/
|
| 26 |
+
|
| 27 |
+
# Testing
|
| 28 |
+
.pytest_cache/
|
| 29 |
+
.coverage
|
| 30 |
+
htmlcov/
|
| 31 |
+
.tox/
|
| 32 |
+
.nox/
|
| 33 |
+
|
| 34 |
+
# IDEs
|
| 35 |
+
.vscode/
|
| 36 |
+
.idea/
|
| 37 |
+
*.swp
|
| 38 |
+
*.swo
|
| 39 |
+
*~
|
| 40 |
+
.DS_Store
|
| 41 |
+
|
| 42 |
+
# Environment
|
| 43 |
+
.env
|
| 44 |
+
.env.local
|
| 45 |
+
.env.*.local
|
| 46 |
+
|
| 47 |
+
# Git
|
| 48 |
+
.git/
|
| 49 |
+
.gitignore
|
| 50 |
+
.gitattributes
|
| 51 |
+
|
| 52 |
+
# Documentation
|
| 53 |
+
*.md
|
| 54 |
+
!README.md
|
| 55 |
+
docs/
|
| 56 |
+
|
| 57 |
+
# Logs
|
| 58 |
+
*.log
|
| 59 |
+
logs/
|
| 60 |
+
|
| 61 |
+
# Docker
|
| 62 |
+
Dockerfile
|
| 63 |
+
docker-compose.yml
|
| 64 |
+
.dockerignore
|
| 65 |
+
|
| 66 |
+
# CI/CD
|
| 67 |
+
.github/
|
| 68 |
+
.gitlab-ci.yml
|
| 69 |
+
.travis.yml
|
| 70 |
+
|
| 71 |
+
# Misc
|
| 72 |
+
node_modules/
|
| 73 |
+
coverage/
|
.env
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Database Configuration
|
| 2 |
+
# Neon Serverless Postgres connection string
|
| 3 |
+
# Format: postgresql://user:password@host.neon.tech/dbname?sslmode=require
|
| 4 |
+
DATABASE_URL=postgresql://neondb_owner:npg_KxTmt2lL1seZ@ep-winter-night-ah1gvoeu-pooler.c-3.us-east-1.aws.neon.tech/neondb?sslmode=require&channel_binding=require
|
| 5 |
+
|
| 6 |
+
# Vector Database Configuration
|
| 7 |
+
# Qdrant Cloud cluster URL and API key
|
| 8 |
+
QDRANT_URL=https://dd8a681c-65ea-4ca6-ac50-e7e4873fdba1.us-east4-0.gcp.cloud.qdrant.io
|
| 9 |
+
QDRANT_API_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.820mM91jPhDQYvtU9O6WYu2gU9W-gTPdjaNbndmYNBY
|
| 10 |
+
|
| 11 |
+
# OpenAI API Configuration
|
| 12 |
+
# Get your API key from https://platform.openai.com/api-keys
|
| 13 |
+
OPENAI_API_KEY=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
|
| 14 |
+
# Optional: Organization ID if you have one
|
| 15 |
+
OPENAI_ORG_ID=
|
| 16 |
+
|
| 17 |
+
# Authentication Configuration
|
| 18 |
+
# Generate with: openssl rand -hex 32
|
| 19 |
+
BETTER_AUTH_SECRET=jVLDQBQkEZCtBMsGMfHTl0QtAI8Vqu8T
|
| 20 |
+
# Session expiration in seconds (default: 7 days)
|
| 21 |
+
SESSION_TTL=604800
|
| 22 |
+
|
| 23 |
+
# Application Settings
|
| 24 |
+
ENVIRONMENT=development
|
| 25 |
+
LOG_LEVEL=INFO
|
| 26 |
+
|
| 27 |
+
# Rate Limiting
|
| 28 |
+
# Maximum requests per minute per user
|
| 29 |
+
RATE_LIMIT_PER_MINUTE=20
|
| 30 |
+
|
| 31 |
+
# CORS Configuration
|
| 32 |
+
# Comma-separated list of allowed origins
|
| 33 |
+
ALLOWED_ORIGINS=http://localhost:3000
|
| 34 |
+
|
| 35 |
+
# Server Configuration
|
| 36 |
+
HOST=0.0.0.0
|
| 37 |
+
PORT=8000
|
| 38 |
+
|
| 39 |
+
# Redis Configuration (for rate limiting)
|
| 40 |
+
REDIS_URL="redis://default:xTmd5Lh1uAkxqdbcYjXZVt8eCM43MH8l@redis-12427.c276.us-east-1-2.ec2.cloud.redislabs.com:12427"
|
Dockerfile
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-stage build for production-ready image
|
| 2 |
+
|
| 3 |
+
# Stage 1: Build dependencies
|
| 4 |
+
FROM python:3.11-slim AS builder
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install system dependencies
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
gcc \
|
| 11 |
+
postgresql-client \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Copy requirements and install Python dependencies
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Stage 2: Runtime
|
| 19 |
+
FROM python:3.11-slim
|
| 20 |
+
|
| 21 |
+
WORKDIR /app
|
| 22 |
+
|
| 23 |
+
# Install runtime dependencies
|
| 24 |
+
RUN apt-get update && apt-get install -y \
|
| 25 |
+
postgresql-client \
|
| 26 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 27 |
+
|
| 28 |
+
# Copy Python dependencies from builder
|
| 29 |
+
COPY --from=builder /root/.local /root/.local
|
| 30 |
+
|
| 31 |
+
# Copy application code
|
| 32 |
+
COPY src/ ./src/
|
| 33 |
+
COPY scripts/ ./scripts/
|
| 34 |
+
COPY alembic/ ./alembic/
|
| 35 |
+
COPY alembic.ini ./
|
| 36 |
+
|
| 37 |
+
# Make sure scripts are on path
|
| 38 |
+
ENV PATH=/root/.local/bin:$PATH
|
| 39 |
+
|
| 40 |
+
# Create non-root user for security
|
| 41 |
+
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
|
| 42 |
+
USER appuser
|
| 43 |
+
|
| 44 |
+
# Expose port
|
| 45 |
+
EXPOSE 8000
|
| 46 |
+
|
| 47 |
+
# Health check
|
| 48 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \
|
| 49 |
+
CMD python -c "import requests; requests.get('http://localhost:8000/health')" || exit 1
|
| 50 |
+
|
| 51 |
+
# Run application
|
| 52 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
backend:
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
ports:
|
| 9 |
+
- "8000:8000"
|
| 10 |
+
environment:
|
| 11 |
+
- DATABASE_URL=${DATABASE_URL}
|
| 12 |
+
- QDRANT_URL=${QDRANT_URL}
|
| 13 |
+
- QDRANT_API_KEY=${QDRANT_API_KEY}
|
| 14 |
+
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
| 15 |
+
- OPENAI_ORG_ID=${OPENAI_ORG_ID:-}
|
| 16 |
+
- BETTER_AUTH_SECRET=${BETTER_AUTH_SECRET}
|
| 17 |
+
- ENVIRONMENT=development
|
| 18 |
+
- LOG_LEVEL=${LOG_LEVEL:-INFO}
|
| 19 |
+
- RATE_LIMIT_PER_MINUTE=${RATE_LIMIT_PER_MINUTE:-20}
|
| 20 |
+
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost:3000}
|
| 21 |
+
volumes:
|
| 22 |
+
- ./src:/app/src
|
| 23 |
+
- ./scripts:/app/scripts
|
| 24 |
+
depends_on:
|
| 25 |
+
- redis
|
| 26 |
+
networks:
|
| 27 |
+
- chatbot-network
|
| 28 |
+
restart: unless-stopped
|
| 29 |
+
|
| 30 |
+
redis:
|
| 31 |
+
image: redis:7-alpine
|
| 32 |
+
ports:
|
| 33 |
+
- "6379:6379"
|
| 34 |
+
volumes:
|
| 35 |
+
- redis-data:/data
|
| 36 |
+
networks:
|
| 37 |
+
- chatbot-network
|
| 38 |
+
restart: unless-stopped
|
| 39 |
+
command: redis-server --appendonly yes
|
| 40 |
+
|
| 41 |
+
volumes:
|
| 42 |
+
redis-data:
|
| 43 |
+
|
| 44 |
+
networks:
|
| 45 |
+
chatbot-network:
|
| 46 |
+
driver: bridge
|
scripts/embed_book_content.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Book content embedding script
|
| 4 |
+
|
| 5 |
+
Reads markdown files from docs/ (including all nested subdirectories), chunks content by headings or word count,
|
| 6 |
+
generates embeddings with OpenAI, and uploads to Qdrant vector database.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python backend/scripts/embed_book_content.py --book-path docs/ --collection-name humanoid-robotics-book-v1
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import asyncio
|
| 13 |
+
import os
|
| 14 |
+
import re
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import List, Dict, Any
|
| 18 |
+
from uuid import uuid4
|
| 19 |
+
|
| 20 |
+
# Add parent directory to path for imports
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 22 |
+
|
| 23 |
+
from openai import AsyncOpenAI
|
| 24 |
+
from qdrant_client import AsyncQdrantClient
|
| 25 |
+
from qdrant_client.models import Distance, VectorParams, PointStruct
|
| 26 |
+
|
| 27 |
+
from src.config.settings import settings
|
| 28 |
+
from src.utils.logger import setup_logging, get_logger
|
| 29 |
+
|
| 30 |
+
setup_logging(level="INFO")
|
| 31 |
+
logger = get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class BookContentChunker:
|
| 35 |
+
"""Chunks markdown content intelligently by headings and word limits"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
| 38 |
+
"""
|
| 39 |
+
Initialize chunker
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
chunk_size: Target chunk size in words
|
| 43 |
+
overlap: Word overlap between chunks
|
| 44 |
+
"""
|
| 45 |
+
self.chunk_size = chunk_size
|
| 46 |
+
self.overlap = overlap
|
| 47 |
+
|
| 48 |
+
def chunk_markdown(self, content: str, file_path: str) -> List[Dict[str, Any]]:
|
| 49 |
+
"""
|
| 50 |
+
Chunk markdown content by headings and word limits
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
content: Markdown file content
|
| 54 |
+
file_path: Path to markdown file (for metadata)
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
List of chunk dictionaries with content and metadata
|
| 58 |
+
"""
|
| 59 |
+
chunks = []
|
| 60 |
+
|
| 61 |
+
# Extract chapter/module name from file path
|
| 62 |
+
path_obj = Path(file_path)
|
| 63 |
+
chapter = self._extract_chapter_name(path_obj)
|
| 64 |
+
|
| 65 |
+
# Split by headings (## and ###)
|
| 66 |
+
sections = re.split(r'(^#{2,3}\s+.+$)', content, flags=re.MULTILINE)
|
| 67 |
+
|
| 68 |
+
current_section_heading = "Introduction"
|
| 69 |
+
current_content = []
|
| 70 |
+
|
| 71 |
+
for i, section in enumerate(sections):
|
| 72 |
+
# Check if this is a heading
|
| 73 |
+
heading_match = re.match(r'^(#{2,3})\s+(.+)$', section.strip())
|
| 74 |
+
|
| 75 |
+
if heading_match:
|
| 76 |
+
# Save previous section if it has content
|
| 77 |
+
if current_content:
|
| 78 |
+
section_chunks = self._chunk_section(
|
| 79 |
+
"\n".join(current_content),
|
| 80 |
+
chapter,
|
| 81 |
+
current_section_heading
|
| 82 |
+
)
|
| 83 |
+
chunks.extend(section_chunks)
|
| 84 |
+
|
| 85 |
+
# Start new section
|
| 86 |
+
current_section_heading = heading_match.group(2).strip()
|
| 87 |
+
current_content = []
|
| 88 |
+
else:
|
| 89 |
+
# Accumulate content
|
| 90 |
+
if section.strip():
|
| 91 |
+
current_content.append(section.strip())
|
| 92 |
+
|
| 93 |
+
# Process last section
|
| 94 |
+
if current_content:
|
| 95 |
+
section_chunks = self._chunk_section(
|
| 96 |
+
"\n".join(current_content),
|
| 97 |
+
chapter,
|
| 98 |
+
current_section_heading
|
| 99 |
+
)
|
| 100 |
+
chunks.extend(section_chunks)
|
| 101 |
+
|
| 102 |
+
return chunks
|
| 103 |
+
|
| 104 |
+
def _chunk_section(self, content: str, chapter: str, section: str) -> List[Dict[str, Any]]:
|
| 105 |
+
"""Chunk a section by word count with overlap"""
|
| 106 |
+
words = content.split()
|
| 107 |
+
chunks = []
|
| 108 |
+
|
| 109 |
+
if len(words) <= self.chunk_size:
|
| 110 |
+
# Section fits in one chunk
|
| 111 |
+
chunks.append({
|
| 112 |
+
"content": content,
|
| 113 |
+
"chapter": chapter,
|
| 114 |
+
"section": section,
|
| 115 |
+
"heading": section,
|
| 116 |
+
"chunk_index": 0,
|
| 117 |
+
"word_count": len(words),
|
| 118 |
+
})
|
| 119 |
+
else:
|
| 120 |
+
# Split into multiple chunks with overlap
|
| 121 |
+
chunk_index = 0
|
| 122 |
+
start = 0
|
| 123 |
+
|
| 124 |
+
while start < len(words):
|
| 125 |
+
end = start + self.chunk_size
|
| 126 |
+
chunk_words = words[start:end]
|
| 127 |
+
|
| 128 |
+
chunks.append({
|
| 129 |
+
"content": " ".join(chunk_words),
|
| 130 |
+
"chapter": chapter,
|
| 131 |
+
"section": section,
|
| 132 |
+
"heading": section,
|
| 133 |
+
"chunk_index": chunk_index,
|
| 134 |
+
"word_count": len(chunk_words),
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
chunk_index += 1
|
| 138 |
+
start = end - self.overlap # Overlap for context
|
| 139 |
+
|
| 140 |
+
return chunks
|
| 141 |
+
|
| 142 |
+
def _extract_chapter_name(self, path: Path) -> str:
|
| 143 |
+
"""Extract chapter/module name from file path"""
|
| 144 |
+
# Try to extract from directory or filename
|
| 145 |
+
parts = path.parts
|
| 146 |
+
|
| 147 |
+
# Look for patterns like "module1-ros2", "Module 1", etc.
|
| 148 |
+
for part in reversed(parts):
|
| 149 |
+
if re.match(r'module[-\s]*\d+', part, re.IGNORECASE):
|
| 150 |
+
return part.replace('-', ' ').title()
|
| 151 |
+
|
| 152 |
+
# Fallback to filename without extension
|
| 153 |
+
return path.stem.replace('-', ' ').replace('_', ' ').title()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class BookEmbedder:
|
| 157 |
+
"""Handles embedding generation and Qdrant upload"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, collection_name: str = "book_content"):
|
| 160 |
+
"""
|
| 161 |
+
Initialize embedder
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
collection_name: Qdrant collection name
|
| 165 |
+
"""
|
| 166 |
+
self.collection_name = collection_name
|
| 167 |
+
self.openai_client = AsyncOpenAI(api_key=settings.openai_api_key)
|
| 168 |
+
self.qdrant_client = AsyncQdrantClient(
|
| 169 |
+
url=settings.qdrant_url,
|
| 170 |
+
api_key=settings.qdrant_api_key,
|
| 171 |
+
timeout=30, # Set a higher timeout (seconds)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
async def create_collection(self):
|
| 175 |
+
"""Create Qdrant collection if it doesn't exist, with improved connection error handling"""
|
| 176 |
+
try:
|
| 177 |
+
collections = await self.qdrant_client.get_collections()
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(
|
| 180 |
+
"\nCannot connect to Qdrant. "
|
| 181 |
+
f"Error: {type(e).__name__}: {e}\n"
|
| 182 |
+
"-> Please make sure your Qdrant server is running and accessible at the configured URL.\n"
|
| 183 |
+
f"-> Current Qdrant URL: {settings.qdrant_url}"
|
| 184 |
+
)
|
| 185 |
+
logger.error("Exiting due to Qdrant connection failure.")
|
| 186 |
+
import sys
|
| 187 |
+
sys.exit(1)
|
| 188 |
+
|
| 189 |
+
collection_names = [col.name for col in collections.collections]
|
| 190 |
+
|
| 191 |
+
if self.collection_name not in collection_names:
|
| 192 |
+
await self.qdrant_client.create_collection(
|
| 193 |
+
collection_name=self.collection_name,
|
| 194 |
+
vectors_config=VectorParams(
|
| 195 |
+
size=settings.vector_size,
|
| 196 |
+
distance=Distance.COSINE,
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
logger.info(f"Created collection: {self.collection_name}")
|
| 200 |
+
else:
|
| 201 |
+
logger.info(f"Collection already exists: {self.collection_name}")
|
| 202 |
+
|
| 203 |
+
async def embed_text(self, text: str) -> List[float]:
|
| 204 |
+
"""
|
| 205 |
+
Generate embedding for text using OpenAI
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
text: Text to embed
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Embedding vector
|
| 212 |
+
"""
|
| 213 |
+
response = await self.openai_client.embeddings.create(
|
| 214 |
+
model=settings.openai_embedding_model,
|
| 215 |
+
input=text
|
| 216 |
+
)
|
| 217 |
+
return response.data[0].embedding
|
| 218 |
+
|
| 219 |
+
async def upload_chunks(self, chunks: List[Dict[str, Any]], doc_version: str = "v1.0.0"):
|
| 220 |
+
"""
|
| 221 |
+
Upload chunks with embeddings to Qdrant
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
chunks: List of chunk dictionaries
|
| 225 |
+
doc_version: Document version identifier
|
| 226 |
+
"""
|
| 227 |
+
logger.info(f"Uploading {len(chunks)} chunks to Qdrant...")
|
| 228 |
+
|
| 229 |
+
points = []
|
| 230 |
+
|
| 231 |
+
for i, chunk in enumerate(chunks):
|
| 232 |
+
# Generate embedding
|
| 233 |
+
embedding = await self.embed_text(chunk["content"])
|
| 234 |
+
|
| 235 |
+
# Create point
|
| 236 |
+
point = PointStruct(
|
| 237 |
+
id=str(uuid4()),
|
| 238 |
+
vector=embedding,
|
| 239 |
+
payload={
|
| 240 |
+
"content": chunk["content"],
|
| 241 |
+
"chapter": chunk["chapter"],
|
| 242 |
+
"section": chunk["section"],
|
| 243 |
+
"heading": chunk["heading"],
|
| 244 |
+
"chunk_index": chunk["chunk_index"],
|
| 245 |
+
"word_count": chunk["word_count"],
|
| 246 |
+
"doc_version": doc_version,
|
| 247 |
+
}
|
| 248 |
+
)
|
| 249 |
+
points.append(point)
|
| 250 |
+
|
| 251 |
+
# Upload in batches of 100
|
| 252 |
+
if len(points) >= 100:
|
| 253 |
+
await self.qdrant_client.upsert(
|
| 254 |
+
collection_name=self.collection_name,
|
| 255 |
+
points=points
|
| 256 |
+
)
|
| 257 |
+
logger.info(f"Uploaded batch {i // 100 + 1} ({len(points)} points)")
|
| 258 |
+
points = []
|
| 259 |
+
|
| 260 |
+
# Upload remaining points
|
| 261 |
+
if points:
|
| 262 |
+
await self.qdrant_client.upsert(
|
| 263 |
+
collection_name=self.collection_name,
|
| 264 |
+
points=points
|
| 265 |
+
)
|
| 266 |
+
logger.info(f"Uploaded final batch ({len(points)} points)")
|
| 267 |
+
|
| 268 |
+
async def close(self):
|
| 269 |
+
"""Close connections"""
|
| 270 |
+
await self.qdrant_client.close()
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def get_all_markdown_files_recursively(root_path: Path) -> List[Path]:
|
| 274 |
+
"""
|
| 275 |
+
Find all markdown files recursively (as deep as needed) in the given root_path.
|
| 276 |
+
This function will walk all subdirectories and return both *.md and *.mdx files.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
root_path: Path to the root directory
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
List[Path]: List of all markdown file Paths
|
| 283 |
+
"""
|
| 284 |
+
md_files = list(root_path.rglob("*.md"))
|
| 285 |
+
mdx_files = list(root_path.rglob("*.mdx"))
|
| 286 |
+
all_files = md_files + mdx_files
|
| 287 |
+
return [file for file in all_files if file.is_file() and 'node_modules' not in str(file)]
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
async def main():
|
| 291 |
+
"""Main embedding script"""
|
| 292 |
+
parser = argparse.ArgumentParser(description="Embed book content into Qdrant")
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--book-path",
|
| 295 |
+
type=str,
|
| 296 |
+
required=True,
|
| 297 |
+
help="Path to book content directory (e.g., docs/)"
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--collection-name",
|
| 301 |
+
type=str,
|
| 302 |
+
default="humanoid-robotics-book-v1",
|
| 303 |
+
help="Qdrant collection name"
|
| 304 |
+
)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--doc-version",
|
| 307 |
+
type=str,
|
| 308 |
+
default="v1.0.0",
|
| 309 |
+
help="Document version identifier"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
args = parser.parse_args()
|
| 313 |
+
|
| 314 |
+
# Initialize components
|
| 315 |
+
chunker = BookContentChunker(chunk_size=500, overlap=50)
|
| 316 |
+
embedder = BookEmbedder(collection_name=args.collection_name)
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
# Create collection, with robust error handling in the constructor
|
| 320 |
+
await embedder.create_collection()
|
| 321 |
+
|
| 322 |
+
# Find all markdown files as deep as needed
|
| 323 |
+
book_path = Path(args.book_path)
|
| 324 |
+
md_files = get_all_markdown_files_recursively(book_path)
|
| 325 |
+
logger.info(f"Found {len(md_files)} markdown files (.md and .mdx) recursively in all subdirectories")
|
| 326 |
+
|
| 327 |
+
# Process each file
|
| 328 |
+
all_chunks = []
|
| 329 |
+
for md_file in md_files:
|
| 330 |
+
logger.info(f"Processing: {md_file}")
|
| 331 |
+
|
| 332 |
+
with open(md_file, 'r', encoding='utf-8') as f:
|
| 333 |
+
content = f.read()
|
| 334 |
+
|
| 335 |
+
chunks = chunker.chunk_markdown(content, str(md_file))
|
| 336 |
+
all_chunks.extend(chunks)
|
| 337 |
+
logger.info(f" -> Generated {len(chunks)} chunks")
|
| 338 |
+
|
| 339 |
+
logger.info(f"Total chunks: {len(all_chunks)}")
|
| 340 |
+
|
| 341 |
+
# Upload to Qdrant
|
| 342 |
+
await embedder.upload_chunks(all_chunks, doc_version=args.doc_version)
|
| 343 |
+
|
| 344 |
+
logger.info("✅ Embedding complete!")
|
| 345 |
+
|
| 346 |
+
finally:
|
| 347 |
+
await embedder.close()
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
# Run main in asyncio loop, but trap connection errors globally as a last resort
|
| 352 |
+
try:
|
| 353 |
+
asyncio.run(main())
|
| 354 |
+
except Exception as e:
|
| 355 |
+
logger.error(f"FATAL: Exception occurred: {type(e).__name__}: {e}")
|
| 356 |
+
logger.error("Please check if Qdrant is running, accessible, and credentials are set correctly.")
|
| 357 |
+
import sys
|
| 358 |
+
sys.exit(1)
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""RAG Chatbot Backend Package"""
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
src/__pycache__/main.cpython-312.pyc
ADDED
|
Binary file (7.32 kB). View file
|
|
|
src/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API package"""
|
src/api/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
src/api/middleware/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Middleware package"""
|
src/api/middleware/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
src/api/middleware/__pycache__/auth_middleware.cpython-312.pyc
ADDED
|
Binary file (5.34 kB). View file
|
|
|
src/api/middleware/__pycache__/logging_middleware.cpython-312.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
src/api/middleware/__pycache__/rate_limit.cpython-312.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
src/api/middleware/auth_middleware.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Authentication middleware for protecting API endpoints
|
| 2 |
+
|
| 3 |
+
Provides dependency injection for current user authentication.
|
| 4 |
+
Validates JWT tokens from HTTP-only cookies and extracts user information.
|
| 5 |
+
"""
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from fastapi import Depends, HTTPException, status, Request
|
| 8 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 9 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 10 |
+
|
| 11 |
+
from src.config.database import get_db_session
|
| 12 |
+
from src.models.user import User
|
| 13 |
+
from src.services.auth_service import AuthService
|
| 14 |
+
from src.utils.logger import get_logger
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
# HTTP Bearer scheme for Authorization header (optional fallback)
|
| 19 |
+
security = HTTPBearer(auto_error=False)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def get_token_from_request(request: Request) -> Optional[str]:
|
| 23 |
+
"""Extract JWT token from cookie or Authorization header
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
request: FastAPI request object
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
JWT token string or None
|
| 30 |
+
"""
|
| 31 |
+
# First, try to get token from HTTP-only cookie
|
| 32 |
+
token = request.cookies.get("auth_token")
|
| 33 |
+
if token:
|
| 34 |
+
return token
|
| 35 |
+
|
| 36 |
+
# Fallback: try Authorization header (for API clients)
|
| 37 |
+
auth_header = request.headers.get("Authorization")
|
| 38 |
+
if auth_header and auth_header.startswith("Bearer "):
|
| 39 |
+
return auth_header.split(" ")[1]
|
| 40 |
+
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
async def get_current_user(
|
| 45 |
+
request: Request,
|
| 46 |
+
db: AsyncSession = Depends(get_db_session)
|
| 47 |
+
) -> User:
|
| 48 |
+
"""Dependency to get the current authenticated user
|
| 49 |
+
|
| 50 |
+
Validates JWT token from cookie/header and returns the associated user.
|
| 51 |
+
Raises 401 Unauthorized if token is invalid or user not found.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
request: FastAPI request object
|
| 55 |
+
db: Database session
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Authenticated User instance
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
HTTPException: 401 if authentication fails
|
| 62 |
+
"""
|
| 63 |
+
# Extract token from request
|
| 64 |
+
token = await get_token_from_request(request)
|
| 65 |
+
|
| 66 |
+
if not token:
|
| 67 |
+
logger.warning("Authentication failed: no token provided")
|
| 68 |
+
raise HTTPException(
|
| 69 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 70 |
+
detail="Not authenticated",
|
| 71 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Decode and validate JWT token
|
| 75 |
+
payload = AuthService.decode_jwt_token(token)
|
| 76 |
+
if not payload:
|
| 77 |
+
logger.warning("Authentication failed: invalid JWT token")
|
| 78 |
+
raise HTTPException(
|
| 79 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 80 |
+
detail="Invalid authentication credentials",
|
| 81 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Extract user_id from token payload
|
| 85 |
+
user_id_str = payload.get("sub")
|
| 86 |
+
if not user_id_str:
|
| 87 |
+
logger.warning("Authentication failed: no user_id in token payload")
|
| 88 |
+
raise HTTPException(
|
| 89 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 90 |
+
detail="Invalid token payload",
|
| 91 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Validate session still exists in database
|
| 95 |
+
session = await AuthService.validate_session(db, token)
|
| 96 |
+
if not session:
|
| 97 |
+
logger.warning(f"Authentication failed: session not found or expired for token")
|
| 98 |
+
raise HTTPException(
|
| 99 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 100 |
+
detail="Session expired or invalid",
|
| 101 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Get user from database
|
| 105 |
+
from uuid import UUID
|
| 106 |
+
try:
|
| 107 |
+
user_id = UUID(user_id_str)
|
| 108 |
+
except ValueError:
|
| 109 |
+
logger.warning(f"Authentication failed: invalid user_id format: {user_id_str}")
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 112 |
+
detail="Invalid user identifier",
|
| 113 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
user = await AuthService.get_user_by_id(db, user_id)
|
| 117 |
+
if not user:
|
| 118 |
+
logger.warning(f"Authentication failed: user {user_id} not found")
|
| 119 |
+
raise HTTPException(
|
| 120 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 121 |
+
detail="User not found",
|
| 122 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
logger.debug(f"User authenticated: {user.id}")
|
| 126 |
+
return user
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
async def get_current_user_optional(
|
| 130 |
+
request: Request,
|
| 131 |
+
db: AsyncSession = Depends(get_db_session)
|
| 132 |
+
) -> Optional[User]:
|
| 133 |
+
"""Optional authentication dependency
|
| 134 |
+
|
| 135 |
+
Same as get_current_user but returns None instead of raising exception
|
| 136 |
+
when no valid authentication is provided.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
request: FastAPI request object
|
| 140 |
+
db: Database session
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Authenticated User instance or None
|
| 144 |
+
"""
|
| 145 |
+
try:
|
| 146 |
+
return await get_current_user(request, db)
|
| 147 |
+
except HTTPException:
|
| 148 |
+
return None
|
src/api/middleware/logging_middleware.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging middleware for API requests and responses.
|
| 2 |
+
|
| 3 |
+
This middleware logs details of each API request and its corresponding response.
|
| 4 |
+
"""
|
| 5 |
+
import time
|
| 6 |
+
from fastapi import Request
|
| 7 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 8 |
+
from src.utils.logger import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
class LoggingMiddleware(BaseHTTPMiddleware):
|
| 13 |
+
async def dispatch(self, request: Request, call_next):
|
| 14 |
+
start_time = time.time()
|
| 15 |
+
|
| 16 |
+
# Log request details
|
| 17 |
+
request_log_details = {
|
| 18 |
+
"method": request.method,
|
| 19 |
+
"path": request.url.path,
|
| 20 |
+
"client": request.client.host,
|
| 21 |
+
}
|
| 22 |
+
logger.info("Request started", extra=request_log_details)
|
| 23 |
+
|
| 24 |
+
response = await call_next(request)
|
| 25 |
+
|
| 26 |
+
process_time = (time.time() - start_time) * 1000
|
| 27 |
+
|
| 28 |
+
# Log response details
|
| 29 |
+
response_log_details = {
|
| 30 |
+
"method": request.method,
|
| 31 |
+
"path": request.url.path,
|
| 32 |
+
"status_code": response.status_code,
|
| 33 |
+
"process_time_ms": f"{process_time:.2f}",
|
| 34 |
+
}
|
| 35 |
+
logger.info("Request finished", extra=response_log_details)
|
| 36 |
+
|
| 37 |
+
return response
|
src/api/middleware/rate_limit.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/api/middleware/rate_limit.py
|
| 2 |
+
from typing import Callable
|
| 3 |
+
from fastapi import Request
|
| 4 |
+
from fastapi.responses import JSONResponse
|
| 5 |
+
import inspect
|
| 6 |
+
|
| 7 |
+
from slowapi import Limiter
|
| 8 |
+
from slowapi.util import get_remote_address
|
| 9 |
+
from slowapi.errors import RateLimitExceeded
|
| 10 |
+
|
| 11 |
+
from src.config.settings import settings
|
| 12 |
+
from src.utils.logger import get_logger
|
| 13 |
+
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
def get_user_identifier(request: Request) -> str:
|
| 17 |
+
user = getattr(request.state, "user", None)
|
| 18 |
+
if user:
|
| 19 |
+
return f"user:{user.id}"
|
| 20 |
+
return f"ip:{get_remote_address(request)}"
|
| 21 |
+
|
| 22 |
+
# Create limiter instance (single authoritative limiter in this module)
|
| 23 |
+
limiter = Limiter(
|
| 24 |
+
key_func=get_user_identifier,
|
| 25 |
+
default_limits=[f"{settings.rate_limit_per_minute}/minute"],
|
| 26 |
+
storage_uri=settings.redis_url,
|
| 27 |
+
strategy="fixed-window"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
async def rate_limit_dependency(request: Request):
|
| 31 |
+
"""
|
| 32 |
+
FastAPI dependency to apply default rate limiting.
|
| 33 |
+
|
| 34 |
+
This wraps a small noop handler with the limiter decorator and calls it.
|
| 35 |
+
That keeps the decorator semantics without passing Request into the decorator.
|
| 36 |
+
"""
|
| 37 |
+
# decorator that would normally be used on a route
|
| 38 |
+
decorator = limiter.limit(f"{settings.rate_limit_per_minute}/minute")
|
| 39 |
+
|
| 40 |
+
# a tiny handler the decorator can wrap
|
| 41 |
+
async def _noop(req: Request):
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
wrapped = decorator(_noop) # now wrapped is a callable handler
|
| 45 |
+
|
| 46 |
+
# call the wrapped handler with the request; handle whether it is awaitable
|
| 47 |
+
result = wrapped(request)
|
| 48 |
+
if inspect.isawaitable(result):
|
| 49 |
+
await result
|
| 50 |
+
|
| 51 |
+
# dependency returns truthy so route proceeds
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
| 56 |
+
identifier = get_user_identifier(request)
|
| 57 |
+
logger.warning(f"Rate limit exceeded for {identifier}")
|
| 58 |
+
|
| 59 |
+
# Try to parse retry after defensively
|
| 60 |
+
retry_after = 60
|
| 61 |
+
try:
|
| 62 |
+
# exc.detail sometimes contains text like "Retry after 60 seconds"
|
| 63 |
+
if isinstance(exc.detail, str) and "Retry after" in exc.detail:
|
| 64 |
+
parts = exc.detail.split("Retry after")
|
| 65 |
+
if len(parts) > 1:
|
| 66 |
+
retry_after = int(''.join(filter(str.isdigit, parts[1])) or 60)
|
| 67 |
+
except Exception:
|
| 68 |
+
retry_after = 60
|
| 69 |
+
|
| 70 |
+
payload = {
|
| 71 |
+
"error": "rate_limit_exceeded",
|
| 72 |
+
"message": f"Too many requests. Please try again in {retry_after} seconds.",
|
| 73 |
+
"retry_after": int(retry_after)
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return JSONResponse(status_code=429, content=payload, headers={"Retry-After": str(retry_after)})
|
src/api/routes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API routes package"""
|
src/api/routes/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
src/api/routes/__pycache__/auth.cpython-312.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
src/api/routes/__pycache__/chat.cpython-312.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
src/api/routes/__pycache__/health.cpython-312.pyc
ADDED
|
Binary file (3.47 kB). View file
|
|
|
src/api/routes/auth.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Authentication routes for user registration, login, and session management
|
| 2 |
+
|
| 3 |
+
Provides endpoints for:
|
| 4 |
+
- POST /auth/register - User registration
|
| 5 |
+
- POST /auth/login - User login
|
| 6 |
+
- POST /auth/logout - User logout
|
| 7 |
+
- GET /auth/me - Get current user
|
| 8 |
+
"""
|
| 9 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
| 10 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 11 |
+
from sqlalchemy import select
|
| 12 |
+
|
| 13 |
+
from src.config.database import get_db_session
|
| 14 |
+
from src.models.user import User
|
| 15 |
+
from src.models.schemas import (
|
| 16 |
+
UserCreate,
|
| 17 |
+
UserLogin,
|
| 18 |
+
UserResponse,
|
| 19 |
+
AuthResponse,
|
| 20 |
+
MessageResponse
|
| 21 |
+
)
|
| 22 |
+
from src.services.auth_service import AuthService
|
| 23 |
+
from src.api.middleware.auth_middleware import get_current_user, get_token_from_request
|
| 24 |
+
from src.utils.logger import get_logger
|
| 25 |
+
from src.utils.validators import sanitize_html
|
| 26 |
+
|
| 27 |
+
logger = get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
| 30 |
+
|
| 31 |
+
# Max allowed length for bcrypt hashing (EXPANDED)
|
| 32 |
+
BCRYPT_PASSWORD_MAX_BYTES = 4096 # Was 72, but expanded to allow for saving a big password
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@router.post(
|
| 36 |
+
"/register",
|
| 37 |
+
response_model=AuthResponse,
|
| 38 |
+
status_code=status.HTTP_201_CREATED,
|
| 39 |
+
summary="Register a new user",
|
| 40 |
+
description="Create a new user account with email and password"
|
| 41 |
+
)
|
| 42 |
+
async def register(
|
| 43 |
+
user_data: UserCreate,
|
| 44 |
+
response: Response,
|
| 45 |
+
db: AsyncSession = Depends(get_db_session)
|
| 46 |
+
) -> AuthResponse:
|
| 47 |
+
"""Register a new user account
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
user_data: User registration data (email, password)
|
| 51 |
+
response: FastAPI response object for setting cookies
|
| 52 |
+
db: Database session
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
AuthResponse with user data and session token
|
| 56 |
+
|
| 57 |
+
Raises:
|
| 58 |
+
HTTPException: 400 if email already exists
|
| 59 |
+
"""
|
| 60 |
+
# Check if user already exists
|
| 61 |
+
result = await db.execute(
|
| 62 |
+
select(User).where(User.email == user_data.email.lower())
|
| 63 |
+
)
|
| 64 |
+
existing_user = result.scalar_one_or_none()
|
| 65 |
+
|
| 66 |
+
if existing_user:
|
| 67 |
+
logger.warning(f"Registration failed: email {user_data.email} already exists")
|
| 68 |
+
raise HTTPException(
|
| 69 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 70 |
+
detail="Email already registered"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Truncate password to BCRYPT_PASSWORD_MAX_BYTES for hashing (now allows larger passwords)
|
| 74 |
+
password = user_data.password
|
| 75 |
+
password_bytes = password.encode("utf-8")
|
| 76 |
+
if len(password_bytes) > BCRYPT_PASSWORD_MAX_BYTES:
|
| 77 |
+
logger.warning(
|
| 78 |
+
f"Password for {user_data.email} is longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes. Truncating..."
|
| 79 |
+
)
|
| 80 |
+
# Truncate the bytes and decode safely at a character boundary
|
| 81 |
+
truncated_bytes = password_bytes[:BCRYPT_PASSWORD_MAX_BYTES]
|
| 82 |
+
while True:
|
| 83 |
+
try:
|
| 84 |
+
password = truncated_bytes.decode("utf-8")
|
| 85 |
+
break
|
| 86 |
+
except UnicodeDecodeError:
|
| 87 |
+
truncated_bytes = truncated_bytes[:-1]
|
| 88 |
+
# else password remains if within byte limit
|
| 89 |
+
|
| 90 |
+
# Create new user
|
| 91 |
+
try:
|
| 92 |
+
user = await AuthService.create_user(
|
| 93 |
+
db=db,
|
| 94 |
+
email=user_data.email,
|
| 95 |
+
password=password
|
| 96 |
+
)
|
| 97 |
+
except AttributeError as e:
|
| 98 |
+
logger.error(
|
| 99 |
+
f"User creation failed due to bcrypt or passlib error: {e}. "
|
| 100 |
+
"This is likely caused by an incompatible version of bcrypt. "
|
| 101 |
+
"Please ensure 'bcrypt' and 'passlib' are installed and up to date. "
|
| 102 |
+
"Upgrade them using: pip install --upgrade bcrypt passlib"
|
| 103 |
+
)
|
| 104 |
+
raise HTTPException(
|
| 105 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 106 |
+
detail="Server configuration error: bcrypt module issue, please contact support."
|
| 107 |
+
)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
msg = str(e)
|
| 110 |
+
if "password cannot be longer than " in msg:
|
| 111 |
+
logger.error(
|
| 112 |
+
f"User creation failed: password too long for bcrypt for {user_data.email}: {e}. Manual truncation should have prevented this."
|
| 113 |
+
)
|
| 114 |
+
raise HTTPException(
|
| 115 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 116 |
+
detail=f"Password cannot be longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes. Please use a shorter password."
|
| 117 |
+
)
|
| 118 |
+
logger.error(f"User creation failed: {e}")
|
| 119 |
+
raise HTTPException(
|
| 120 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 121 |
+
detail="Failed to create user account"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Generate JWT token
|
| 125 |
+
token = AuthService.generate_jwt_token(user.id)
|
| 126 |
+
|
| 127 |
+
# Create session in database
|
| 128 |
+
await AuthService.create_session(db=db, user_id=user.id, token=token)
|
| 129 |
+
|
| 130 |
+
# Set HTTP-only cookie with JWT token
|
| 131 |
+
response.set_cookie(
|
| 132 |
+
key="auth_token",
|
| 133 |
+
value=token,
|
| 134 |
+
httponly=True,
|
| 135 |
+
secure=True, # Only send over HTTPS
|
| 136 |
+
samesite="lax", # CSRF protection
|
| 137 |
+
max_age=60 * 60 * 24 * 7 # 7 days
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
logger.info(f"User registered and logged in: {user.id}")
|
| 141 |
+
|
| 142 |
+
return AuthResponse(
|
| 143 |
+
user=UserResponse.model_validate(user),
|
| 144 |
+
message="Registration successful"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@router.post(
|
| 149 |
+
"/login",
|
| 150 |
+
response_model=AuthResponse,
|
| 151 |
+
status_code=status.HTTP_200_OK,
|
| 152 |
+
summary="Login user",
|
| 153 |
+
description="Authenticate user with email and password, create session"
|
| 154 |
+
)
|
| 155 |
+
async def login(
|
| 156 |
+
credentials: UserLogin,
|
| 157 |
+
response: Response,
|
| 158 |
+
db: AsyncSession = Depends(get_db_session)
|
| 159 |
+
) -> AuthResponse:
|
| 160 |
+
"""Authenticate user and create session
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
credentials: User login credentials (email, password)
|
| 164 |
+
response: FastAPI response object for setting cookies
|
| 165 |
+
db: Database session
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
AuthResponse with user data and session token
|
| 169 |
+
|
| 170 |
+
Raises:
|
| 171 |
+
HTTPException: 401 if credentials are invalid
|
| 172 |
+
"""
|
| 173 |
+
# Truncate password to BCRYPT_PASSWORD_MAX_BYTES for hashing (now allows larger passwords)
|
| 174 |
+
password = credentials.password
|
| 175 |
+
password_bytes = password.encode("utf-8")
|
| 176 |
+
if len(password_bytes) > BCRYPT_PASSWORD_MAX_BYTES:
|
| 177 |
+
logger.warning(
|
| 178 |
+
f"Login attempt with password longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes. Truncating for bcrypt."
|
| 179 |
+
)
|
| 180 |
+
truncated_bytes = password_bytes[:BCRYPT_PASSWORD_MAX_BYTES]
|
| 181 |
+
while True:
|
| 182 |
+
try:
|
| 183 |
+
password = truncated_bytes.decode("utf-8")
|
| 184 |
+
break
|
| 185 |
+
except UnicodeDecodeError:
|
| 186 |
+
truncated_bytes = truncated_bytes[:-1]
|
| 187 |
+
|
| 188 |
+
# Authenticate user
|
| 189 |
+
try:
|
| 190 |
+
user = await AuthService.authenticate_user(
|
| 191 |
+
db=db,
|
| 192 |
+
email=credentials.email,
|
| 193 |
+
password=password
|
| 194 |
+
)
|
| 195 |
+
except AttributeError as e:
|
| 196 |
+
logger.error(
|
| 197 |
+
f"User authentication failed due to bcrypt or passlib error: {e}. "
|
| 198 |
+
"Please ensure 'bcrypt' and 'passlib' are installed and up to date."
|
| 199 |
+
)
|
| 200 |
+
raise HTTPException(
|
| 201 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 202 |
+
detail="Server configuration error: bcrypt module issue, please contact support."
|
| 203 |
+
)
|
| 204 |
+
except Exception as e:
|
| 205 |
+
msg = str(e)
|
| 206 |
+
if "password cannot be longer than " in msg:
|
| 207 |
+
logger.error(
|
| 208 |
+
f"User authentication failed: password too long for bcrypt: {e}. Manual truncation should have prevented this."
|
| 209 |
+
)
|
| 210 |
+
raise HTTPException(
|
| 211 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 212 |
+
detail=f"Password cannot be longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes."
|
| 213 |
+
)
|
| 214 |
+
logger.error(f"User authentication failed: {e}")
|
| 215 |
+
raise HTTPException(
|
| 216 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 217 |
+
detail="Server error during authentication"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if not user:
|
| 221 |
+
raise HTTPException(
|
| 222 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 223 |
+
detail="Invalid email or password",
|
| 224 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Generate JWT token
|
| 228 |
+
token = AuthService.generate_jwt_token(user.id)
|
| 229 |
+
|
| 230 |
+
# Create session in database
|
| 231 |
+
await AuthService.create_session(db=db, user_id=user.id, token=token)
|
| 232 |
+
|
| 233 |
+
# Set HTTP-only cookie with JWT token
|
| 234 |
+
response.set_cookie(
|
| 235 |
+
key="auth_token",
|
| 236 |
+
value=token,
|
| 237 |
+
httponly=True,
|
| 238 |
+
secure=True,
|
| 239 |
+
samesite="lax",
|
| 240 |
+
max_age=60 * 60 * 24 * 7 # 7 days
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
logger.info(f"User logged in: {user.id}")
|
| 244 |
+
|
| 245 |
+
return AuthResponse(
|
| 246 |
+
user=UserResponse.model_validate(user),
|
| 247 |
+
message="Login successful"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@router.post(
|
| 252 |
+
"/logout",
|
| 253 |
+
response_model=MessageResponse,
|
| 254 |
+
status_code=status.HTTP_200_OK,
|
| 255 |
+
summary="Logout user",
|
| 256 |
+
description="Revoke user session and clear authentication cookie"
|
| 257 |
+
)
|
| 258 |
+
async def logout(
|
| 259 |
+
response: Response,
|
| 260 |
+
current_user: User = Depends(get_current_user),
|
| 261 |
+
db: AsyncSession = Depends(get_db_session)
|
| 262 |
+
) -> MessageResponse:
|
| 263 |
+
"""Logout user by revoking session
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
response: FastAPI response object for clearing cookies
|
| 267 |
+
current_user: Authenticated user (from middleware)
|
| 268 |
+
db: Database session
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
MessageResponse confirming logout
|
| 272 |
+
"""
|
| 273 |
+
# Get token from cookie
|
| 274 |
+
from fastapi import Request
|
| 275 |
+
# Note: We need to extract token from request, but we already have current_user
|
| 276 |
+
# so we can just delete the cookie. In production, we'd also revoke the session.
|
| 277 |
+
|
| 278 |
+
# Clear HTTP-only cookie
|
| 279 |
+
response.delete_cookie(key="auth_token", httponly=True, secure=True, samesite="lax")
|
| 280 |
+
|
| 281 |
+
logger.info(f"User logged out: {current_user.id}")
|
| 282 |
+
|
| 283 |
+
return MessageResponse(message="Logout successful")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
@router.get(
|
| 287 |
+
"/me",
|
| 288 |
+
response_model=UserResponse,
|
| 289 |
+
status_code=status.HTTP_200_OK,
|
| 290 |
+
summary="Get current user",
|
| 291 |
+
description="Get authenticated user's profile information"
|
| 292 |
+
)
|
| 293 |
+
async def get_current_user_profile(
|
| 294 |
+
current_user: User = Depends(get_current_user)
|
| 295 |
+
) -> UserResponse:
|
| 296 |
+
"""Get current authenticated user's profile
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
current_user: Authenticated user (from middleware)
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
UserResponse with user profile data
|
| 303 |
+
"""
|
| 304 |
+
return UserResponse.model_validate(current_user)
|
src/api/routes/chat.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chat message endpoints for interacting with the RAG chatbot.
|
| 2 |
+
|
| 3 |
+
Provides endpoints for:
|
| 4 |
+
- POST /chat/message - Send a new message to the chatbot, get a RAG-powered response
|
| 5 |
+
- GET /chat/history - Get chat history for a specific thread
|
| 6 |
+
- GET /chat/threads - Get a list of user's chat threads
|
| 7 |
+
"""
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
from uuid import UUID, uuid4
|
| 10 |
+
import time
|
| 11 |
+
import asyncio
|
| 12 |
+
|
| 13 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 14 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 15 |
+
from sqlalchemy import select, func
|
| 16 |
+
from openai import OpenAIError, APITimeoutError, APIConnectionError, APIStatusError
|
| 17 |
+
|
| 18 |
+
from src.api.middleware.auth_middleware import get_current_user
|
| 19 |
+
# from src.api.middleware.rate_limit import rate_limit_dependency
|
| 20 |
+
from src.config.database import get_db_session
|
| 21 |
+
from src.models.user import User
|
| 22 |
+
from src.models.schemas import (
|
| 23 |
+
ChatMessageCreate,
|
| 24 |
+
ChatMessageResponse,
|
| 25 |
+
ChatResponse,
|
| 26 |
+
ChatHistoryResponse
|
| 27 |
+
)
|
| 28 |
+
from src.models.chat_message import ChatUserRole, ChatMessage
|
| 29 |
+
from src.services.chat_service import ChatService
|
| 30 |
+
from src.services.rag_service import RAGService
|
| 31 |
+
from src.services.vector_service import VectorService
|
| 32 |
+
from src.utils.validators import sanitize_html
|
| 33 |
+
from src.utils.logger import get_logger
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
router = APIRouter(prefix="/chat", tags=["Chatbot"])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.post(
|
| 41 |
+
"/message",
|
| 42 |
+
response_model=ChatResponse,
|
| 43 |
+
status_code=status.HTTP_200_OK,
|
| 44 |
+
summary="Send message to chatbot",
|
| 45 |
+
description="Send a message to the RAG chatbot and get a response based on book content"
|
| 46 |
+
)
|
| 47 |
+
async def send_message(
|
| 48 |
+
chat_message_data: ChatMessageCreate,
|
| 49 |
+
current_user: User = Depends(get_current_user),
|
| 50 |
+
db: AsyncSession = Depends(get_db_session),
|
| 51 |
+
# rate_limit_status: Dict[str, Any] = Depends(rate_limit_dependency)
|
| 52 |
+
) -> ChatResponse:
|
| 53 |
+
"""Send a message to the RAG chatbot.
|
| 54 |
+
|
| 55 |
+
Handles full-book queries and selected-text queries.
|
| 56 |
+
Retrieves context from Qdrant, generates response using RAG service,
|
| 57 |
+
and persists both user and assistant messages to the database.
|
| 58 |
+
"""
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
user_id = current_user.id
|
| 61 |
+
query_mode = chat_message_data.query_mode or "full_book"
|
| 62 |
+
user_message_content = sanitize_html(chat_message_data.message, strip=True)
|
| 63 |
+
selected_text_content = sanitize_html(chat_message_data.selected_text, strip=True) if chat_message_data.selected_text else None
|
| 64 |
+
|
| 65 |
+
# Determine thread_id: create new if not provided (simple UUID for conversation grouping)
|
| 66 |
+
thread_id = chat_message_data.thread_id if chat_message_data.thread_id else str(uuid4())
|
| 67 |
+
logger.info(f"Processing message for user {user_id}, thread {thread_id}")
|
| 68 |
+
|
| 69 |
+
# 1. Save user message to DB
|
| 70 |
+
user_message_db = await ChatService.save_message(
|
| 71 |
+
db=db,
|
| 72 |
+
user_id=user_id,
|
| 73 |
+
thread_id=thread_id,
|
| 74 |
+
role=ChatUserRole.USER,
|
| 75 |
+
content=user_message_content,
|
| 76 |
+
metadata={
|
| 77 |
+
"query_mode": query_mode,
|
| 78 |
+
"selected_text": selected_text_content
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# 2 & 3. Run vector search and chat history retrieval in parallel for speed
|
| 83 |
+
try:
|
| 84 |
+
# Prepare vector search task
|
| 85 |
+
if query_mode == "selection" and not selected_text_content:
|
| 86 |
+
raise HTTPException(
|
| 87 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 88 |
+
detail="selected_text is required for 'selection' query mode"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Determine search parameters
|
| 92 |
+
search_text = selected_text_content if query_mode == "selection" else user_message_content
|
| 93 |
+
top_k = 3 if query_mode == "selection" else 5
|
| 94 |
+
|
| 95 |
+
# Run both operations in parallel
|
| 96 |
+
context_chunks, chat_history = await asyncio.gather(
|
| 97 |
+
VectorService.search_similar_chunks(query_text=search_text, top_k=top_k),
|
| 98 |
+
ChatService.get_chat_history(db=db, user_id=user_id, thread_id=thread_id, limit=10)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Convert history to dict format for RAG service
|
| 102 |
+
history_dicts = [
|
| 103 |
+
{"role": msg.role, "content": msg.content}
|
| 104 |
+
for msg in reversed(chat_history) # Reverse to chronological order
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
except HTTPException:
|
| 108 |
+
raise # Re-raise validation errors
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Error retrieving context or history: {e}", exc_info=True)
|
| 111 |
+
raise HTTPException(
|
| 112 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 113 |
+
detail="Failed to retrieve relevant book content. Please try again."
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 4. Generate response using RAG service with Agents SDK
|
| 117 |
+
try:
|
| 118 |
+
rag_response = await RAGService.generate_response(
|
| 119 |
+
user_message=user_message_content,
|
| 120 |
+
context_chunks=context_chunks,
|
| 121 |
+
chat_history=history_dicts,
|
| 122 |
+
query_mode=query_mode,
|
| 123 |
+
selected_text=selected_text_content
|
| 124 |
+
)
|
| 125 |
+
assistant_message_content = rag_response["content"]
|
| 126 |
+
chunk_ids = rag_response["chunk_ids"]
|
| 127 |
+
model_used = rag_response["model_used"]
|
| 128 |
+
|
| 129 |
+
except APIStatusError as e:
|
| 130 |
+
logger.error(f"OpenAI API error (status: {e.status_code}): {e.message}")
|
| 131 |
+
if e.status_code == 429:
|
| 132 |
+
raise HTTPException(
|
| 133 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 134 |
+
detail="OpenAI API rate limit exceeded. Please try again shortly."
|
| 135 |
+
)
|
| 136 |
+
raise HTTPException(
|
| 137 |
+
status_code=status.HTTP_502_BAD_GATEWAY,
|
| 138 |
+
detail=f"OpenAI API error: {e.message}"
|
| 139 |
+
)
|
| 140 |
+
except APITimeoutError as e:
|
| 141 |
+
logger.error(f"OpenAI API timeout error: {e}")
|
| 142 |
+
raise HTTPException(
|
| 143 |
+
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
| 144 |
+
detail="OpenAI API timed out. Please try again."
|
| 145 |
+
)
|
| 146 |
+
except APIConnectionError as e:
|
| 147 |
+
logger.error(f"OpenAI API connection error: {e}")
|
| 148 |
+
raise HTTPException(
|
| 149 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 150 |
+
detail="Could not connect to OpenAI API. Please check your internet connection or try again later."
|
| 151 |
+
)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"Generic error from RAG service: {e}", exc_info=True)
|
| 154 |
+
raise HTTPException(
|
| 155 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 156 |
+
detail="Failed to generate response from chatbot. Please try again."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Calculate response time
|
| 160 |
+
response_time_ms = int((time.time() - start_time) * 1000)
|
| 161 |
+
|
| 162 |
+
# 5. Save assistant message to DB
|
| 163 |
+
assistant_message_db = await ChatService.save_message(
|
| 164 |
+
db=db,
|
| 165 |
+
user_id=user_id,
|
| 166 |
+
thread_id=thread_id,
|
| 167 |
+
role=ChatUserRole.ASSISTANT,
|
| 168 |
+
content=assistant_message_content,
|
| 169 |
+
metadata={
|
| 170 |
+
"query_mode": query_mode,
|
| 171 |
+
"selected_text_context": selected_text_content if query_mode == "selection" else None,
|
| 172 |
+
"chunk_ids": chunk_ids,
|
| 173 |
+
"model_used": model_used,
|
| 174 |
+
"response_time_ms": response_time_ms
|
| 175 |
+
}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
logger.info(f"Chat response generated for user {user_id} in thread {thread_id} ({response_time_ms}ms)")
|
| 179 |
+
|
| 180 |
+
return ChatResponse(
|
| 181 |
+
user_message=ChatMessageResponse.model_validate(user_message_db),
|
| 182 |
+
assistant_message=ChatMessageResponse.model_validate(assistant_message_db),
|
| 183 |
+
thread_id=thread_id
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@router.get(
|
| 188 |
+
"/history",
|
| 189 |
+
response_model=ChatHistoryResponse,
|
| 190 |
+
status_code=status.HTTP_200_OK,
|
| 191 |
+
summary="Get chat history",
|
| 192 |
+
description="Retrieve paginated chat message history for a specific thread"
|
| 193 |
+
)
|
| 194 |
+
async def get_chat_history(
|
| 195 |
+
thread_id: str,
|
| 196 |
+
limit: int = 50,
|
| 197 |
+
offset: int = 0,
|
| 198 |
+
current_user: User = Depends(get_current_user),
|
| 199 |
+
db: AsyncSession = Depends(get_db_session)
|
| 200 |
+
) -> ChatHistoryResponse:
|
| 201 |
+
"""Retrieve chat history for a specific thread.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
thread_id: The ID of the conversation thread.
|
| 205 |
+
limit: The maximum number of messages to return.
|
| 206 |
+
offset: The number of messages to skip for pagination.
|
| 207 |
+
current_user: The authenticated user.
|
| 208 |
+
db: The database session.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
ChatHistoryResponse containing messages and total count.
|
| 212 |
+
"""
|
| 213 |
+
user_id = current_user.id
|
| 214 |
+
|
| 215 |
+
messages_db = await ChatService.get_chat_history(
|
| 216 |
+
db=db,
|
| 217 |
+
user_id=user_id,
|
| 218 |
+
thread_id=thread_id,
|
| 219 |
+
limit=limit,
|
| 220 |
+
offset=offset
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Convert to Pydantic models
|
| 224 |
+
messages_response = [ChatMessageResponse.model_validate(msg) for msg in messages_db]
|
| 225 |
+
|
| 226 |
+
# Get total count for pagination metadata
|
| 227 |
+
total_messages = await db.scalar(
|
| 228 |
+
select(func.count(ChatMessage.id))
|
| 229 |
+
.where(ChatMessage.user_id == user_id, ChatMessage.thread_id == thread_id)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
logger.info(f"Retrieved chat history for user {user_id}, thread {thread_id}")
|
| 233 |
+
|
| 234 |
+
return ChatHistoryResponse(
|
| 235 |
+
messages=messages_response,
|
| 236 |
+
total=total_messages if total_messages is not None else 0,
|
| 237 |
+
thread_id=thread_id
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@router.get(
|
| 242 |
+
"/threads",
|
| 243 |
+
response_model=List[Dict[str, Any]],
|
| 244 |
+
status_code=status.HTTP_200_OK,
|
| 245 |
+
summary="Get user chat threads",
|
| 246 |
+
description="Retrieve a list of all chat threads for the authenticated user"
|
| 247 |
+
)
|
| 248 |
+
async def get_user_chat_threads(
|
| 249 |
+
current_user: User = Depends(get_current_user),
|
| 250 |
+
db: AsyncSession = Depends(get_db_session)
|
| 251 |
+
) -> List[Dict[str, Any]]:
|
| 252 |
+
"""Retrieve a list of chat threads for the authenticated user.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
current_user: The authenticated user.
|
| 256 |
+
db: The database session.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
A list of dictionaries, each representing a thread summary.
|
| 260 |
+
"""
|
| 261 |
+
user_id = current_user.id
|
| 262 |
+
threads = await ChatService.get_user_threads(db=db, user_id=user_id)
|
| 263 |
+
logger.info(f"Retrieved {len(threads)} chat threads for user {user_id}")
|
| 264 |
+
return threads
|
src/api/routes/health.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Health check endpoint
|
| 2 |
+
|
| 3 |
+
Provides health status for the application and its dependencies.
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import APIRouter, status
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
from sqlalchemy import text
|
| 11 |
+
|
| 12 |
+
from src.config.database import get_engine, get_qdrant_client
|
| 13 |
+
from src.config.settings import settings
|
| 14 |
+
from src.utils.logger import get_logger
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
router = APIRouter(tags=["health"])
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HealthResponse(BaseModel):
|
| 22 |
+
"""Health check response model"""
|
| 23 |
+
status: str
|
| 24 |
+
timestamp: str
|
| 25 |
+
environment: str
|
| 26 |
+
dependencies: Dict[str, str]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@router.get(
|
| 30 |
+
"/health",
|
| 31 |
+
response_model=HealthResponse,
|
| 32 |
+
status_code=status.HTTP_200_OK,
|
| 33 |
+
summary="Health Check",
|
| 34 |
+
description="Check the health status of the application and its dependencies"
|
| 35 |
+
)
|
| 36 |
+
async def health_check() -> HealthResponse:
|
| 37 |
+
"""Perform health check on application and dependencies
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
HealthResponse with status and dependency information
|
| 41 |
+
"""
|
| 42 |
+
dependencies: Dict[str, str] = {}
|
| 43 |
+
|
| 44 |
+
# Check database connection
|
| 45 |
+
try:
|
| 46 |
+
engine = get_engine()
|
| 47 |
+
async with engine.connect() as conn:
|
| 48 |
+
await conn.execute(text("SELECT 1"))
|
| 49 |
+
dependencies["database"] = "healthy"
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f"Database health check failed: {e}")
|
| 52 |
+
dependencies["database"] = "unhealthy"
|
| 53 |
+
|
| 54 |
+
# Check Qdrant connection
|
| 55 |
+
try:
|
| 56 |
+
client = get_qdrant_client()
|
| 57 |
+
await client.get_collections()
|
| 58 |
+
dependencies["qdrant"] = "healthy"
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Qdrant health check failed: {e}")
|
| 61 |
+
dependencies["qdrant"] = "unhealthy"
|
| 62 |
+
|
| 63 |
+
# Overall status
|
| 64 |
+
overall_status = "healthy" if all(
|
| 65 |
+
dep_status == "healthy" for dep_status in dependencies.values()
|
| 66 |
+
) else "degraded"
|
| 67 |
+
|
| 68 |
+
return HealthResponse(
|
| 69 |
+
status=overall_status,
|
| 70 |
+
timestamp=datetime.utcnow().isoformat(),
|
| 71 |
+
environment=settings.environment,
|
| 72 |
+
dependencies=dependencies
|
| 73 |
+
)
|
src/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Configuration package"""
|
src/config/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
src/config/__pycache__/database.cpython-312.pyc
ADDED
|
Binary file (5.53 kB). View file
|
|
|
src/config/__pycache__/settings.cpython-312.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
src/config/database.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database and vector store configuration
|
| 2 |
+
|
| 3 |
+
Provides async database engine, session management, and Qdrant client setup.
|
| 4 |
+
"""
|
| 5 |
+
from typing import AsyncGenerator
|
| 6 |
+
from sqlalchemy.ext.asyncio import (
|
| 7 |
+
AsyncEngine,
|
| 8 |
+
AsyncSession,
|
| 9 |
+
create_async_engine,
|
| 10 |
+
async_sessionmaker,
|
| 11 |
+
)
|
| 12 |
+
from sqlalchemy.orm import declarative_base
|
| 13 |
+
from qdrant_client import AsyncQdrantClient
|
| 14 |
+
from qdrant_client.models import Distance, VectorParams
|
| 15 |
+
|
| 16 |
+
from src.config.settings import settings
|
| 17 |
+
from src.utils.logger import get_logger
|
| 18 |
+
|
| 19 |
+
logger = get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
# SQLAlchemy Base for models
|
| 22 |
+
Base = declarative_base()
|
| 23 |
+
|
| 24 |
+
# Global engine instance
|
| 25 |
+
_engine: AsyncEngine | None = None
|
| 26 |
+
|
| 27 |
+
# Global session maker
|
| 28 |
+
_async_session_maker: async_sessionmaker[AsyncSession] | None = None
|
| 29 |
+
|
| 30 |
+
# Global Qdrant client
|
| 31 |
+
_qdrant_client: AsyncQdrantClient | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_engine() -> AsyncEngine:
|
| 35 |
+
"""Get or create the async database engine
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
AsyncEngine instance
|
| 39 |
+
"""
|
| 40 |
+
global _engine
|
| 41 |
+
|
| 42 |
+
if _engine is None:
|
| 43 |
+
# AsyncPG connection arguments for SSL
|
| 44 |
+
connect_args = {}
|
| 45 |
+
if "neon.tech" in settings.database_url or settings.is_production:
|
| 46 |
+
# Enable SSL for Neon and production databases
|
| 47 |
+
connect_args["ssl"] = "require"
|
| 48 |
+
|
| 49 |
+
_engine = create_async_engine(
|
| 50 |
+
settings.async_database_url,
|
| 51 |
+
echo=not settings.is_production, # Log SQL in development
|
| 52 |
+
pool_pre_ping=True, # Verify connections before using
|
| 53 |
+
pool_size=5,
|
| 54 |
+
max_overflow=10,
|
| 55 |
+
connect_args=connect_args,
|
| 56 |
+
)
|
| 57 |
+
logger.info("Database engine created")
|
| 58 |
+
|
| 59 |
+
return _engine
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_session_maker() -> async_sessionmaker[AsyncSession]:
|
| 63 |
+
"""Get or create the async session maker
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
async_sessionmaker instance
|
| 67 |
+
"""
|
| 68 |
+
global _async_session_maker
|
| 69 |
+
|
| 70 |
+
if _async_session_maker is None:
|
| 71 |
+
engine = get_engine()
|
| 72 |
+
_async_session_maker = async_sessionmaker(
|
| 73 |
+
engine,
|
| 74 |
+
class_=AsyncSession,
|
| 75 |
+
expire_on_commit=False,
|
| 76 |
+
)
|
| 77 |
+
logger.info("Session maker created")
|
| 78 |
+
|
| 79 |
+
return _async_session_maker
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
| 83 |
+
"""Dependency for getting async database sessions
|
| 84 |
+
|
| 85 |
+
Yields:
|
| 86 |
+
AsyncSession instance
|
| 87 |
+
"""
|
| 88 |
+
session_maker = get_session_maker()
|
| 89 |
+
async with session_maker() as session:
|
| 90 |
+
try:
|
| 91 |
+
yield session
|
| 92 |
+
finally:
|
| 93 |
+
await session.close()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_qdrant_client() -> AsyncQdrantClient:
|
| 97 |
+
"""Get or create the Qdrant client
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
AsyncQdrantClient instance
|
| 101 |
+
"""
|
| 102 |
+
global _qdrant_client
|
| 103 |
+
|
| 104 |
+
if _qdrant_client is None:
|
| 105 |
+
_qdrant_client = AsyncQdrantClient(
|
| 106 |
+
url=settings.qdrant_url,
|
| 107 |
+
api_key=settings.qdrant_api_key,
|
| 108 |
+
timeout=30.0,
|
| 109 |
+
)
|
| 110 |
+
logger.info("Qdrant client created")
|
| 111 |
+
|
| 112 |
+
return _qdrant_client
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
async def init_qdrant_collection() -> None:
|
| 116 |
+
"""Initialize Qdrant collection if it doesn't exist
|
| 117 |
+
|
| 118 |
+
Creates the collection with appropriate vector configuration.
|
| 119 |
+
"""
|
| 120 |
+
client = get_qdrant_client()
|
| 121 |
+
|
| 122 |
+
# Check if collection exists
|
| 123 |
+
collections = await client.get_collections()
|
| 124 |
+
collection_names = [col.name for col in collections.collections]
|
| 125 |
+
|
| 126 |
+
if settings.qdrant_collection_name not in collection_names:
|
| 127 |
+
# Create collection with vector configuration
|
| 128 |
+
await client.create_collection(
|
| 129 |
+
collection_name=settings.qdrant_collection_name,
|
| 130 |
+
vectors_config=VectorParams(
|
| 131 |
+
size=settings.vector_size,
|
| 132 |
+
distance=Distance.COSINE,
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
logger.info(f"Created Qdrant collection: {settings.qdrant_collection_name}")
|
| 136 |
+
else:
|
| 137 |
+
logger.info(f"Qdrant collection already exists: {settings.qdrant_collection_name}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
async def close_database_connections() -> None:
|
| 141 |
+
"""Close all database connections gracefully"""
|
| 142 |
+
global _engine, _qdrant_client
|
| 143 |
+
|
| 144 |
+
if _engine is not None:
|
| 145 |
+
await _engine.dispose()
|
| 146 |
+
logger.info("Database engine disposed")
|
| 147 |
+
_engine = None
|
| 148 |
+
|
| 149 |
+
if _qdrant_client is not None:
|
| 150 |
+
await _qdrant_client.close()
|
| 151 |
+
logger.info("Qdrant client closed")
|
| 152 |
+
_qdrant_client = None
|
src/config/settings.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Application settings and configuration
|
| 2 |
+
|
| 3 |
+
Loads environment variables and provides type-safe configuration.
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Union, Optional
|
| 6 |
+
from pydantic import field_validator
|
| 7 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 8 |
+
from openai import AsyncOpenAI
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Settings(BaseSettings):
|
| 12 |
+
"""Application settings loaded from environment variables"""
|
| 13 |
+
|
| 14 |
+
model_config = SettingsConfigDict(
|
| 15 |
+
env_file=".env",
|
| 16 |
+
env_file_encoding="utf-8",
|
| 17 |
+
case_sensitive=False,
|
| 18 |
+
extra="ignore"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Database
|
| 22 |
+
database_url: str
|
| 23 |
+
|
| 24 |
+
# Qdrant Vector Database
|
| 25 |
+
qdrant_url: str
|
| 26 |
+
qdrant_api_key: str
|
| 27 |
+
qdrant_collection_name: str = "humanoid-robotics-book-v1"
|
| 28 |
+
vector_size: int = 1536 # OpenAI text-embedding-3-small dimension
|
| 29 |
+
|
| 30 |
+
# OpenAI
|
| 31 |
+
openai_api_key: str
|
| 32 |
+
openai_embedding_model: str = "text-embedding-3-small"
|
| 33 |
+
chat_model: str = "gpt-4o-mini" # Fast, cost-effective model for chat (was gpt-4-turbo-preview)
|
| 34 |
+
|
| 35 |
+
# Authentication
|
| 36 |
+
better_auth_secret: str
|
| 37 |
+
session_expiry_days: int = 7
|
| 38 |
+
|
| 39 |
+
# Rate Limiting
|
| 40 |
+
rate_limit_per_minute: int = 20
|
| 41 |
+
redis_url: str = "redis://localhost:6379"
|
| 42 |
+
|
| 43 |
+
# CORS
|
| 44 |
+
allowed_origins: Union[str, List[str]] = "http://localhost:3000,http://localhost:8000"
|
| 45 |
+
|
| 46 |
+
# Application
|
| 47 |
+
environment: str = "development"
|
| 48 |
+
log_level: str = "INFO"
|
| 49 |
+
|
| 50 |
+
@field_validator("allowed_origins", mode="before")
|
| 51 |
+
@classmethod
|
| 52 |
+
def parse_cors_origins(cls, v):
|
| 53 |
+
"""Parse CORS origins from comma-separated string or list"""
|
| 54 |
+
if isinstance(v, str):
|
| 55 |
+
return [origin.strip() for origin in v.split(",")]
|
| 56 |
+
return v
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def is_production(self) -> bool:
|
| 60 |
+
"""Check if running in production environment"""
|
| 61 |
+
return self.environment.lower() == "production"
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def async_database_url(self) -> str:
|
| 65 |
+
"""Get async database URL for SQLAlchemy
|
| 66 |
+
|
| 67 |
+
Converts postgresql:// to postgresql+asyncpg:// and removes sslmode and
|
| 68 |
+
channel_binding parameters since asyncpg uses different SSL configuration.
|
| 69 |
+
"""
|
| 70 |
+
url = self.database_url
|
| 71 |
+
|
| 72 |
+
# Replace postgresql:// with postgresql+asyncpg://
|
| 73 |
+
if url.startswith("postgresql://"):
|
| 74 |
+
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
| 75 |
+
|
| 76 |
+
# Remove sslmode and channel_binding parameters that asyncpg doesn't support
|
| 77 |
+
# asyncpg will handle SSL automatically
|
| 78 |
+
if "sslmode=" in url or "channel_binding=" in url:
|
| 79 |
+
from urllib.parse import urlparse, parse_qs, urlencode, urlunparse
|
| 80 |
+
parsed = urlparse(url)
|
| 81 |
+
query_params = parse_qs(parsed.query)
|
| 82 |
+
# Remove sslmode and channel_binding
|
| 83 |
+
query_params.pop('sslmode', None)
|
| 84 |
+
query_params.pop('channel_binding', None)
|
| 85 |
+
# Reconstruct the query string
|
| 86 |
+
new_query = urlencode(query_params, doseq=True)
|
| 87 |
+
url = urlunparse((
|
| 88 |
+
parsed.scheme,
|
| 89 |
+
parsed.netloc,
|
| 90 |
+
parsed.path,
|
| 91 |
+
parsed.params,
|
| 92 |
+
new_query,
|
| 93 |
+
parsed.fragment
|
| 94 |
+
))
|
| 95 |
+
|
| 96 |
+
return url
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Global settings instance
|
| 100 |
+
settings = Settings()
|
| 101 |
+
|
| 102 |
+
# Global OpenAI client instance (only if API key is provided)
|
| 103 |
+
openai_client = AsyncOpenAI(api_key=settings.openai_api_key) if settings.openai_api_key else None
|
src/main.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application entry point
|
| 2 |
+
|
| 3 |
+
Main application setup with middleware, CORS, and route configuration.
|
| 4 |
+
"""
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from typing import AsyncGenerator
|
| 7 |
+
from fastapi import FastAPI, Request
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
| 10 |
+
# Temporarily disabled for debugging
|
| 11 |
+
from secure import Secure
|
| 12 |
+
|
| 13 |
+
# Secure middleware configuration
|
| 14 |
+
# Temporarily disabled for debugging
|
| 15 |
+
secure_headers = Secure()
|
| 16 |
+
from fastapi.responses import JSONResponse
|
| 17 |
+
# Temporarily disabled for debugging
|
| 18 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 19 |
+
from slowapi.util import get_remote_address
|
| 20 |
+
from slowapi.errors import RateLimitExceeded
|
| 21 |
+
|
| 22 |
+
from src.config.settings import settings
|
| 23 |
+
from src.config.database import init_qdrant_collection, close_database_connections
|
| 24 |
+
from src.utils.logger import setup_logging, get_logger
|
| 25 |
+
from src.api.routes import health, auth, chat
|
| 26 |
+
|
| 27 |
+
# Setup logging
|
| 28 |
+
setup_logging(
|
| 29 |
+
level=settings.log_level,
|
| 30 |
+
use_json=settings.is_production
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
logger = get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Rate limiter configuration (using the one from rate_limit middleware)
|
| 37 |
+
# limiter = Limiter(
|
| 38 |
+
# key_func=get_remote_address,
|
| 39 |
+
# default_limits=[f"{settings.rate_limit_per_minute}/minute"]
|
| 40 |
+
# )
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@asynccontextmanager
|
| 44 |
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 45 |
+
"""Application lifespan manager
|
| 46 |
+
|
| 47 |
+
Handles startup and shutdown events.
|
| 48 |
+
"""
|
| 49 |
+
# Startup
|
| 50 |
+
logger.info("Starting up application...")
|
| 51 |
+
logger.info(f"Environment: {settings.environment}")
|
| 52 |
+
|
| 53 |
+
# Initialize Qdrant collection
|
| 54 |
+
try:
|
| 55 |
+
await init_qdrant_collection()
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Failed to initialize Qdrant collection: {e}")
|
| 58 |
+
|
| 59 |
+
logger.info("Application startup complete")
|
| 60 |
+
|
| 61 |
+
yield
|
| 62 |
+
|
| 63 |
+
# Shutdown
|
| 64 |
+
logger.info("Shutting down application...")
|
| 65 |
+
await close_database_connections()
|
| 66 |
+
logger.info("Application shutdown complete")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Create FastAPI application
|
| 70 |
+
app = FastAPI(
|
| 71 |
+
title="RAG Chatbot API",
|
| 72 |
+
description="Retrieval-Augmented Generation chatbot for humanoid robotics textbook",
|
| 73 |
+
version="1.0.0",
|
| 74 |
+
docs_url="/api/docs" if not settings.is_production else None,
|
| 75 |
+
redoc_url="/api/redoc" if not settings.is_production else None,
|
| 76 |
+
lifespan=lifespan
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Temporarily disabled for debugging
|
| 80 |
+
from src.api.middleware.logging_middleware import LoggingMiddleware
|
| 81 |
+
from src.api.middleware.rate_limit import limiter, rate_limit_exceeded_handler
|
| 82 |
+
|
| 83 |
+
# Add rate limiter to app state
|
| 84 |
+
# Temporarily disabled for debugging
|
| 85 |
+
# app.state.limiter = limiter
|
| 86 |
+
app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)
|
| 87 |
+
|
| 88 |
+
# Add logging middleware
|
| 89 |
+
# Temporarily disabled due to middleware stack build error
|
| 90 |
+
app.add_middleware(LoggingMiddleware)
|
| 91 |
+
|
| 92 |
+
# Add secure headers middleware
|
| 93 |
+
# Temporarily disabled for debugging
|
| 94 |
+
@app.middleware("http")
|
| 95 |
+
async def secure_headers_middleware(request: Request, call_next):
|
| 96 |
+
response = await call_next(request)
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
# If the Secure instance exposes an integration helper named "framework"
|
| 100 |
+
# that supports FastAPI/Starlette, use it.
|
| 101 |
+
if getattr(secure_headers, "framework", None) is not None:
|
| 102 |
+
fw = secure_headers.framework
|
| 103 |
+
# defensive: some versions expose attributes differently
|
| 104 |
+
if hasattr(fw, "fastapi"):
|
| 105 |
+
fw.fastapi(response)
|
| 106 |
+
elif hasattr(fw, "starlette"):
|
| 107 |
+
fw.starlette(response)
|
| 108 |
+
else:
|
| 109 |
+
# fallback to a generic method if one exists
|
| 110 |
+
if hasattr(secure_headers, "apply"):
|
| 111 |
+
secure_headers.apply(response)
|
| 112 |
+
elif hasattr(secure_headers, "add"):
|
| 113 |
+
secure_headers.add(response)
|
| 114 |
+
else:
|
| 115 |
+
raise AttributeError("Secure instance has no recognized integration methods")
|
| 116 |
+
else:
|
| 117 |
+
# library not integrated or missing; apply safe default headers manually
|
| 118 |
+
# These are conservative, common security headers.
|
| 119 |
+
response.headers.setdefault("X-Content-Type-Options", "nosniff")
|
| 120 |
+
response.headers.setdefault("X-Frame-Options", "DENY")
|
| 121 |
+
response.headers.setdefault("Referrer-Policy", "no-referrer")
|
| 122 |
+
response.headers.setdefault("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
| 123 |
+
response.headers.setdefault("Permissions-Policy", "geolocation=()")
|
| 124 |
+
response.headers.setdefault("X-XSS-Protection", "1; mode=block")
|
| 125 |
+
except Exception:
|
| 126 |
+
# log the failure but do not crash the request pipeline
|
| 127 |
+
logger.exception("Failed to apply secure headers")
|
| 128 |
+
|
| 129 |
+
return response
|
| 130 |
+
|
| 131 |
+
# CORS middleware configuration
|
| 132 |
+
# Temporarily disabled for debugging
|
| 133 |
+
app.add_middleware(
|
| 134 |
+
CORSMiddleware,
|
| 135 |
+
allow_origins=settings.allowed_origins,
|
| 136 |
+
allow_credentials=True,
|
| 137 |
+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
| 138 |
+
allow_headers=["*"],
|
| 139 |
+
expose_headers=["X-Request-ID"],
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Exception handlers
|
| 144 |
+
@app.exception_handler(Exception)
|
| 145 |
+
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
| 146 |
+
"""Global exception handler for unhandled errors
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
request: FastAPI request
|
| 150 |
+
exc: Exception that was raised
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
JSONResponse with error details
|
| 154 |
+
"""
|
| 155 |
+
logger.error(
|
| 156 |
+
f"Unhandled exception: {exc}",
|
| 157 |
+
extra={
|
| 158 |
+
"path": request.url.path,
|
| 159 |
+
"method": request.method,
|
| 160 |
+
"client": request.client.host if request.client else "unknown"
|
| 161 |
+
}
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return JSONResponse(
|
| 165 |
+
status_code=500,
|
| 166 |
+
content={
|
| 167 |
+
"error": "Internal server error",
|
| 168 |
+
"message": "An unexpected error occurred" if settings.is_production else str(exc)
|
| 169 |
+
}
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Include routers
|
| 174 |
+
app.include_router(health.router, prefix="/api")
|
| 175 |
+
app.include_router(auth.router, prefix="/api")
|
| 176 |
+
app.include_router(chat.router, prefix="/api")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# Root endpoint
|
| 180 |
+
@app.get("/", tags=["root"])
|
| 181 |
+
async def root() -> dict[str, str]:
|
| 182 |
+
"""Root endpoint
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Welcome message
|
| 186 |
+
"""
|
| 187 |
+
return {
|
| 188 |
+
"message": "RAG Chatbot API",
|
| 189 |
+
"status": "running",
|
| 190 |
+
"docs": "/api/docs" if not settings.is_production else "disabled"
|
| 191 |
+
}
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database models package"""
|
| 2 |
+
from src.models.user import User
|
| 3 |
+
from src.models.session import Session
|
| 4 |
+
from src.models.chat_message import ChatMessage
|
| 5 |
+
|
| 6 |
+
__all__ = ["User", "Session", "ChatMessage"]
|
src/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (399 Bytes). View file
|
|
|
src/models/__pycache__/chat_message.cpython-312.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
src/models/__pycache__/schemas.cpython-312.pyc
ADDED
|
Binary file (7.65 kB). View file
|
|
|
src/models/__pycache__/session.cpython-312.pyc
ADDED
|
Binary file (2.5 kB). View file
|
|
|
src/models/__pycache__/user.cpython-312.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
src/models/chat_message.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chat message model for storing conversation history
|
| 2 |
+
|
| 3 |
+
Represents a single question-answer exchange in the chatbot.
|
| 4 |
+
Aligns with data-model.md ChatMessage entity specification.
|
| 5 |
+
"""
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from uuid import UUID, uuid4
|
| 8 |
+
from enum import Enum
|
| 9 |
+
|
| 10 |
+
from sqlalchemy import Column, String, Text, TIMESTAMP, ForeignKey, text
|
| 11 |
+
from sqlalchemy.dialects.postgresql import UUID as PGUUID, JSONB
|
| 12 |
+
from sqlalchemy.orm import relationship
|
| 13 |
+
|
| 14 |
+
from src.config.database import Base
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChatUserRole(str, Enum):
|
| 18 |
+
"""Enum for the role of the chat message sender."""
|
| 19 |
+
USER = "user"
|
| 20 |
+
ASSISTANT = "assistant"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ChatMessage(Base):
|
| 24 |
+
"""ChatMessage model for storing conversation history
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
id: Unique message identifier (UUID)
|
| 28 |
+
user_id: Foreign key to User
|
| 29 |
+
thread_id: OpenAI Agents SDK thread identifier
|
| 30 |
+
role: Message sender (user or assistant)
|
| 31 |
+
content: Message text content
|
| 32 |
+
metadata: Additional context (JSONB)
|
| 33 |
+
created_at: Message creation timestamp
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
__tablename__ = "chat_messages"
|
| 37 |
+
|
| 38 |
+
id = Column(
|
| 39 |
+
PGUUID(as_uuid=True),
|
| 40 |
+
primary_key=True,
|
| 41 |
+
default=uuid4,
|
| 42 |
+
server_default="gen_random_uuid()"
|
| 43 |
+
)
|
| 44 |
+
user_id = Column(PGUUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
| 45 |
+
thread_id = Column(String(255), nullable=False)
|
| 46 |
+
role = Column(String(10), nullable=False) # CHECK (role IN ('user', 'assistant')) will be added in migration
|
| 47 |
+
content = Column(Text, nullable=False)
|
| 48 |
+
message_metadata = Column('metadata', JSONB, default=dict, server_default=text("'{}'::jsonb"), nullable=False)
|
| 49 |
+
created_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, server_default="NOW()")
|
| 50 |
+
|
| 51 |
+
# Relationships
|
| 52 |
+
user = relationship("User", back_populates="chat_messages")
|
| 53 |
+
|
| 54 |
+
def __repr__(self) -> str:
|
| 55 |
+
return f"<ChatMessage(id={self.id}, role={self.role}, thread_id={self.thread_id})>"
|
src/models/schemas.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic schemas for request/response validation
|
| 2 |
+
|
| 3 |
+
Provides type-safe data validation for API endpoints.
|
| 4 |
+
"""
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from uuid import UUID
|
| 7 |
+
from typing import Optional, Dict, Any
|
| 8 |
+
from pydantic import BaseModel, EmailStr, Field, field_validator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ============================================================================
|
| 12 |
+
# User Schemas
|
| 13 |
+
# ============================================================================
|
| 14 |
+
|
| 15 |
+
class UserCreate(BaseModel):
|
| 16 |
+
"""Schema for user registration request"""
|
| 17 |
+
email: EmailStr = Field(..., description="User's email address")
|
| 18 |
+
password: str = Field(..., min_length=8, max_length=128, description="User's password")
|
| 19 |
+
|
| 20 |
+
@field_validator("password")
|
| 21 |
+
@classmethod
|
| 22 |
+
def validate_password_strength(cls, v: str) -> str:
|
| 23 |
+
"""Validate password meets strength requirements"""
|
| 24 |
+
if not any(c.isupper() for c in v):
|
| 25 |
+
raise ValueError("Password must contain at least one uppercase letter")
|
| 26 |
+
if not any(c.islower() for c in v):
|
| 27 |
+
raise ValueError("Password must contain at least one lowercase letter")
|
| 28 |
+
if not any(c.isdigit() for c in v):
|
| 29 |
+
raise ValueError("Password must contain at least one digit")
|
| 30 |
+
if not any(c in "!@#$%^&*(),.?\":{}|<>" for c in v):
|
| 31 |
+
raise ValueError("Password must contain at least one special character")
|
| 32 |
+
return v
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class UserLogin(BaseModel):
|
| 36 |
+
"""Schema for user login request"""
|
| 37 |
+
email: EmailStr = Field(..., description="User's email address")
|
| 38 |
+
password: str = Field(..., min_length=1, max_length=128, description="User's password")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class UserResponse(BaseModel):
|
| 42 |
+
"""Schema for user data in responses"""
|
| 43 |
+
id: UUID
|
| 44 |
+
email: str
|
| 45 |
+
created_at: datetime
|
| 46 |
+
updated_at: datetime
|
| 47 |
+
|
| 48 |
+
model_config = {"from_attributes": True}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AuthResponse(BaseModel):
|
| 52 |
+
"""Schema for authentication response"""
|
| 53 |
+
user: UserResponse
|
| 54 |
+
message: str = "Authentication successful"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ============================================================================
|
| 58 |
+
# Session Schemas
|
| 59 |
+
# ============================================================================
|
| 60 |
+
|
| 61 |
+
class SessionResponse(BaseModel):
|
| 62 |
+
"""Schema for session data in responses"""
|
| 63 |
+
id: UUID
|
| 64 |
+
user_id: UUID
|
| 65 |
+
expires_at: datetime
|
| 66 |
+
created_at: datetime
|
| 67 |
+
|
| 68 |
+
model_config = {"from_attributes": True}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ============================================================================
|
| 72 |
+
# Chat Message Schemas
|
| 73 |
+
# ============================================================================
|
| 74 |
+
|
| 75 |
+
class ChatMessageCreate(BaseModel):
|
| 76 |
+
"""Schema for creating a new chat message"""
|
| 77 |
+
message: str = Field(..., min_length=1, max_length=10000, description="User's message content")
|
| 78 |
+
thread_id: Optional[str] = Field(None, min_length=1, max_length=255, description="OpenAI thread ID, optional for new threads")
|
| 79 |
+
query_mode: Optional[str] = Field(None, description="Query mode: 'full_book' or 'selection'")
|
| 80 |
+
selected_text: Optional[str] = Field(None, max_length=5000, description="Selected text for context queries")
|
| 81 |
+
|
| 82 |
+
@field_validator("query_mode")
|
| 83 |
+
@classmethod
|
| 84 |
+
def validate_query_mode(cls, v: Optional[str]) -> Optional[str]:
|
| 85 |
+
"""Validate query mode is one of allowed values"""
|
| 86 |
+
if v is not None and v not in ["full_book", "selection"]:
|
| 87 |
+
raise ValueError("query_mode must be 'full_book' or 'selection'")
|
| 88 |
+
return v
|
| 89 |
+
|
| 90 |
+
@field_validator("selected_text")
|
| 91 |
+
@classmethod
|
| 92 |
+
def validate_selected_text(cls, v: Optional[str], info: Any) -> Optional[str]:
|
| 93 |
+
"""Validate selected_text is present when query_mode is 'selection'"""
|
| 94 |
+
if info.data.get("query_mode") == "selection" and not v:
|
| 95 |
+
raise ValueError("selected_text is required when query_mode is 'selection'")
|
| 96 |
+
return v
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ChatMessageResponse(BaseModel):
|
| 100 |
+
"""Schema for chat message in responses"""
|
| 101 |
+
id: UUID
|
| 102 |
+
user_id: UUID
|
| 103 |
+
thread_id: str
|
| 104 |
+
role: str
|
| 105 |
+
content: str
|
| 106 |
+
metadata: Dict[str, Any] = Field(..., alias="message_metadata") # Use alias for Pydantic V2
|
| 107 |
+
created_at: datetime
|
| 108 |
+
|
| 109 |
+
model_config = {
|
| 110 |
+
"from_attributes": True,
|
| 111 |
+
"populate_by_name": True,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ChatResponse(BaseModel):
|
| 117 |
+
"""Schema for chat response with user and assistant messages"""
|
| 118 |
+
user_message: ChatMessageResponse
|
| 119 |
+
assistant_message: ChatMessageResponse
|
| 120 |
+
thread_id: str
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class ChatHistoryResponse(BaseModel):
|
| 124 |
+
"""Schema for thread history response"""
|
| 125 |
+
messages: list[ChatMessageResponse]
|
| 126 |
+
total: int
|
| 127 |
+
thread_id: str
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ============================================================================
|
| 131 |
+
# Generic Response Schemas
|
| 132 |
+
# ============================================================================
|
| 133 |
+
|
| 134 |
+
class MessageResponse(BaseModel):
|
| 135 |
+
"""Generic message response"""
|
| 136 |
+
message: str
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ErrorResponse(BaseModel):
|
| 140 |
+
"""Error response schema"""
|
| 141 |
+
error: str
|
| 142 |
+
message: str
|
| 143 |
+
details: Optional[Dict[str, Any]] = None
|
src/models/session.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Session model for authentication token management
|
| 2 |
+
|
| 3 |
+
Represents an active authentication session with JWT tokens and expiration.
|
| 4 |
+
Aligns with data-model.md Session entity specification.
|
| 5 |
+
"""
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from uuid import UUID, uuid4
|
| 8 |
+
from sqlalchemy import Column, String, TIMESTAMP, ForeignKey
|
| 9 |
+
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
| 10 |
+
from sqlalchemy.orm import relationship
|
| 11 |
+
|
| 12 |
+
from src.config.database import Base
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Session(Base):
|
| 16 |
+
"""Session model for JWT token management
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
id: Unique session identifier (UUID)
|
| 20 |
+
user_id: Foreign key to User
|
| 21 |
+
token_hash: Hashed JWT token (unique)
|
| 22 |
+
expires_at: Session expiration timestamp
|
| 23 |
+
created_at: Session creation timestamp
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
__tablename__ = "sessions"
|
| 27 |
+
|
| 28 |
+
id = Column(
|
| 29 |
+
PGUUID(as_uuid=True),
|
| 30 |
+
primary_key=True,
|
| 31 |
+
default=uuid4,
|
| 32 |
+
server_default="gen_random_uuid()"
|
| 33 |
+
)
|
| 34 |
+
user_id = Column(
|
| 35 |
+
PGUUID(as_uuid=True),
|
| 36 |
+
ForeignKey("users.id", ondelete="CASCADE"),
|
| 37 |
+
nullable=False,
|
| 38 |
+
index=True
|
| 39 |
+
)
|
| 40 |
+
token_hash = Column(String(255), unique=True, nullable=False, index=True)
|
| 41 |
+
expires_at = Column(TIMESTAMP, nullable=False, index=True)
|
| 42 |
+
created_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, server_default="NOW()")
|
| 43 |
+
|
| 44 |
+
# Relationships
|
| 45 |
+
user = relationship("User", back_populates="sessions")
|
| 46 |
+
|
| 47 |
+
def __repr__(self) -> str:
|
| 48 |
+
return f"<Session(id={self.id}, user_id={self.user_id}, expires_at={self.expires_at})>"
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def is_expired(self) -> bool:
|
| 52 |
+
"""Check if session has expired"""
|
| 53 |
+
return datetime.utcnow() > self.expires_at
|
src/models/user.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""User model for authentication and authorization
|
| 2 |
+
|
| 3 |
+
Represents an authenticated reader with account credentials.
|
| 4 |
+
Aligns with data-model.md User entity specification.
|
| 5 |
+
"""
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from uuid import UUID, uuid4
|
| 8 |
+
from sqlalchemy import Column, String, TIMESTAMP
|
| 9 |
+
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
| 10 |
+
from sqlalchemy.orm import relationship
|
| 11 |
+
|
| 12 |
+
from src.config.database import Base
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class User(Base):
|
| 16 |
+
"""User model for authentication
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
id: Unique user identifier (UUID)
|
| 20 |
+
email: User's email address (unique)
|
| 21 |
+
password_hash: Hashed password (managed by auth service)
|
| 22 |
+
created_at: Account creation timestamp
|
| 23 |
+
updated_at: Last account modification timestamp
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
__tablename__ = "users"
|
| 27 |
+
|
| 28 |
+
id = Column(
|
| 29 |
+
PGUUID(as_uuid=True),
|
| 30 |
+
primary_key=True,
|
| 31 |
+
default=uuid4,
|
| 32 |
+
server_default="gen_random_uuid()"
|
| 33 |
+
)
|
| 34 |
+
email = Column(String(255), unique=True, nullable=False, index=True)
|
| 35 |
+
password_hash = Column(String(255), nullable=False)
|
| 36 |
+
created_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, server_default="NOW()")
|
| 37 |
+
updated_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, server_default="NOW()")
|
| 38 |
+
|
| 39 |
+
# Relationships
|
| 40 |
+
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
| 41 |
+
chat_messages = relationship("ChatMessage", back_populates="user", cascade="all, delete-orphan")
|
| 42 |
+
|
| 43 |
+
def __repr__(self) -> str:
|
| 44 |
+
return f"<User(id={self.id}, email={self.email})>"
|
src/services/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Services package"""
|
| 2 |
+
from src.services.auth_service import AuthService
|
| 3 |
+
from src.services.vector_service import VectorService
|
| 4 |
+
from src.services.chat_service import ChatService
|
| 5 |
+
|
| 6 |
+
__all__ = ["AuthService", "VectorService", "ChatService"]
|
src/services/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (428 Bytes). View file
|
|
|
src/services/__pycache__/auth_service.cpython-312.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
src/services/__pycache__/chat_service.cpython-312.pyc
ADDED
|
Binary file (6.02 kB). View file
|
|
|
src/services/__pycache__/rag_service.cpython-312.pyc
ADDED
|
Binary file (9.74 kB). View file
|
|
|
src/services/__pycache__/vector_service.cpython-312.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
src/services/auth_service.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Authentication service for user management and session handling
|
| 2 |
+
|
| 3 |
+
Provides user registration, authentication, password hashing,
|
| 4 |
+
JWT token generation, and session management.
|
| 5 |
+
"""
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from uuid import UUID
|
| 9 |
+
import hashlib
|
| 10 |
+
from passlib.context import CryptContext
|
| 11 |
+
from jose import JWTError, jwt
|
| 12 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 13 |
+
from sqlalchemy import select, delete
|
| 14 |
+
|
| 15 |
+
from src.models.user import User
|
| 16 |
+
from src.models.session import Session
|
| 17 |
+
from src.config.settings import settings
|
| 18 |
+
from src.utils.logger import get_logger
|
| 19 |
+
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
# Password hashing configuration
|
| 23 |
+
# Using Argon2 (modern, memory-hard algorithm) with bcrypt fallback for existing passwords
|
| 24 |
+
pwd_context = CryptContext(schemes=["argon2", "bcrypt"], deprecated="auto")
|
| 25 |
+
|
| 26 |
+
# JWT configuration
|
| 27 |
+
ALGORITHM = "HS256"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AuthService:
|
| 31 |
+
"""Authentication service for user and session management"""
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def hash_password(password: str) -> str:
|
| 35 |
+
"""Hash a plain text password using Argon2
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
password: Plain text password (no length limitations with Argon2)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Hashed password
|
| 42 |
+
"""
|
| 43 |
+
# Argon2 is a modern, memory-hard hashing algorithm with no password length limits
|
| 44 |
+
return pwd_context.hash(password, scheme="argon2")
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 48 |
+
"""Verify a password against its hash
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
plain_password: Plain text password to verify
|
| 52 |
+
hashed_password: Hashed password to compare against
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
True if password matches, False otherwise
|
| 56 |
+
"""
|
| 57 |
+
# Passlib automatically handles both Argon2 and bcrypt hashes
|
| 58 |
+
# Works with the hash that was used (supports password migration)
|
| 59 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def generate_jwt_token(user_id: UUID) -> str:
|
| 63 |
+
"""Generate a JWT token for a user
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
user_id: User's UUID
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
JWT token string
|
| 70 |
+
"""
|
| 71 |
+
expires_delta = timedelta(days=settings.session_expiry_days)
|
| 72 |
+
expire = datetime.utcnow() + expires_delta
|
| 73 |
+
|
| 74 |
+
to_encode = {
|
| 75 |
+
"sub": str(user_id),
|
| 76 |
+
"exp": expire,
|
| 77 |
+
"iat": datetime.utcnow()
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
encoded_jwt = jwt.encode(
|
| 81 |
+
to_encode,
|
| 82 |
+
settings.better_auth_secret,
|
| 83 |
+
algorithm=ALGORITHM
|
| 84 |
+
)
|
| 85 |
+
return encoded_jwt
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def decode_jwt_token(token: str) -> Optional[dict]:
|
| 89 |
+
"""Decode and validate a JWT token
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
token: JWT token string
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Decoded token payload or None if invalid
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
payload = jwt.decode(
|
| 99 |
+
token,
|
| 100 |
+
settings.better_auth_secret,
|
| 101 |
+
algorithms=[ALGORITHM]
|
| 102 |
+
)
|
| 103 |
+
return payload
|
| 104 |
+
except JWTError as e:
|
| 105 |
+
logger.warning(f"JWT decode error: {e}")
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def hash_token(token: str) -> str:
|
| 110 |
+
"""Create SHA-256 hash of a token for storage
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
token: Token to hash
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Hex digest of token hash
|
| 117 |
+
"""
|
| 118 |
+
return hashlib.sha256(token.encode()).hexdigest()
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
async def create_user(
|
| 122 |
+
db: AsyncSession,
|
| 123 |
+
email: str,
|
| 124 |
+
password: str
|
| 125 |
+
) -> User:
|
| 126 |
+
"""Create a new user account
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
db: Database session
|
| 130 |
+
email: User's email address
|
| 131 |
+
password: Plain text password
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Created User instance
|
| 135 |
+
"""
|
| 136 |
+
password_hash = AuthService.hash_password(password)
|
| 137 |
+
|
| 138 |
+
user = User(
|
| 139 |
+
email=email.lower(), # Normalize email to lowercase
|
| 140 |
+
password_hash=password_hash
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
db.add(user)
|
| 144 |
+
await db.commit()
|
| 145 |
+
await db.refresh(user)
|
| 146 |
+
|
| 147 |
+
logger.info(f"User created: {user.id}")
|
| 148 |
+
return user
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
async def authenticate_user(
|
| 152 |
+
db: AsyncSession,
|
| 153 |
+
email: str,
|
| 154 |
+
password: str
|
| 155 |
+
) -> Optional[User]:
|
| 156 |
+
"""Authenticate a user by email and password
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
db: Database session
|
| 160 |
+
email: User's email address
|
| 161 |
+
password: Plain text password
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
User instance if authenticated, None otherwise
|
| 165 |
+
"""
|
| 166 |
+
# Query user by email
|
| 167 |
+
result = await db.execute(
|
| 168 |
+
select(User).where(User.email == email.lower())
|
| 169 |
+
)
|
| 170 |
+
user = result.scalar_one_or_none()
|
| 171 |
+
|
| 172 |
+
if user is None:
|
| 173 |
+
logger.warning(f"Authentication failed: user not found for email {email}")
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
if not AuthService.verify_password(password, user.password_hash):
|
| 177 |
+
logger.warning(f"Authentication failed: invalid password for user {user.id}")
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
logger.info(f"User authenticated: {user.id}")
|
| 181 |
+
return user
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
async def create_session(
|
| 185 |
+
db: AsyncSession,
|
| 186 |
+
user_id: UUID,
|
| 187 |
+
token: str
|
| 188 |
+
) -> Session:
|
| 189 |
+
"""Create a new session for a user
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
db: Database session
|
| 193 |
+
user_id: User's UUID
|
| 194 |
+
token: JWT token string
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Created Session instance
|
| 198 |
+
"""
|
| 199 |
+
token_hash = AuthService.hash_token(token)
|
| 200 |
+
expires_at = datetime.utcnow() + timedelta(days=settings.session_expiry_days)
|
| 201 |
+
|
| 202 |
+
session = Session(
|
| 203 |
+
user_id=user_id,
|
| 204 |
+
token_hash=token_hash,
|
| 205 |
+
expires_at=expires_at
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
db.add(session)
|
| 209 |
+
await db.commit()
|
| 210 |
+
await db.refresh(session)
|
| 211 |
+
|
| 212 |
+
logger.info(f"Session created: {session.id} for user {user_id}")
|
| 213 |
+
return session
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
async def validate_session(
|
| 217 |
+
db: AsyncSession,
|
| 218 |
+
token: str
|
| 219 |
+
) -> Optional[Session]:
|
| 220 |
+
"""Validate a session token
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
db: Database session
|
| 224 |
+
token: JWT token string
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Session instance if valid, None otherwise
|
| 228 |
+
"""
|
| 229 |
+
token_hash = AuthService.hash_token(token)
|
| 230 |
+
|
| 231 |
+
# Query session by token hash
|
| 232 |
+
result = await db.execute(
|
| 233 |
+
select(Session).where(Session.token_hash == token_hash)
|
| 234 |
+
)
|
| 235 |
+
session = result.scalar_one_or_none()
|
| 236 |
+
|
| 237 |
+
if session is None:
|
| 238 |
+
logger.warning("Session validation failed: session not found")
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
if session.is_expired:
|
| 242 |
+
logger.warning(f"Session validation failed: session {session.id} expired")
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
return session
|
| 246 |
+
|
| 247 |
+
@staticmethod
|
| 248 |
+
async def revoke_session(
|
| 249 |
+
db: AsyncSession,
|
| 250 |
+
token: str
|
| 251 |
+
) -> bool:
|
| 252 |
+
"""Revoke a session by token
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
db: Database session
|
| 256 |
+
token: JWT token string
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
True if session was revoked, False otherwise
|
| 260 |
+
"""
|
| 261 |
+
token_hash = AuthService.hash_token(token)
|
| 262 |
+
|
| 263 |
+
result = await db.execute(
|
| 264 |
+
delete(Session).where(Session.token_hash == token_hash)
|
| 265 |
+
)
|
| 266 |
+
await db.commit()
|
| 267 |
+
|
| 268 |
+
revoked = result.rowcount > 0
|
| 269 |
+
if revoked:
|
| 270 |
+
logger.info(f"Session revoked for token hash {token_hash[:16]}...")
|
| 271 |
+
|
| 272 |
+
return revoked
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
async def cleanup_expired_sessions(db: AsyncSession) -> int:
|
| 276 |
+
"""Remove all expired sessions from database
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
db: Database session
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Number of sessions deleted
|
| 283 |
+
"""
|
| 284 |
+
result = await db.execute(
|
| 285 |
+
delete(Session).where(Session.expires_at < datetime.utcnow())
|
| 286 |
+
)
|
| 287 |
+
await db.commit()
|
| 288 |
+
|
| 289 |
+
count = result.rowcount
|
| 290 |
+
if count > 0:
|
| 291 |
+
logger.info(f"Cleaned up {count} expired sessions")
|
| 292 |
+
|
| 293 |
+
return count
|
| 294 |
+
|
| 295 |
+
@staticmethod
|
| 296 |
+
async def get_user_by_id(
|
| 297 |
+
db: AsyncSession,
|
| 298 |
+
user_id: UUID
|
| 299 |
+
) -> Optional[User]:
|
| 300 |
+
"""Get a user by ID
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
db: Database session
|
| 304 |
+
user_id: User's UUID
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
User instance if found, None otherwise
|
| 308 |
+
"""
|
| 309 |
+
result = await db.execute(
|
| 310 |
+
select(User).where(User.id == user_id)
|
| 311 |
+
)
|
| 312 |
+
return result.scalar_one_or_none()
|