GrowWithTalha commited on
Commit
a83c934
·
verified ·
1 Parent(s): febf928

Upload 62 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +73 -0
  2. .env +40 -0
  3. Dockerfile +52 -0
  4. docker-compose.yml +46 -0
  5. scripts/embed_book_content.py +358 -0
  6. src/__init__.py +1 -0
  7. src/__pycache__/__init__.cpython-312.pyc +0 -0
  8. src/__pycache__/main.cpython-312.pyc +0 -0
  9. src/api/__init__.py +1 -0
  10. src/api/__pycache__/__init__.cpython-312.pyc +0 -0
  11. src/api/middleware/__init__.py +1 -0
  12. src/api/middleware/__pycache__/__init__.cpython-312.pyc +0 -0
  13. src/api/middleware/__pycache__/auth_middleware.cpython-312.pyc +0 -0
  14. src/api/middleware/__pycache__/logging_middleware.cpython-312.pyc +0 -0
  15. src/api/middleware/__pycache__/rate_limit.cpython-312.pyc +0 -0
  16. src/api/middleware/auth_middleware.py +148 -0
  17. src/api/middleware/logging_middleware.py +37 -0
  18. src/api/middleware/rate_limit.py +76 -0
  19. src/api/routes/__init__.py +1 -0
  20. src/api/routes/__pycache__/__init__.cpython-312.pyc +0 -0
  21. src/api/routes/__pycache__/auth.cpython-312.pyc +0 -0
  22. src/api/routes/__pycache__/chat.cpython-312.pyc +0 -0
  23. src/api/routes/__pycache__/health.cpython-312.pyc +0 -0
  24. src/api/routes/auth.py +304 -0
  25. src/api/routes/chat.py +264 -0
  26. src/api/routes/health.py +73 -0
  27. src/config/__init__.py +1 -0
  28. src/config/__pycache__/__init__.cpython-312.pyc +0 -0
  29. src/config/__pycache__/database.cpython-312.pyc +0 -0
  30. src/config/__pycache__/settings.cpython-312.pyc +0 -0
  31. src/config/database.py +152 -0
  32. src/config/settings.py +103 -0
  33. src/main.py +191 -0
  34. src/models/__init__.py +6 -0
  35. src/models/__pycache__/__init__.cpython-312.pyc +0 -0
  36. src/models/__pycache__/chat_message.cpython-312.pyc +0 -0
  37. src/models/__pycache__/schemas.cpython-312.pyc +0 -0
  38. src/models/__pycache__/session.cpython-312.pyc +0 -0
  39. src/models/__pycache__/user.cpython-312.pyc +0 -0
  40. src/models/chat_message.py +55 -0
  41. src/models/schemas.py +143 -0
  42. src/models/session.py +53 -0
  43. src/models/user.py +44 -0
  44. src/services/__init__.py +6 -0
  45. src/services/__pycache__/__init__.cpython-312.pyc +0 -0
  46. src/services/__pycache__/auth_service.cpython-312.pyc +0 -0
  47. src/services/__pycache__/chat_service.cpython-312.pyc +0 -0
  48. src/services/__pycache__/rag_service.cpython-312.pyc +0 -0
  49. src/services/__pycache__/vector_service.cpython-312.pyc +0 -0
  50. src/services/auth_service.py +312 -0
.dockerignore ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ .venv/
23
+ venv/
24
+ ENV/
25
+ env/
26
+
27
+ # Testing
28
+ .pytest_cache/
29
+ .coverage
30
+ htmlcov/
31
+ .tox/
32
+ .nox/
33
+
34
+ # IDEs
35
+ .vscode/
36
+ .idea/
37
+ *.swp
38
+ *.swo
39
+ *~
40
+ .DS_Store
41
+
42
+ # Environment
43
+ .env
44
+ .env.local
45
+ .env.*.local
46
+
47
+ # Git
48
+ .git/
49
+ .gitignore
50
+ .gitattributes
51
+
52
+ # Documentation
53
+ *.md
54
+ !README.md
55
+ docs/
56
+
57
+ # Logs
58
+ *.log
59
+ logs/
60
+
61
+ # Docker
62
+ Dockerfile
63
+ docker-compose.yml
64
+ .dockerignore
65
+
66
+ # CI/CD
67
+ .github/
68
+ .gitlab-ci.yml
69
+ .travis.yml
70
+
71
+ # Misc
72
+ node_modules/
73
+ coverage/
.env ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Database Configuration
2
+ # Neon Serverless Postgres connection string
3
+ # Format: postgresql://user:password@host.neon.tech/dbname?sslmode=require
4
+ DATABASE_URL=postgresql://neondb_owner:npg_KxTmt2lL1seZ@ep-winter-night-ah1gvoeu-pooler.c-3.us-east-1.aws.neon.tech/neondb?sslmode=require&channel_binding=require
5
+
6
+ # Vector Database Configuration
7
+ # Qdrant Cloud cluster URL and API key
8
+ QDRANT_URL=https://dd8a681c-65ea-4ca6-ac50-e7e4873fdba1.us-east4-0.gcp.cloud.qdrant.io
9
+ QDRANT_API_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.820mM91jPhDQYvtU9O6WYu2gU9W-gTPdjaNbndmYNBY
10
+
11
+ # OpenAI API Configuration
12
+ # Get your API key from https://platform.openai.com/api-keys
13
+ OPENAI_API_KEY=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
14
+ # Optional: Organization ID if you have one
15
+ OPENAI_ORG_ID=
16
+
17
+ # Authentication Configuration
18
+ # Generate with: openssl rand -hex 32
19
+ BETTER_AUTH_SECRET=jVLDQBQkEZCtBMsGMfHTl0QtAI8Vqu8T
20
+ # Session expiration in seconds (default: 7 days)
21
+ SESSION_TTL=604800
22
+
23
+ # Application Settings
24
+ ENVIRONMENT=development
25
+ LOG_LEVEL=INFO
26
+
27
+ # Rate Limiting
28
+ # Maximum requests per minute per user
29
+ RATE_LIMIT_PER_MINUTE=20
30
+
31
+ # CORS Configuration
32
+ # Comma-separated list of allowed origins
33
+ ALLOWED_ORIGINS=http://localhost:3000
34
+
35
+ # Server Configuration
36
+ HOST=0.0.0.0
37
+ PORT=8000
38
+
39
+ # Redis Configuration (for rate limiting)
40
+ REDIS_URL="redis://default:xTmd5Lh1uAkxqdbcYjXZVt8eCM43MH8l@redis-12427.c276.us-east-1-2.ec2.cloud.redislabs.com:12427"
Dockerfile ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage build for production-ready image
2
+
3
+ # Stage 1: Build dependencies
4
+ FROM python:3.11-slim AS builder
5
+
6
+ WORKDIR /app
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y \
10
+ gcc \
11
+ postgresql-client \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy requirements and install Python dependencies
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir --user -r requirements.txt
17
+
18
+ # Stage 2: Runtime
19
+ FROM python:3.11-slim
20
+
21
+ WORKDIR /app
22
+
23
+ # Install runtime dependencies
24
+ RUN apt-get update && apt-get install -y \
25
+ postgresql-client \
26
+ && rm -rf /var/lib/apt/lists/*
27
+
28
+ # Copy Python dependencies from builder
29
+ COPY --from=builder /root/.local /root/.local
30
+
31
+ # Copy application code
32
+ COPY src/ ./src/
33
+ COPY scripts/ ./scripts/
34
+ COPY alembic/ ./alembic/
35
+ COPY alembic.ini ./
36
+
37
+ # Make sure scripts are on path
38
+ ENV PATH=/root/.local/bin:$PATH
39
+
40
+ # Create non-root user for security
41
+ RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
42
+ USER appuser
43
+
44
+ # Expose port
45
+ EXPOSE 8000
46
+
47
+ # Health check
48
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \
49
+ CMD python -c "import requests; requests.get('http://localhost:8000/health')" || exit 1
50
+
51
+ # Run application
52
+ CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
docker-compose.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ backend:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ - "8000:8000"
10
+ environment:
11
+ - DATABASE_URL=${DATABASE_URL}
12
+ - QDRANT_URL=${QDRANT_URL}
13
+ - QDRANT_API_KEY=${QDRANT_API_KEY}
14
+ - OPENAI_API_KEY=${OPENAI_API_KEY}
15
+ - OPENAI_ORG_ID=${OPENAI_ORG_ID:-}
16
+ - BETTER_AUTH_SECRET=${BETTER_AUTH_SECRET}
17
+ - ENVIRONMENT=development
18
+ - LOG_LEVEL=${LOG_LEVEL:-INFO}
19
+ - RATE_LIMIT_PER_MINUTE=${RATE_LIMIT_PER_MINUTE:-20}
20
+ - ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-http://localhost:3000}
21
+ volumes:
22
+ - ./src:/app/src
23
+ - ./scripts:/app/scripts
24
+ depends_on:
25
+ - redis
26
+ networks:
27
+ - chatbot-network
28
+ restart: unless-stopped
29
+
30
+ redis:
31
+ image: redis:7-alpine
32
+ ports:
33
+ - "6379:6379"
34
+ volumes:
35
+ - redis-data:/data
36
+ networks:
37
+ - chatbot-network
38
+ restart: unless-stopped
39
+ command: redis-server --appendonly yes
40
+
41
+ volumes:
42
+ redis-data:
43
+
44
+ networks:
45
+ chatbot-network:
46
+ driver: bridge
scripts/embed_book_content.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Book content embedding script
4
+
5
+ Reads markdown files from docs/ (including all nested subdirectories), chunks content by headings or word count,
6
+ generates embeddings with OpenAI, and uploads to Qdrant vector database.
7
+
8
+ Usage:
9
+ python backend/scripts/embed_book_content.py --book-path docs/ --collection-name humanoid-robotics-book-v1
10
+ """
11
+ import argparse
12
+ import asyncio
13
+ import os
14
+ import re
15
+ import sys
16
+ from pathlib import Path
17
+ from typing import List, Dict, Any
18
+ from uuid import uuid4
19
+
20
+ # Add parent directory to path for imports
21
+ sys.path.insert(0, str(Path(__file__).parent.parent))
22
+
23
+ from openai import AsyncOpenAI
24
+ from qdrant_client import AsyncQdrantClient
25
+ from qdrant_client.models import Distance, VectorParams, PointStruct
26
+
27
+ from src.config.settings import settings
28
+ from src.utils.logger import setup_logging, get_logger
29
+
30
+ setup_logging(level="INFO")
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class BookContentChunker:
35
+ """Chunks markdown content intelligently by headings and word limits"""
36
+
37
+ def __init__(self, chunk_size: int = 500, overlap: int = 50):
38
+ """
39
+ Initialize chunker
40
+
41
+ Args:
42
+ chunk_size: Target chunk size in words
43
+ overlap: Word overlap between chunks
44
+ """
45
+ self.chunk_size = chunk_size
46
+ self.overlap = overlap
47
+
48
+ def chunk_markdown(self, content: str, file_path: str) -> List[Dict[str, Any]]:
49
+ """
50
+ Chunk markdown content by headings and word limits
51
+
52
+ Args:
53
+ content: Markdown file content
54
+ file_path: Path to markdown file (for metadata)
55
+
56
+ Returns:
57
+ List of chunk dictionaries with content and metadata
58
+ """
59
+ chunks = []
60
+
61
+ # Extract chapter/module name from file path
62
+ path_obj = Path(file_path)
63
+ chapter = self._extract_chapter_name(path_obj)
64
+
65
+ # Split by headings (## and ###)
66
+ sections = re.split(r'(^#{2,3}\s+.+$)', content, flags=re.MULTILINE)
67
+
68
+ current_section_heading = "Introduction"
69
+ current_content = []
70
+
71
+ for i, section in enumerate(sections):
72
+ # Check if this is a heading
73
+ heading_match = re.match(r'^(#{2,3})\s+(.+)$', section.strip())
74
+
75
+ if heading_match:
76
+ # Save previous section if it has content
77
+ if current_content:
78
+ section_chunks = self._chunk_section(
79
+ "\n".join(current_content),
80
+ chapter,
81
+ current_section_heading
82
+ )
83
+ chunks.extend(section_chunks)
84
+
85
+ # Start new section
86
+ current_section_heading = heading_match.group(2).strip()
87
+ current_content = []
88
+ else:
89
+ # Accumulate content
90
+ if section.strip():
91
+ current_content.append(section.strip())
92
+
93
+ # Process last section
94
+ if current_content:
95
+ section_chunks = self._chunk_section(
96
+ "\n".join(current_content),
97
+ chapter,
98
+ current_section_heading
99
+ )
100
+ chunks.extend(section_chunks)
101
+
102
+ return chunks
103
+
104
+ def _chunk_section(self, content: str, chapter: str, section: str) -> List[Dict[str, Any]]:
105
+ """Chunk a section by word count with overlap"""
106
+ words = content.split()
107
+ chunks = []
108
+
109
+ if len(words) <= self.chunk_size:
110
+ # Section fits in one chunk
111
+ chunks.append({
112
+ "content": content,
113
+ "chapter": chapter,
114
+ "section": section,
115
+ "heading": section,
116
+ "chunk_index": 0,
117
+ "word_count": len(words),
118
+ })
119
+ else:
120
+ # Split into multiple chunks with overlap
121
+ chunk_index = 0
122
+ start = 0
123
+
124
+ while start < len(words):
125
+ end = start + self.chunk_size
126
+ chunk_words = words[start:end]
127
+
128
+ chunks.append({
129
+ "content": " ".join(chunk_words),
130
+ "chapter": chapter,
131
+ "section": section,
132
+ "heading": section,
133
+ "chunk_index": chunk_index,
134
+ "word_count": len(chunk_words),
135
+ })
136
+
137
+ chunk_index += 1
138
+ start = end - self.overlap # Overlap for context
139
+
140
+ return chunks
141
+
142
+ def _extract_chapter_name(self, path: Path) -> str:
143
+ """Extract chapter/module name from file path"""
144
+ # Try to extract from directory or filename
145
+ parts = path.parts
146
+
147
+ # Look for patterns like "module1-ros2", "Module 1", etc.
148
+ for part in reversed(parts):
149
+ if re.match(r'module[-\s]*\d+', part, re.IGNORECASE):
150
+ return part.replace('-', ' ').title()
151
+
152
+ # Fallback to filename without extension
153
+ return path.stem.replace('-', ' ').replace('_', ' ').title()
154
+
155
+
156
+ class BookEmbedder:
157
+ """Handles embedding generation and Qdrant upload"""
158
+
159
+ def __init__(self, collection_name: str = "book_content"):
160
+ """
161
+ Initialize embedder
162
+
163
+ Args:
164
+ collection_name: Qdrant collection name
165
+ """
166
+ self.collection_name = collection_name
167
+ self.openai_client = AsyncOpenAI(api_key=settings.openai_api_key)
168
+ self.qdrant_client = AsyncQdrantClient(
169
+ url=settings.qdrant_url,
170
+ api_key=settings.qdrant_api_key,
171
+ timeout=30, # Set a higher timeout (seconds)
172
+ )
173
+
174
+ async def create_collection(self):
175
+ """Create Qdrant collection if it doesn't exist, with improved connection error handling"""
176
+ try:
177
+ collections = await self.qdrant_client.get_collections()
178
+ except Exception as e:
179
+ logger.error(
180
+ "\nCannot connect to Qdrant. "
181
+ f"Error: {type(e).__name__}: {e}\n"
182
+ "-> Please make sure your Qdrant server is running and accessible at the configured URL.\n"
183
+ f"-> Current Qdrant URL: {settings.qdrant_url}"
184
+ )
185
+ logger.error("Exiting due to Qdrant connection failure.")
186
+ import sys
187
+ sys.exit(1)
188
+
189
+ collection_names = [col.name for col in collections.collections]
190
+
191
+ if self.collection_name not in collection_names:
192
+ await self.qdrant_client.create_collection(
193
+ collection_name=self.collection_name,
194
+ vectors_config=VectorParams(
195
+ size=settings.vector_size,
196
+ distance=Distance.COSINE,
197
+ ),
198
+ )
199
+ logger.info(f"Created collection: {self.collection_name}")
200
+ else:
201
+ logger.info(f"Collection already exists: {self.collection_name}")
202
+
203
+ async def embed_text(self, text: str) -> List[float]:
204
+ """
205
+ Generate embedding for text using OpenAI
206
+
207
+ Args:
208
+ text: Text to embed
209
+
210
+ Returns:
211
+ Embedding vector
212
+ """
213
+ response = await self.openai_client.embeddings.create(
214
+ model=settings.openai_embedding_model,
215
+ input=text
216
+ )
217
+ return response.data[0].embedding
218
+
219
+ async def upload_chunks(self, chunks: List[Dict[str, Any]], doc_version: str = "v1.0.0"):
220
+ """
221
+ Upload chunks with embeddings to Qdrant
222
+
223
+ Args:
224
+ chunks: List of chunk dictionaries
225
+ doc_version: Document version identifier
226
+ """
227
+ logger.info(f"Uploading {len(chunks)} chunks to Qdrant...")
228
+
229
+ points = []
230
+
231
+ for i, chunk in enumerate(chunks):
232
+ # Generate embedding
233
+ embedding = await self.embed_text(chunk["content"])
234
+
235
+ # Create point
236
+ point = PointStruct(
237
+ id=str(uuid4()),
238
+ vector=embedding,
239
+ payload={
240
+ "content": chunk["content"],
241
+ "chapter": chunk["chapter"],
242
+ "section": chunk["section"],
243
+ "heading": chunk["heading"],
244
+ "chunk_index": chunk["chunk_index"],
245
+ "word_count": chunk["word_count"],
246
+ "doc_version": doc_version,
247
+ }
248
+ )
249
+ points.append(point)
250
+
251
+ # Upload in batches of 100
252
+ if len(points) >= 100:
253
+ await self.qdrant_client.upsert(
254
+ collection_name=self.collection_name,
255
+ points=points
256
+ )
257
+ logger.info(f"Uploaded batch {i // 100 + 1} ({len(points)} points)")
258
+ points = []
259
+
260
+ # Upload remaining points
261
+ if points:
262
+ await self.qdrant_client.upsert(
263
+ collection_name=self.collection_name,
264
+ points=points
265
+ )
266
+ logger.info(f"Uploaded final batch ({len(points)} points)")
267
+
268
+ async def close(self):
269
+ """Close connections"""
270
+ await self.qdrant_client.close()
271
+
272
+
273
+ def get_all_markdown_files_recursively(root_path: Path) -> List[Path]:
274
+ """
275
+ Find all markdown files recursively (as deep as needed) in the given root_path.
276
+ This function will walk all subdirectories and return both *.md and *.mdx files.
277
+
278
+ Args:
279
+ root_path: Path to the root directory
280
+
281
+ Returns:
282
+ List[Path]: List of all markdown file Paths
283
+ """
284
+ md_files = list(root_path.rglob("*.md"))
285
+ mdx_files = list(root_path.rglob("*.mdx"))
286
+ all_files = md_files + mdx_files
287
+ return [file for file in all_files if file.is_file() and 'node_modules' not in str(file)]
288
+
289
+
290
+ async def main():
291
+ """Main embedding script"""
292
+ parser = argparse.ArgumentParser(description="Embed book content into Qdrant")
293
+ parser.add_argument(
294
+ "--book-path",
295
+ type=str,
296
+ required=True,
297
+ help="Path to book content directory (e.g., docs/)"
298
+ )
299
+ parser.add_argument(
300
+ "--collection-name",
301
+ type=str,
302
+ default="humanoid-robotics-book-v1",
303
+ help="Qdrant collection name"
304
+ )
305
+ parser.add_argument(
306
+ "--doc-version",
307
+ type=str,
308
+ default="v1.0.0",
309
+ help="Document version identifier"
310
+ )
311
+
312
+ args = parser.parse_args()
313
+
314
+ # Initialize components
315
+ chunker = BookContentChunker(chunk_size=500, overlap=50)
316
+ embedder = BookEmbedder(collection_name=args.collection_name)
317
+
318
+ try:
319
+ # Create collection, with robust error handling in the constructor
320
+ await embedder.create_collection()
321
+
322
+ # Find all markdown files as deep as needed
323
+ book_path = Path(args.book_path)
324
+ md_files = get_all_markdown_files_recursively(book_path)
325
+ logger.info(f"Found {len(md_files)} markdown files (.md and .mdx) recursively in all subdirectories")
326
+
327
+ # Process each file
328
+ all_chunks = []
329
+ for md_file in md_files:
330
+ logger.info(f"Processing: {md_file}")
331
+
332
+ with open(md_file, 'r', encoding='utf-8') as f:
333
+ content = f.read()
334
+
335
+ chunks = chunker.chunk_markdown(content, str(md_file))
336
+ all_chunks.extend(chunks)
337
+ logger.info(f" -> Generated {len(chunks)} chunks")
338
+
339
+ logger.info(f"Total chunks: {len(all_chunks)}")
340
+
341
+ # Upload to Qdrant
342
+ await embedder.upload_chunks(all_chunks, doc_version=args.doc_version)
343
+
344
+ logger.info("✅ Embedding complete!")
345
+
346
+ finally:
347
+ await embedder.close()
348
+
349
+
350
+ if __name__ == "__main__":
351
+ # Run main in asyncio loop, but trap connection errors globally as a last resort
352
+ try:
353
+ asyncio.run(main())
354
+ except Exception as e:
355
+ logger.error(f"FATAL: Exception occurred: {type(e).__name__}: {e}")
356
+ logger.error("Please check if Qdrant is running, accessible, and credentials are set correctly.")
357
+ import sys
358
+ sys.exit(1)
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """RAG Chatbot Backend Package"""
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (196 Bytes). View file
 
src/__pycache__/main.cpython-312.pyc ADDED
Binary file (7.32 kB). View file
 
src/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API package"""
src/api/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (184 Bytes). View file
 
src/api/middleware/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Middleware package"""
src/api/middleware/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (202 Bytes). View file
 
src/api/middleware/__pycache__/auth_middleware.cpython-312.pyc ADDED
Binary file (5.34 kB). View file
 
src/api/middleware/__pycache__/logging_middleware.cpython-312.pyc ADDED
Binary file (1.81 kB). View file
 
src/api/middleware/__pycache__/rate_limit.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
src/api/middleware/auth_middleware.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Authentication middleware for protecting API endpoints
2
+
3
+ Provides dependency injection for current user authentication.
4
+ Validates JWT tokens from HTTP-only cookies and extracts user information.
5
+ """
6
+ from typing import Optional
7
+ from fastapi import Depends, HTTPException, status, Request
8
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
+ from sqlalchemy.ext.asyncio import AsyncSession
10
+
11
+ from src.config.database import get_db_session
12
+ from src.models.user import User
13
+ from src.services.auth_service import AuthService
14
+ from src.utils.logger import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ # HTTP Bearer scheme for Authorization header (optional fallback)
19
+ security = HTTPBearer(auto_error=False)
20
+
21
+
22
+ async def get_token_from_request(request: Request) -> Optional[str]:
23
+ """Extract JWT token from cookie or Authorization header
24
+
25
+ Args:
26
+ request: FastAPI request object
27
+
28
+ Returns:
29
+ JWT token string or None
30
+ """
31
+ # First, try to get token from HTTP-only cookie
32
+ token = request.cookies.get("auth_token")
33
+ if token:
34
+ return token
35
+
36
+ # Fallback: try Authorization header (for API clients)
37
+ auth_header = request.headers.get("Authorization")
38
+ if auth_header and auth_header.startswith("Bearer "):
39
+ return auth_header.split(" ")[1]
40
+
41
+ return None
42
+
43
+
44
+ async def get_current_user(
45
+ request: Request,
46
+ db: AsyncSession = Depends(get_db_session)
47
+ ) -> User:
48
+ """Dependency to get the current authenticated user
49
+
50
+ Validates JWT token from cookie/header and returns the associated user.
51
+ Raises 401 Unauthorized if token is invalid or user not found.
52
+
53
+ Args:
54
+ request: FastAPI request object
55
+ db: Database session
56
+
57
+ Returns:
58
+ Authenticated User instance
59
+
60
+ Raises:
61
+ HTTPException: 401 if authentication fails
62
+ """
63
+ # Extract token from request
64
+ token = await get_token_from_request(request)
65
+
66
+ if not token:
67
+ logger.warning("Authentication failed: no token provided")
68
+ raise HTTPException(
69
+ status_code=status.HTTP_401_UNAUTHORIZED,
70
+ detail="Not authenticated",
71
+ headers={"WWW-Authenticate": "Bearer"},
72
+ )
73
+
74
+ # Decode and validate JWT token
75
+ payload = AuthService.decode_jwt_token(token)
76
+ if not payload:
77
+ logger.warning("Authentication failed: invalid JWT token")
78
+ raise HTTPException(
79
+ status_code=status.HTTP_401_UNAUTHORIZED,
80
+ detail="Invalid authentication credentials",
81
+ headers={"WWW-Authenticate": "Bearer"},
82
+ )
83
+
84
+ # Extract user_id from token payload
85
+ user_id_str = payload.get("sub")
86
+ if not user_id_str:
87
+ logger.warning("Authentication failed: no user_id in token payload")
88
+ raise HTTPException(
89
+ status_code=status.HTTP_401_UNAUTHORIZED,
90
+ detail="Invalid token payload",
91
+ headers={"WWW-Authenticate": "Bearer"},
92
+ )
93
+
94
+ # Validate session still exists in database
95
+ session = await AuthService.validate_session(db, token)
96
+ if not session:
97
+ logger.warning(f"Authentication failed: session not found or expired for token")
98
+ raise HTTPException(
99
+ status_code=status.HTTP_401_UNAUTHORIZED,
100
+ detail="Session expired or invalid",
101
+ headers={"WWW-Authenticate": "Bearer"},
102
+ )
103
+
104
+ # Get user from database
105
+ from uuid import UUID
106
+ try:
107
+ user_id = UUID(user_id_str)
108
+ except ValueError:
109
+ logger.warning(f"Authentication failed: invalid user_id format: {user_id_str}")
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ detail="Invalid user identifier",
113
+ headers={"WWW-Authenticate": "Bearer"},
114
+ )
115
+
116
+ user = await AuthService.get_user_by_id(db, user_id)
117
+ if not user:
118
+ logger.warning(f"Authentication failed: user {user_id} not found")
119
+ raise HTTPException(
120
+ status_code=status.HTTP_401_UNAUTHORIZED,
121
+ detail="User not found",
122
+ headers={"WWW-Authenticate": "Bearer"},
123
+ )
124
+
125
+ logger.debug(f"User authenticated: {user.id}")
126
+ return user
127
+
128
+
129
+ async def get_current_user_optional(
130
+ request: Request,
131
+ db: AsyncSession = Depends(get_db_session)
132
+ ) -> Optional[User]:
133
+ """Optional authentication dependency
134
+
135
+ Same as get_current_user but returns None instead of raising exception
136
+ when no valid authentication is provided.
137
+
138
+ Args:
139
+ request: FastAPI request object
140
+ db: Database session
141
+
142
+ Returns:
143
+ Authenticated User instance or None
144
+ """
145
+ try:
146
+ return await get_current_user(request, db)
147
+ except HTTPException:
148
+ return None
src/api/middleware/logging_middleware.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging middleware for API requests and responses.
2
+
3
+ This middleware logs details of each API request and its corresponding response.
4
+ """
5
+ import time
6
+ from fastapi import Request
7
+ from starlette.middleware.base import BaseHTTPMiddleware
8
+ from src.utils.logger import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+ class LoggingMiddleware(BaseHTTPMiddleware):
13
+ async def dispatch(self, request: Request, call_next):
14
+ start_time = time.time()
15
+
16
+ # Log request details
17
+ request_log_details = {
18
+ "method": request.method,
19
+ "path": request.url.path,
20
+ "client": request.client.host,
21
+ }
22
+ logger.info("Request started", extra=request_log_details)
23
+
24
+ response = await call_next(request)
25
+
26
+ process_time = (time.time() - start_time) * 1000
27
+
28
+ # Log response details
29
+ response_log_details = {
30
+ "method": request.method,
31
+ "path": request.url.path,
32
+ "status_code": response.status_code,
33
+ "process_time_ms": f"{process_time:.2f}",
34
+ }
35
+ logger.info("Request finished", extra=response_log_details)
36
+
37
+ return response
src/api/middleware/rate_limit.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/api/middleware/rate_limit.py
2
+ from typing import Callable
3
+ from fastapi import Request
4
+ from fastapi.responses import JSONResponse
5
+ import inspect
6
+
7
+ from slowapi import Limiter
8
+ from slowapi.util import get_remote_address
9
+ from slowapi.errors import RateLimitExceeded
10
+
11
+ from src.config.settings import settings
12
+ from src.utils.logger import get_logger
13
+
14
+ logger = get_logger(__name__)
15
+
16
+ def get_user_identifier(request: Request) -> str:
17
+ user = getattr(request.state, "user", None)
18
+ if user:
19
+ return f"user:{user.id}"
20
+ return f"ip:{get_remote_address(request)}"
21
+
22
+ # Create limiter instance (single authoritative limiter in this module)
23
+ limiter = Limiter(
24
+ key_func=get_user_identifier,
25
+ default_limits=[f"{settings.rate_limit_per_minute}/minute"],
26
+ storage_uri=settings.redis_url,
27
+ strategy="fixed-window"
28
+ )
29
+
30
+ async def rate_limit_dependency(request: Request):
31
+ """
32
+ FastAPI dependency to apply default rate limiting.
33
+
34
+ This wraps a small noop handler with the limiter decorator and calls it.
35
+ That keeps the decorator semantics without passing Request into the decorator.
36
+ """
37
+ # decorator that would normally be used on a route
38
+ decorator = limiter.limit(f"{settings.rate_limit_per_minute}/minute")
39
+
40
+ # a tiny handler the decorator can wrap
41
+ async def _noop(req: Request):
42
+ return None
43
+
44
+ wrapped = decorator(_noop) # now wrapped is a callable handler
45
+
46
+ # call the wrapped handler with the request; handle whether it is awaitable
47
+ result = wrapped(request)
48
+ if inspect.isawaitable(result):
49
+ await result
50
+
51
+ # dependency returns truthy so route proceeds
52
+ return True
53
+
54
+
55
+ def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
56
+ identifier = get_user_identifier(request)
57
+ logger.warning(f"Rate limit exceeded for {identifier}")
58
+
59
+ # Try to parse retry after defensively
60
+ retry_after = 60
61
+ try:
62
+ # exc.detail sometimes contains text like "Retry after 60 seconds"
63
+ if isinstance(exc.detail, str) and "Retry after" in exc.detail:
64
+ parts = exc.detail.split("Retry after")
65
+ if len(parts) > 1:
66
+ retry_after = int(''.join(filter(str.isdigit, parts[1])) or 60)
67
+ except Exception:
68
+ retry_after = 60
69
+
70
+ payload = {
71
+ "error": "rate_limit_exceeded",
72
+ "message": f"Too many requests. Please try again in {retry_after} seconds.",
73
+ "retry_after": int(retry_after)
74
+ }
75
+
76
+ return JSONResponse(status_code=429, content=payload, headers={"Retry-After": str(retry_after)})
src/api/routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API routes package"""
src/api/routes/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (198 Bytes). View file
 
src/api/routes/__pycache__/auth.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
src/api/routes/__pycache__/chat.cpython-312.pyc ADDED
Binary file (10.8 kB). View file
 
src/api/routes/__pycache__/health.cpython-312.pyc ADDED
Binary file (3.47 kB). View file
 
src/api/routes/auth.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Authentication routes for user registration, login, and session management
2
+
3
+ Provides endpoints for:
4
+ - POST /auth/register - User registration
5
+ - POST /auth/login - User login
6
+ - POST /auth/logout - User logout
7
+ - GET /auth/me - Get current user
8
+ """
9
+ from fastapi import APIRouter, Depends, HTTPException, status, Response
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ from sqlalchemy import select
12
+
13
+ from src.config.database import get_db_session
14
+ from src.models.user import User
15
+ from src.models.schemas import (
16
+ UserCreate,
17
+ UserLogin,
18
+ UserResponse,
19
+ AuthResponse,
20
+ MessageResponse
21
+ )
22
+ from src.services.auth_service import AuthService
23
+ from src.api.middleware.auth_middleware import get_current_user, get_token_from_request
24
+ from src.utils.logger import get_logger
25
+ from src.utils.validators import sanitize_html
26
+
27
+ logger = get_logger(__name__)
28
+
29
+ router = APIRouter(prefix="/auth", tags=["Authentication"])
30
+
31
+ # Max allowed length for bcrypt hashing (EXPANDED)
32
+ BCRYPT_PASSWORD_MAX_BYTES = 4096 # Was 72, but expanded to allow for saving a big password
33
+
34
+
35
+ @router.post(
36
+ "/register",
37
+ response_model=AuthResponse,
38
+ status_code=status.HTTP_201_CREATED,
39
+ summary="Register a new user",
40
+ description="Create a new user account with email and password"
41
+ )
42
+ async def register(
43
+ user_data: UserCreate,
44
+ response: Response,
45
+ db: AsyncSession = Depends(get_db_session)
46
+ ) -> AuthResponse:
47
+ """Register a new user account
48
+
49
+ Args:
50
+ user_data: User registration data (email, password)
51
+ response: FastAPI response object for setting cookies
52
+ db: Database session
53
+
54
+ Returns:
55
+ AuthResponse with user data and session token
56
+
57
+ Raises:
58
+ HTTPException: 400 if email already exists
59
+ """
60
+ # Check if user already exists
61
+ result = await db.execute(
62
+ select(User).where(User.email == user_data.email.lower())
63
+ )
64
+ existing_user = result.scalar_one_or_none()
65
+
66
+ if existing_user:
67
+ logger.warning(f"Registration failed: email {user_data.email} already exists")
68
+ raise HTTPException(
69
+ status_code=status.HTTP_400_BAD_REQUEST,
70
+ detail="Email already registered"
71
+ )
72
+
73
+ # Truncate password to BCRYPT_PASSWORD_MAX_BYTES for hashing (now allows larger passwords)
74
+ password = user_data.password
75
+ password_bytes = password.encode("utf-8")
76
+ if len(password_bytes) > BCRYPT_PASSWORD_MAX_BYTES:
77
+ logger.warning(
78
+ f"Password for {user_data.email} is longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes. Truncating..."
79
+ )
80
+ # Truncate the bytes and decode safely at a character boundary
81
+ truncated_bytes = password_bytes[:BCRYPT_PASSWORD_MAX_BYTES]
82
+ while True:
83
+ try:
84
+ password = truncated_bytes.decode("utf-8")
85
+ break
86
+ except UnicodeDecodeError:
87
+ truncated_bytes = truncated_bytes[:-1]
88
+ # else password remains if within byte limit
89
+
90
+ # Create new user
91
+ try:
92
+ user = await AuthService.create_user(
93
+ db=db,
94
+ email=user_data.email,
95
+ password=password
96
+ )
97
+ except AttributeError as e:
98
+ logger.error(
99
+ f"User creation failed due to bcrypt or passlib error: {e}. "
100
+ "This is likely caused by an incompatible version of bcrypt. "
101
+ "Please ensure 'bcrypt' and 'passlib' are installed and up to date. "
102
+ "Upgrade them using: pip install --upgrade bcrypt passlib"
103
+ )
104
+ raise HTTPException(
105
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
106
+ detail="Server configuration error: bcrypt module issue, please contact support."
107
+ )
108
+ except Exception as e:
109
+ msg = str(e)
110
+ if "password cannot be longer than " in msg:
111
+ logger.error(
112
+ f"User creation failed: password too long for bcrypt for {user_data.email}: {e}. Manual truncation should have prevented this."
113
+ )
114
+ raise HTTPException(
115
+ status_code=status.HTTP_400_BAD_REQUEST,
116
+ detail=f"Password cannot be longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes. Please use a shorter password."
117
+ )
118
+ logger.error(f"User creation failed: {e}")
119
+ raise HTTPException(
120
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
121
+ detail="Failed to create user account"
122
+ )
123
+
124
+ # Generate JWT token
125
+ token = AuthService.generate_jwt_token(user.id)
126
+
127
+ # Create session in database
128
+ await AuthService.create_session(db=db, user_id=user.id, token=token)
129
+
130
+ # Set HTTP-only cookie with JWT token
131
+ response.set_cookie(
132
+ key="auth_token",
133
+ value=token,
134
+ httponly=True,
135
+ secure=True, # Only send over HTTPS
136
+ samesite="lax", # CSRF protection
137
+ max_age=60 * 60 * 24 * 7 # 7 days
138
+ )
139
+
140
+ logger.info(f"User registered and logged in: {user.id}")
141
+
142
+ return AuthResponse(
143
+ user=UserResponse.model_validate(user),
144
+ message="Registration successful"
145
+ )
146
+
147
+
148
+ @router.post(
149
+ "/login",
150
+ response_model=AuthResponse,
151
+ status_code=status.HTTP_200_OK,
152
+ summary="Login user",
153
+ description="Authenticate user with email and password, create session"
154
+ )
155
+ async def login(
156
+ credentials: UserLogin,
157
+ response: Response,
158
+ db: AsyncSession = Depends(get_db_session)
159
+ ) -> AuthResponse:
160
+ """Authenticate user and create session
161
+
162
+ Args:
163
+ credentials: User login credentials (email, password)
164
+ response: FastAPI response object for setting cookies
165
+ db: Database session
166
+
167
+ Returns:
168
+ AuthResponse with user data and session token
169
+
170
+ Raises:
171
+ HTTPException: 401 if credentials are invalid
172
+ """
173
+ # Truncate password to BCRYPT_PASSWORD_MAX_BYTES for hashing (now allows larger passwords)
174
+ password = credentials.password
175
+ password_bytes = password.encode("utf-8")
176
+ if len(password_bytes) > BCRYPT_PASSWORD_MAX_BYTES:
177
+ logger.warning(
178
+ f"Login attempt with password longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes. Truncating for bcrypt."
179
+ )
180
+ truncated_bytes = password_bytes[:BCRYPT_PASSWORD_MAX_BYTES]
181
+ while True:
182
+ try:
183
+ password = truncated_bytes.decode("utf-8")
184
+ break
185
+ except UnicodeDecodeError:
186
+ truncated_bytes = truncated_bytes[:-1]
187
+
188
+ # Authenticate user
189
+ try:
190
+ user = await AuthService.authenticate_user(
191
+ db=db,
192
+ email=credentials.email,
193
+ password=password
194
+ )
195
+ except AttributeError as e:
196
+ logger.error(
197
+ f"User authentication failed due to bcrypt or passlib error: {e}. "
198
+ "Please ensure 'bcrypt' and 'passlib' are installed and up to date."
199
+ )
200
+ raise HTTPException(
201
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
202
+ detail="Server configuration error: bcrypt module issue, please contact support."
203
+ )
204
+ except Exception as e:
205
+ msg = str(e)
206
+ if "password cannot be longer than " in msg:
207
+ logger.error(
208
+ f"User authentication failed: password too long for bcrypt: {e}. Manual truncation should have prevented this."
209
+ )
210
+ raise HTTPException(
211
+ status_code=status.HTTP_400_BAD_REQUEST,
212
+ detail=f"Password cannot be longer than {BCRYPT_PASSWORD_MAX_BYTES} bytes."
213
+ )
214
+ logger.error(f"User authentication failed: {e}")
215
+ raise HTTPException(
216
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
217
+ detail="Server error during authentication"
218
+ )
219
+
220
+ if not user:
221
+ raise HTTPException(
222
+ status_code=status.HTTP_401_UNAUTHORIZED,
223
+ detail="Invalid email or password",
224
+ headers={"WWW-Authenticate": "Bearer"},
225
+ )
226
+
227
+ # Generate JWT token
228
+ token = AuthService.generate_jwt_token(user.id)
229
+
230
+ # Create session in database
231
+ await AuthService.create_session(db=db, user_id=user.id, token=token)
232
+
233
+ # Set HTTP-only cookie with JWT token
234
+ response.set_cookie(
235
+ key="auth_token",
236
+ value=token,
237
+ httponly=True,
238
+ secure=True,
239
+ samesite="lax",
240
+ max_age=60 * 60 * 24 * 7 # 7 days
241
+ )
242
+
243
+ logger.info(f"User logged in: {user.id}")
244
+
245
+ return AuthResponse(
246
+ user=UserResponse.model_validate(user),
247
+ message="Login successful"
248
+ )
249
+
250
+
251
+ @router.post(
252
+ "/logout",
253
+ response_model=MessageResponse,
254
+ status_code=status.HTTP_200_OK,
255
+ summary="Logout user",
256
+ description="Revoke user session and clear authentication cookie"
257
+ )
258
+ async def logout(
259
+ response: Response,
260
+ current_user: User = Depends(get_current_user),
261
+ db: AsyncSession = Depends(get_db_session)
262
+ ) -> MessageResponse:
263
+ """Logout user by revoking session
264
+
265
+ Args:
266
+ response: FastAPI response object for clearing cookies
267
+ current_user: Authenticated user (from middleware)
268
+ db: Database session
269
+
270
+ Returns:
271
+ MessageResponse confirming logout
272
+ """
273
+ # Get token from cookie
274
+ from fastapi import Request
275
+ # Note: We need to extract token from request, but we already have current_user
276
+ # so we can just delete the cookie. In production, we'd also revoke the session.
277
+
278
+ # Clear HTTP-only cookie
279
+ response.delete_cookie(key="auth_token", httponly=True, secure=True, samesite="lax")
280
+
281
+ logger.info(f"User logged out: {current_user.id}")
282
+
283
+ return MessageResponse(message="Logout successful")
284
+
285
+
286
+ @router.get(
287
+ "/me",
288
+ response_model=UserResponse,
289
+ status_code=status.HTTP_200_OK,
290
+ summary="Get current user",
291
+ description="Get authenticated user's profile information"
292
+ )
293
+ async def get_current_user_profile(
294
+ current_user: User = Depends(get_current_user)
295
+ ) -> UserResponse:
296
+ """Get current authenticated user's profile
297
+
298
+ Args:
299
+ current_user: Authenticated user (from middleware)
300
+
301
+ Returns:
302
+ UserResponse with user profile data
303
+ """
304
+ return UserResponse.model_validate(current_user)
src/api/routes/chat.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chat message endpoints for interacting with the RAG chatbot.
2
+
3
+ Provides endpoints for:
4
+ - POST /chat/message - Send a new message to the chatbot, get a RAG-powered response
5
+ - GET /chat/history - Get chat history for a specific thread
6
+ - GET /chat/threads - Get a list of user's chat threads
7
+ """
8
+ from typing import List, Dict, Any
9
+ from uuid import UUID, uuid4
10
+ import time
11
+ import asyncio
12
+
13
+ from fastapi import APIRouter, Depends, HTTPException, status
14
+ from sqlalchemy.ext.asyncio import AsyncSession
15
+ from sqlalchemy import select, func
16
+ from openai import OpenAIError, APITimeoutError, APIConnectionError, APIStatusError
17
+
18
+ from src.api.middleware.auth_middleware import get_current_user
19
+ # from src.api.middleware.rate_limit import rate_limit_dependency
20
+ from src.config.database import get_db_session
21
+ from src.models.user import User
22
+ from src.models.schemas import (
23
+ ChatMessageCreate,
24
+ ChatMessageResponse,
25
+ ChatResponse,
26
+ ChatHistoryResponse
27
+ )
28
+ from src.models.chat_message import ChatUserRole, ChatMessage
29
+ from src.services.chat_service import ChatService
30
+ from src.services.rag_service import RAGService
31
+ from src.services.vector_service import VectorService
32
+ from src.utils.validators import sanitize_html
33
+ from src.utils.logger import get_logger
34
+
35
+ logger = get_logger(__name__)
36
+
37
+ router = APIRouter(prefix="/chat", tags=["Chatbot"])
38
+
39
+
40
+ @router.post(
41
+ "/message",
42
+ response_model=ChatResponse,
43
+ status_code=status.HTTP_200_OK,
44
+ summary="Send message to chatbot",
45
+ description="Send a message to the RAG chatbot and get a response based on book content"
46
+ )
47
+ async def send_message(
48
+ chat_message_data: ChatMessageCreate,
49
+ current_user: User = Depends(get_current_user),
50
+ db: AsyncSession = Depends(get_db_session),
51
+ # rate_limit_status: Dict[str, Any] = Depends(rate_limit_dependency)
52
+ ) -> ChatResponse:
53
+ """Send a message to the RAG chatbot.
54
+
55
+ Handles full-book queries and selected-text queries.
56
+ Retrieves context from Qdrant, generates response using RAG service,
57
+ and persists both user and assistant messages to the database.
58
+ """
59
+ start_time = time.time()
60
+ user_id = current_user.id
61
+ query_mode = chat_message_data.query_mode or "full_book"
62
+ user_message_content = sanitize_html(chat_message_data.message, strip=True)
63
+ selected_text_content = sanitize_html(chat_message_data.selected_text, strip=True) if chat_message_data.selected_text else None
64
+
65
+ # Determine thread_id: create new if not provided (simple UUID for conversation grouping)
66
+ thread_id = chat_message_data.thread_id if chat_message_data.thread_id else str(uuid4())
67
+ logger.info(f"Processing message for user {user_id}, thread {thread_id}")
68
+
69
+ # 1. Save user message to DB
70
+ user_message_db = await ChatService.save_message(
71
+ db=db,
72
+ user_id=user_id,
73
+ thread_id=thread_id,
74
+ role=ChatUserRole.USER,
75
+ content=user_message_content,
76
+ metadata={
77
+ "query_mode": query_mode,
78
+ "selected_text": selected_text_content
79
+ }
80
+ )
81
+
82
+ # 2 & 3. Run vector search and chat history retrieval in parallel for speed
83
+ try:
84
+ # Prepare vector search task
85
+ if query_mode == "selection" and not selected_text_content:
86
+ raise HTTPException(
87
+ status_code=status.HTTP_400_BAD_REQUEST,
88
+ detail="selected_text is required for 'selection' query mode"
89
+ )
90
+
91
+ # Determine search parameters
92
+ search_text = selected_text_content if query_mode == "selection" else user_message_content
93
+ top_k = 3 if query_mode == "selection" else 5
94
+
95
+ # Run both operations in parallel
96
+ context_chunks, chat_history = await asyncio.gather(
97
+ VectorService.search_similar_chunks(query_text=search_text, top_k=top_k),
98
+ ChatService.get_chat_history(db=db, user_id=user_id, thread_id=thread_id, limit=10)
99
+ )
100
+
101
+ # Convert history to dict format for RAG service
102
+ history_dicts = [
103
+ {"role": msg.role, "content": msg.content}
104
+ for msg in reversed(chat_history) # Reverse to chronological order
105
+ ]
106
+
107
+ except HTTPException:
108
+ raise # Re-raise validation errors
109
+ except Exception as e:
110
+ logger.error(f"Error retrieving context or history: {e}", exc_info=True)
111
+ raise HTTPException(
112
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
113
+ detail="Failed to retrieve relevant book content. Please try again."
114
+ )
115
+
116
+ # 4. Generate response using RAG service with Agents SDK
117
+ try:
118
+ rag_response = await RAGService.generate_response(
119
+ user_message=user_message_content,
120
+ context_chunks=context_chunks,
121
+ chat_history=history_dicts,
122
+ query_mode=query_mode,
123
+ selected_text=selected_text_content
124
+ )
125
+ assistant_message_content = rag_response["content"]
126
+ chunk_ids = rag_response["chunk_ids"]
127
+ model_used = rag_response["model_used"]
128
+
129
+ except APIStatusError as e:
130
+ logger.error(f"OpenAI API error (status: {e.status_code}): {e.message}")
131
+ if e.status_code == 429:
132
+ raise HTTPException(
133
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
134
+ detail="OpenAI API rate limit exceeded. Please try again shortly."
135
+ )
136
+ raise HTTPException(
137
+ status_code=status.HTTP_502_BAD_GATEWAY,
138
+ detail=f"OpenAI API error: {e.message}"
139
+ )
140
+ except APITimeoutError as e:
141
+ logger.error(f"OpenAI API timeout error: {e}")
142
+ raise HTTPException(
143
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT,
144
+ detail="OpenAI API timed out. Please try again."
145
+ )
146
+ except APIConnectionError as e:
147
+ logger.error(f"OpenAI API connection error: {e}")
148
+ raise HTTPException(
149
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
150
+ detail="Could not connect to OpenAI API. Please check your internet connection or try again later."
151
+ )
152
+ except Exception as e:
153
+ logger.error(f"Generic error from RAG service: {e}", exc_info=True)
154
+ raise HTTPException(
155
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
156
+ detail="Failed to generate response from chatbot. Please try again."
157
+ )
158
+
159
+ # Calculate response time
160
+ response_time_ms = int((time.time() - start_time) * 1000)
161
+
162
+ # 5. Save assistant message to DB
163
+ assistant_message_db = await ChatService.save_message(
164
+ db=db,
165
+ user_id=user_id,
166
+ thread_id=thread_id,
167
+ role=ChatUserRole.ASSISTANT,
168
+ content=assistant_message_content,
169
+ metadata={
170
+ "query_mode": query_mode,
171
+ "selected_text_context": selected_text_content if query_mode == "selection" else None,
172
+ "chunk_ids": chunk_ids,
173
+ "model_used": model_used,
174
+ "response_time_ms": response_time_ms
175
+ }
176
+ )
177
+
178
+ logger.info(f"Chat response generated for user {user_id} in thread {thread_id} ({response_time_ms}ms)")
179
+
180
+ return ChatResponse(
181
+ user_message=ChatMessageResponse.model_validate(user_message_db),
182
+ assistant_message=ChatMessageResponse.model_validate(assistant_message_db),
183
+ thread_id=thread_id
184
+ )
185
+
186
+
187
+ @router.get(
188
+ "/history",
189
+ response_model=ChatHistoryResponse,
190
+ status_code=status.HTTP_200_OK,
191
+ summary="Get chat history",
192
+ description="Retrieve paginated chat message history for a specific thread"
193
+ )
194
+ async def get_chat_history(
195
+ thread_id: str,
196
+ limit: int = 50,
197
+ offset: int = 0,
198
+ current_user: User = Depends(get_current_user),
199
+ db: AsyncSession = Depends(get_db_session)
200
+ ) -> ChatHistoryResponse:
201
+ """Retrieve chat history for a specific thread.
202
+
203
+ Args:
204
+ thread_id: The ID of the conversation thread.
205
+ limit: The maximum number of messages to return.
206
+ offset: The number of messages to skip for pagination.
207
+ current_user: The authenticated user.
208
+ db: The database session.
209
+
210
+ Returns:
211
+ ChatHistoryResponse containing messages and total count.
212
+ """
213
+ user_id = current_user.id
214
+
215
+ messages_db = await ChatService.get_chat_history(
216
+ db=db,
217
+ user_id=user_id,
218
+ thread_id=thread_id,
219
+ limit=limit,
220
+ offset=offset
221
+ )
222
+
223
+ # Convert to Pydantic models
224
+ messages_response = [ChatMessageResponse.model_validate(msg) for msg in messages_db]
225
+
226
+ # Get total count for pagination metadata
227
+ total_messages = await db.scalar(
228
+ select(func.count(ChatMessage.id))
229
+ .where(ChatMessage.user_id == user_id, ChatMessage.thread_id == thread_id)
230
+ )
231
+
232
+ logger.info(f"Retrieved chat history for user {user_id}, thread {thread_id}")
233
+
234
+ return ChatHistoryResponse(
235
+ messages=messages_response,
236
+ total=total_messages if total_messages is not None else 0,
237
+ thread_id=thread_id
238
+ )
239
+
240
+
241
+ @router.get(
242
+ "/threads",
243
+ response_model=List[Dict[str, Any]],
244
+ status_code=status.HTTP_200_OK,
245
+ summary="Get user chat threads",
246
+ description="Retrieve a list of all chat threads for the authenticated user"
247
+ )
248
+ async def get_user_chat_threads(
249
+ current_user: User = Depends(get_current_user),
250
+ db: AsyncSession = Depends(get_db_session)
251
+ ) -> List[Dict[str, Any]]:
252
+ """Retrieve a list of chat threads for the authenticated user.
253
+
254
+ Args:
255
+ current_user: The authenticated user.
256
+ db: The database session.
257
+
258
+ Returns:
259
+ A list of dictionaries, each representing a thread summary.
260
+ """
261
+ user_id = current_user.id
262
+ threads = await ChatService.get_user_threads(db=db, user_id=user_id)
263
+ logger.info(f"Retrieved {len(threads)} chat threads for user {user_id}")
264
+ return threads
src/api/routes/health.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Health check endpoint
2
+
3
+ Provides health status for the application and its dependencies.
4
+ """
5
+ from fastapi import APIRouter, status
6
+ from pydantic import BaseModel
7
+ from datetime import datetime
8
+ from typing import Dict, Any
9
+
10
+ from sqlalchemy import text
11
+
12
+ from src.config.database import get_engine, get_qdrant_client
13
+ from src.config.settings import settings
14
+ from src.utils.logger import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ router = APIRouter(tags=["health"])
19
+
20
+
21
+ class HealthResponse(BaseModel):
22
+ """Health check response model"""
23
+ status: str
24
+ timestamp: str
25
+ environment: str
26
+ dependencies: Dict[str, str]
27
+
28
+
29
+ @router.get(
30
+ "/health",
31
+ response_model=HealthResponse,
32
+ status_code=status.HTTP_200_OK,
33
+ summary="Health Check",
34
+ description="Check the health status of the application and its dependencies"
35
+ )
36
+ async def health_check() -> HealthResponse:
37
+ """Perform health check on application and dependencies
38
+
39
+ Returns:
40
+ HealthResponse with status and dependency information
41
+ """
42
+ dependencies: Dict[str, str] = {}
43
+
44
+ # Check database connection
45
+ try:
46
+ engine = get_engine()
47
+ async with engine.connect() as conn:
48
+ await conn.execute(text("SELECT 1"))
49
+ dependencies["database"] = "healthy"
50
+ except Exception as e:
51
+ logger.error(f"Database health check failed: {e}")
52
+ dependencies["database"] = "unhealthy"
53
+
54
+ # Check Qdrant connection
55
+ try:
56
+ client = get_qdrant_client()
57
+ await client.get_collections()
58
+ dependencies["qdrant"] = "healthy"
59
+ except Exception as e:
60
+ logger.error(f"Qdrant health check failed: {e}")
61
+ dependencies["qdrant"] = "unhealthy"
62
+
63
+ # Overall status
64
+ overall_status = "healthy" if all(
65
+ dep_status == "healthy" for dep_status in dependencies.values()
66
+ ) else "degraded"
67
+
68
+ return HealthResponse(
69
+ status=overall_status,
70
+ timestamp=datetime.utcnow().isoformat(),
71
+ environment=settings.environment,
72
+ dependencies=dependencies
73
+ )
src/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Configuration package"""
src/config/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (197 Bytes). View file
 
src/config/__pycache__/database.cpython-312.pyc ADDED
Binary file (5.53 kB). View file
 
src/config/__pycache__/settings.cpython-312.pyc ADDED
Binary file (4.15 kB). View file
 
src/config/database.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Database and vector store configuration
2
+
3
+ Provides async database engine, session management, and Qdrant client setup.
4
+ """
5
+ from typing import AsyncGenerator
6
+ from sqlalchemy.ext.asyncio import (
7
+ AsyncEngine,
8
+ AsyncSession,
9
+ create_async_engine,
10
+ async_sessionmaker,
11
+ )
12
+ from sqlalchemy.orm import declarative_base
13
+ from qdrant_client import AsyncQdrantClient
14
+ from qdrant_client.models import Distance, VectorParams
15
+
16
+ from src.config.settings import settings
17
+ from src.utils.logger import get_logger
18
+
19
+ logger = get_logger(__name__)
20
+
21
+ # SQLAlchemy Base for models
22
+ Base = declarative_base()
23
+
24
+ # Global engine instance
25
+ _engine: AsyncEngine | None = None
26
+
27
+ # Global session maker
28
+ _async_session_maker: async_sessionmaker[AsyncSession] | None = None
29
+
30
+ # Global Qdrant client
31
+ _qdrant_client: AsyncQdrantClient | None = None
32
+
33
+
34
+ def get_engine() -> AsyncEngine:
35
+ """Get or create the async database engine
36
+
37
+ Returns:
38
+ AsyncEngine instance
39
+ """
40
+ global _engine
41
+
42
+ if _engine is None:
43
+ # AsyncPG connection arguments for SSL
44
+ connect_args = {}
45
+ if "neon.tech" in settings.database_url or settings.is_production:
46
+ # Enable SSL for Neon and production databases
47
+ connect_args["ssl"] = "require"
48
+
49
+ _engine = create_async_engine(
50
+ settings.async_database_url,
51
+ echo=not settings.is_production, # Log SQL in development
52
+ pool_pre_ping=True, # Verify connections before using
53
+ pool_size=5,
54
+ max_overflow=10,
55
+ connect_args=connect_args,
56
+ )
57
+ logger.info("Database engine created")
58
+
59
+ return _engine
60
+
61
+
62
+ def get_session_maker() -> async_sessionmaker[AsyncSession]:
63
+ """Get or create the async session maker
64
+
65
+ Returns:
66
+ async_sessionmaker instance
67
+ """
68
+ global _async_session_maker
69
+
70
+ if _async_session_maker is None:
71
+ engine = get_engine()
72
+ _async_session_maker = async_sessionmaker(
73
+ engine,
74
+ class_=AsyncSession,
75
+ expire_on_commit=False,
76
+ )
77
+ logger.info("Session maker created")
78
+
79
+ return _async_session_maker
80
+
81
+
82
+ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
83
+ """Dependency for getting async database sessions
84
+
85
+ Yields:
86
+ AsyncSession instance
87
+ """
88
+ session_maker = get_session_maker()
89
+ async with session_maker() as session:
90
+ try:
91
+ yield session
92
+ finally:
93
+ await session.close()
94
+
95
+
96
+ def get_qdrant_client() -> AsyncQdrantClient:
97
+ """Get or create the Qdrant client
98
+
99
+ Returns:
100
+ AsyncQdrantClient instance
101
+ """
102
+ global _qdrant_client
103
+
104
+ if _qdrant_client is None:
105
+ _qdrant_client = AsyncQdrantClient(
106
+ url=settings.qdrant_url,
107
+ api_key=settings.qdrant_api_key,
108
+ timeout=30.0,
109
+ )
110
+ logger.info("Qdrant client created")
111
+
112
+ return _qdrant_client
113
+
114
+
115
+ async def init_qdrant_collection() -> None:
116
+ """Initialize Qdrant collection if it doesn't exist
117
+
118
+ Creates the collection with appropriate vector configuration.
119
+ """
120
+ client = get_qdrant_client()
121
+
122
+ # Check if collection exists
123
+ collections = await client.get_collections()
124
+ collection_names = [col.name for col in collections.collections]
125
+
126
+ if settings.qdrant_collection_name not in collection_names:
127
+ # Create collection with vector configuration
128
+ await client.create_collection(
129
+ collection_name=settings.qdrant_collection_name,
130
+ vectors_config=VectorParams(
131
+ size=settings.vector_size,
132
+ distance=Distance.COSINE,
133
+ ),
134
+ )
135
+ logger.info(f"Created Qdrant collection: {settings.qdrant_collection_name}")
136
+ else:
137
+ logger.info(f"Qdrant collection already exists: {settings.qdrant_collection_name}")
138
+
139
+
140
+ async def close_database_connections() -> None:
141
+ """Close all database connections gracefully"""
142
+ global _engine, _qdrant_client
143
+
144
+ if _engine is not None:
145
+ await _engine.dispose()
146
+ logger.info("Database engine disposed")
147
+ _engine = None
148
+
149
+ if _qdrant_client is not None:
150
+ await _qdrant_client.close()
151
+ logger.info("Qdrant client closed")
152
+ _qdrant_client = None
src/config/settings.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application settings and configuration
2
+
3
+ Loads environment variables and provides type-safe configuration.
4
+ """
5
+ from typing import List, Union, Optional
6
+ from pydantic import field_validator
7
+ from pydantic_settings import BaseSettings, SettingsConfigDict
8
+ from openai import AsyncOpenAI
9
+
10
+
11
+ class Settings(BaseSettings):
12
+ """Application settings loaded from environment variables"""
13
+
14
+ model_config = SettingsConfigDict(
15
+ env_file=".env",
16
+ env_file_encoding="utf-8",
17
+ case_sensitive=False,
18
+ extra="ignore"
19
+ )
20
+
21
+ # Database
22
+ database_url: str
23
+
24
+ # Qdrant Vector Database
25
+ qdrant_url: str
26
+ qdrant_api_key: str
27
+ qdrant_collection_name: str = "humanoid-robotics-book-v1"
28
+ vector_size: int = 1536 # OpenAI text-embedding-3-small dimension
29
+
30
+ # OpenAI
31
+ openai_api_key: str
32
+ openai_embedding_model: str = "text-embedding-3-small"
33
+ chat_model: str = "gpt-4o-mini" # Fast, cost-effective model for chat (was gpt-4-turbo-preview)
34
+
35
+ # Authentication
36
+ better_auth_secret: str
37
+ session_expiry_days: int = 7
38
+
39
+ # Rate Limiting
40
+ rate_limit_per_minute: int = 20
41
+ redis_url: str = "redis://localhost:6379"
42
+
43
+ # CORS
44
+ allowed_origins: Union[str, List[str]] = "http://localhost:3000,http://localhost:8000"
45
+
46
+ # Application
47
+ environment: str = "development"
48
+ log_level: str = "INFO"
49
+
50
+ @field_validator("allowed_origins", mode="before")
51
+ @classmethod
52
+ def parse_cors_origins(cls, v):
53
+ """Parse CORS origins from comma-separated string or list"""
54
+ if isinstance(v, str):
55
+ return [origin.strip() for origin in v.split(",")]
56
+ return v
57
+
58
+ @property
59
+ def is_production(self) -> bool:
60
+ """Check if running in production environment"""
61
+ return self.environment.lower() == "production"
62
+
63
+ @property
64
+ def async_database_url(self) -> str:
65
+ """Get async database URL for SQLAlchemy
66
+
67
+ Converts postgresql:// to postgresql+asyncpg:// and removes sslmode and
68
+ channel_binding parameters since asyncpg uses different SSL configuration.
69
+ """
70
+ url = self.database_url
71
+
72
+ # Replace postgresql:// with postgresql+asyncpg://
73
+ if url.startswith("postgresql://"):
74
+ url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
75
+
76
+ # Remove sslmode and channel_binding parameters that asyncpg doesn't support
77
+ # asyncpg will handle SSL automatically
78
+ if "sslmode=" in url or "channel_binding=" in url:
79
+ from urllib.parse import urlparse, parse_qs, urlencode, urlunparse
80
+ parsed = urlparse(url)
81
+ query_params = parse_qs(parsed.query)
82
+ # Remove sslmode and channel_binding
83
+ query_params.pop('sslmode', None)
84
+ query_params.pop('channel_binding', None)
85
+ # Reconstruct the query string
86
+ new_query = urlencode(query_params, doseq=True)
87
+ url = urlunparse((
88
+ parsed.scheme,
89
+ parsed.netloc,
90
+ parsed.path,
91
+ parsed.params,
92
+ new_query,
93
+ parsed.fragment
94
+ ))
95
+
96
+ return url
97
+
98
+
99
+ # Global settings instance
100
+ settings = Settings()
101
+
102
+ # Global OpenAI client instance (only if API key is provided)
103
+ openai_client = AsyncOpenAI(api_key=settings.openai_api_key) if settings.openai_api_key else None
src/main.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application entry point
2
+
3
+ Main application setup with middleware, CORS, and route configuration.
4
+ """
5
+ from contextlib import asynccontextmanager
6
+ from typing import AsyncGenerator
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.middleware.trustedhost import TrustedHostMiddleware
10
+ # Temporarily disabled for debugging
11
+ from secure import Secure
12
+
13
+ # Secure middleware configuration
14
+ # Temporarily disabled for debugging
15
+ secure_headers = Secure()
16
+ from fastapi.responses import JSONResponse
17
+ # Temporarily disabled for debugging
18
+ from slowapi import Limiter, _rate_limit_exceeded_handler
19
+ from slowapi.util import get_remote_address
20
+ from slowapi.errors import RateLimitExceeded
21
+
22
+ from src.config.settings import settings
23
+ from src.config.database import init_qdrant_collection, close_database_connections
24
+ from src.utils.logger import setup_logging, get_logger
25
+ from src.api.routes import health, auth, chat
26
+
27
+ # Setup logging
28
+ setup_logging(
29
+ level=settings.log_level,
30
+ use_json=settings.is_production
31
+ )
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ # Rate limiter configuration (using the one from rate_limit middleware)
37
+ # limiter = Limiter(
38
+ # key_func=get_remote_address,
39
+ # default_limits=[f"{settings.rate_limit_per_minute}/minute"]
40
+ # )
41
+
42
+
43
+ @asynccontextmanager
44
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
45
+ """Application lifespan manager
46
+
47
+ Handles startup and shutdown events.
48
+ """
49
+ # Startup
50
+ logger.info("Starting up application...")
51
+ logger.info(f"Environment: {settings.environment}")
52
+
53
+ # Initialize Qdrant collection
54
+ try:
55
+ await init_qdrant_collection()
56
+ except Exception as e:
57
+ logger.error(f"Failed to initialize Qdrant collection: {e}")
58
+
59
+ logger.info("Application startup complete")
60
+
61
+ yield
62
+
63
+ # Shutdown
64
+ logger.info("Shutting down application...")
65
+ await close_database_connections()
66
+ logger.info("Application shutdown complete")
67
+
68
+
69
+ # Create FastAPI application
70
+ app = FastAPI(
71
+ title="RAG Chatbot API",
72
+ description="Retrieval-Augmented Generation chatbot for humanoid robotics textbook",
73
+ version="1.0.0",
74
+ docs_url="/api/docs" if not settings.is_production else None,
75
+ redoc_url="/api/redoc" if not settings.is_production else None,
76
+ lifespan=lifespan
77
+ )
78
+
79
+ # Temporarily disabled for debugging
80
+ from src.api.middleware.logging_middleware import LoggingMiddleware
81
+ from src.api.middleware.rate_limit import limiter, rate_limit_exceeded_handler
82
+
83
+ # Add rate limiter to app state
84
+ # Temporarily disabled for debugging
85
+ # app.state.limiter = limiter
86
+ app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)
87
+
88
+ # Add logging middleware
89
+ # Temporarily disabled due to middleware stack build error
90
+ app.add_middleware(LoggingMiddleware)
91
+
92
+ # Add secure headers middleware
93
+ # Temporarily disabled for debugging
94
+ @app.middleware("http")
95
+ async def secure_headers_middleware(request: Request, call_next):
96
+ response = await call_next(request)
97
+
98
+ try:
99
+ # If the Secure instance exposes an integration helper named "framework"
100
+ # that supports FastAPI/Starlette, use it.
101
+ if getattr(secure_headers, "framework", None) is not None:
102
+ fw = secure_headers.framework
103
+ # defensive: some versions expose attributes differently
104
+ if hasattr(fw, "fastapi"):
105
+ fw.fastapi(response)
106
+ elif hasattr(fw, "starlette"):
107
+ fw.starlette(response)
108
+ else:
109
+ # fallback to a generic method if one exists
110
+ if hasattr(secure_headers, "apply"):
111
+ secure_headers.apply(response)
112
+ elif hasattr(secure_headers, "add"):
113
+ secure_headers.add(response)
114
+ else:
115
+ raise AttributeError("Secure instance has no recognized integration methods")
116
+ else:
117
+ # library not integrated or missing; apply safe default headers manually
118
+ # These are conservative, common security headers.
119
+ response.headers.setdefault("X-Content-Type-Options", "nosniff")
120
+ response.headers.setdefault("X-Frame-Options", "DENY")
121
+ response.headers.setdefault("Referrer-Policy", "no-referrer")
122
+ response.headers.setdefault("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
123
+ response.headers.setdefault("Permissions-Policy", "geolocation=()")
124
+ response.headers.setdefault("X-XSS-Protection", "1; mode=block")
125
+ except Exception:
126
+ # log the failure but do not crash the request pipeline
127
+ logger.exception("Failed to apply secure headers")
128
+
129
+ return response
130
+
131
+ # CORS middleware configuration
132
+ # Temporarily disabled for debugging
133
+ app.add_middleware(
134
+ CORSMiddleware,
135
+ allow_origins=settings.allowed_origins,
136
+ allow_credentials=True,
137
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
138
+ allow_headers=["*"],
139
+ expose_headers=["X-Request-ID"],
140
+ )
141
+
142
+
143
+ # Exception handlers
144
+ @app.exception_handler(Exception)
145
+ async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
146
+ """Global exception handler for unhandled errors
147
+
148
+ Args:
149
+ request: FastAPI request
150
+ exc: Exception that was raised
151
+
152
+ Returns:
153
+ JSONResponse with error details
154
+ """
155
+ logger.error(
156
+ f"Unhandled exception: {exc}",
157
+ extra={
158
+ "path": request.url.path,
159
+ "method": request.method,
160
+ "client": request.client.host if request.client else "unknown"
161
+ }
162
+ )
163
+
164
+ return JSONResponse(
165
+ status_code=500,
166
+ content={
167
+ "error": "Internal server error",
168
+ "message": "An unexpected error occurred" if settings.is_production else str(exc)
169
+ }
170
+ )
171
+
172
+
173
+ # Include routers
174
+ app.include_router(health.router, prefix="/api")
175
+ app.include_router(auth.router, prefix="/api")
176
+ app.include_router(chat.router, prefix="/api")
177
+
178
+
179
+ # Root endpoint
180
+ @app.get("/", tags=["root"])
181
+ async def root() -> dict[str, str]:
182
+ """Root endpoint
183
+
184
+ Returns:
185
+ Welcome message
186
+ """
187
+ return {
188
+ "message": "RAG Chatbot API",
189
+ "status": "running",
190
+ "docs": "/api/docs" if not settings.is_production else "disabled"
191
+ }
src/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Database models package"""
2
+ from src.models.user import User
3
+ from src.models.session import Session
4
+ from src.models.chat_message import ChatMessage
5
+
6
+ __all__ = ["User", "Session", "ChatMessage"]
src/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (399 Bytes). View file
 
src/models/__pycache__/chat_message.cpython-312.pyc ADDED
Binary file (2.79 kB). View file
 
src/models/__pycache__/schemas.cpython-312.pyc ADDED
Binary file (7.65 kB). View file
 
src/models/__pycache__/session.cpython-312.pyc ADDED
Binary file (2.5 kB). View file
 
src/models/__pycache__/user.cpython-312.pyc ADDED
Binary file (2.22 kB). View file
 
src/models/chat_message.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chat message model for storing conversation history
2
+
3
+ Represents a single question-answer exchange in the chatbot.
4
+ Aligns with data-model.md ChatMessage entity specification.
5
+ """
6
+ from datetime import datetime
7
+ from uuid import UUID, uuid4
8
+ from enum import Enum
9
+
10
+ from sqlalchemy import Column, String, Text, TIMESTAMP, ForeignKey, text
11
+ from sqlalchemy.dialects.postgresql import UUID as PGUUID, JSONB
12
+ from sqlalchemy.orm import relationship
13
+
14
+ from src.config.database import Base
15
+
16
+
17
+ class ChatUserRole(str, Enum):
18
+ """Enum for the role of the chat message sender."""
19
+ USER = "user"
20
+ ASSISTANT = "assistant"
21
+
22
+
23
+ class ChatMessage(Base):
24
+ """ChatMessage model for storing conversation history
25
+
26
+ Attributes:
27
+ id: Unique message identifier (UUID)
28
+ user_id: Foreign key to User
29
+ thread_id: OpenAI Agents SDK thread identifier
30
+ role: Message sender (user or assistant)
31
+ content: Message text content
32
+ metadata: Additional context (JSONB)
33
+ created_at: Message creation timestamp
34
+ """
35
+
36
+ __tablename__ = "chat_messages"
37
+
38
+ id = Column(
39
+ PGUUID(as_uuid=True),
40
+ primary_key=True,
41
+ default=uuid4,
42
+ server_default="gen_random_uuid()"
43
+ )
44
+ user_id = Column(PGUUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
45
+ thread_id = Column(String(255), nullable=False)
46
+ role = Column(String(10), nullable=False) # CHECK (role IN ('user', 'assistant')) will be added in migration
47
+ content = Column(Text, nullable=False)
48
+ message_metadata = Column('metadata', JSONB, default=dict, server_default=text("'{}'::jsonb"), nullable=False)
49
+ created_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, server_default="NOW()")
50
+
51
+ # Relationships
52
+ user = relationship("User", back_populates="chat_messages")
53
+
54
+ def __repr__(self) -> str:
55
+ return f"<ChatMessage(id={self.id}, role={self.role}, thread_id={self.thread_id})>"
src/models/schemas.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic schemas for request/response validation
2
+
3
+ Provides type-safe data validation for API endpoints.
4
+ """
5
+ from datetime import datetime
6
+ from uuid import UUID
7
+ from typing import Optional, Dict, Any
8
+ from pydantic import BaseModel, EmailStr, Field, field_validator
9
+
10
+
11
+ # ============================================================================
12
+ # User Schemas
13
+ # ============================================================================
14
+
15
+ class UserCreate(BaseModel):
16
+ """Schema for user registration request"""
17
+ email: EmailStr = Field(..., description="User's email address")
18
+ password: str = Field(..., min_length=8, max_length=128, description="User's password")
19
+
20
+ @field_validator("password")
21
+ @classmethod
22
+ def validate_password_strength(cls, v: str) -> str:
23
+ """Validate password meets strength requirements"""
24
+ if not any(c.isupper() for c in v):
25
+ raise ValueError("Password must contain at least one uppercase letter")
26
+ if not any(c.islower() for c in v):
27
+ raise ValueError("Password must contain at least one lowercase letter")
28
+ if not any(c.isdigit() for c in v):
29
+ raise ValueError("Password must contain at least one digit")
30
+ if not any(c in "!@#$%^&*(),.?\":{}|<>" for c in v):
31
+ raise ValueError("Password must contain at least one special character")
32
+ return v
33
+
34
+
35
+ class UserLogin(BaseModel):
36
+ """Schema for user login request"""
37
+ email: EmailStr = Field(..., description="User's email address")
38
+ password: str = Field(..., min_length=1, max_length=128, description="User's password")
39
+
40
+
41
+ class UserResponse(BaseModel):
42
+ """Schema for user data in responses"""
43
+ id: UUID
44
+ email: str
45
+ created_at: datetime
46
+ updated_at: datetime
47
+
48
+ model_config = {"from_attributes": True}
49
+
50
+
51
+ class AuthResponse(BaseModel):
52
+ """Schema for authentication response"""
53
+ user: UserResponse
54
+ message: str = "Authentication successful"
55
+
56
+
57
+ # ============================================================================
58
+ # Session Schemas
59
+ # ============================================================================
60
+
61
+ class SessionResponse(BaseModel):
62
+ """Schema for session data in responses"""
63
+ id: UUID
64
+ user_id: UUID
65
+ expires_at: datetime
66
+ created_at: datetime
67
+
68
+ model_config = {"from_attributes": True}
69
+
70
+
71
+ # ============================================================================
72
+ # Chat Message Schemas
73
+ # ============================================================================
74
+
75
+ class ChatMessageCreate(BaseModel):
76
+ """Schema for creating a new chat message"""
77
+ message: str = Field(..., min_length=1, max_length=10000, description="User's message content")
78
+ thread_id: Optional[str] = Field(None, min_length=1, max_length=255, description="OpenAI thread ID, optional for new threads")
79
+ query_mode: Optional[str] = Field(None, description="Query mode: 'full_book' or 'selection'")
80
+ selected_text: Optional[str] = Field(None, max_length=5000, description="Selected text for context queries")
81
+
82
+ @field_validator("query_mode")
83
+ @classmethod
84
+ def validate_query_mode(cls, v: Optional[str]) -> Optional[str]:
85
+ """Validate query mode is one of allowed values"""
86
+ if v is not None and v not in ["full_book", "selection"]:
87
+ raise ValueError("query_mode must be 'full_book' or 'selection'")
88
+ return v
89
+
90
+ @field_validator("selected_text")
91
+ @classmethod
92
+ def validate_selected_text(cls, v: Optional[str], info: Any) -> Optional[str]:
93
+ """Validate selected_text is present when query_mode is 'selection'"""
94
+ if info.data.get("query_mode") == "selection" and not v:
95
+ raise ValueError("selected_text is required when query_mode is 'selection'")
96
+ return v
97
+
98
+
99
+ class ChatMessageResponse(BaseModel):
100
+ """Schema for chat message in responses"""
101
+ id: UUID
102
+ user_id: UUID
103
+ thread_id: str
104
+ role: str
105
+ content: str
106
+ metadata: Dict[str, Any] = Field(..., alias="message_metadata") # Use alias for Pydantic V2
107
+ created_at: datetime
108
+
109
+ model_config = {
110
+ "from_attributes": True,
111
+ "populate_by_name": True,
112
+ }
113
+
114
+
115
+
116
+ class ChatResponse(BaseModel):
117
+ """Schema for chat response with user and assistant messages"""
118
+ user_message: ChatMessageResponse
119
+ assistant_message: ChatMessageResponse
120
+ thread_id: str
121
+
122
+
123
+ class ChatHistoryResponse(BaseModel):
124
+ """Schema for thread history response"""
125
+ messages: list[ChatMessageResponse]
126
+ total: int
127
+ thread_id: str
128
+
129
+
130
+ # ============================================================================
131
+ # Generic Response Schemas
132
+ # ============================================================================
133
+
134
+ class MessageResponse(BaseModel):
135
+ """Generic message response"""
136
+ message: str
137
+
138
+
139
+ class ErrorResponse(BaseModel):
140
+ """Error response schema"""
141
+ error: str
142
+ message: str
143
+ details: Optional[Dict[str, Any]] = None
src/models/session.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session model for authentication token management
2
+
3
+ Represents an active authentication session with JWT tokens and expiration.
4
+ Aligns with data-model.md Session entity specification.
5
+ """
6
+ from datetime import datetime
7
+ from uuid import UUID, uuid4
8
+ from sqlalchemy import Column, String, TIMESTAMP, ForeignKey
9
+ from sqlalchemy.dialects.postgresql import UUID as PGUUID
10
+ from sqlalchemy.orm import relationship
11
+
12
+ from src.config.database import Base
13
+
14
+
15
+ class Session(Base):
16
+ """Session model for JWT token management
17
+
18
+ Attributes:
19
+ id: Unique session identifier (UUID)
20
+ user_id: Foreign key to User
21
+ token_hash: Hashed JWT token (unique)
22
+ expires_at: Session expiration timestamp
23
+ created_at: Session creation timestamp
24
+ """
25
+
26
+ __tablename__ = "sessions"
27
+
28
+ id = Column(
29
+ PGUUID(as_uuid=True),
30
+ primary_key=True,
31
+ default=uuid4,
32
+ server_default="gen_random_uuid()"
33
+ )
34
+ user_id = Column(
35
+ PGUUID(as_uuid=True),
36
+ ForeignKey("users.id", ondelete="CASCADE"),
37
+ nullable=False,
38
+ index=True
39
+ )
40
+ token_hash = Column(String(255), unique=True, nullable=False, index=True)
41
+ expires_at = Column(TIMESTAMP, nullable=False, index=True)
42
+ created_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, server_default="NOW()")
43
+
44
+ # Relationships
45
+ user = relationship("User", back_populates="sessions")
46
+
47
+ def __repr__(self) -> str:
48
+ return f"<Session(id={self.id}, user_id={self.user_id}, expires_at={self.expires_at})>"
49
+
50
+ @property
51
+ def is_expired(self) -> bool:
52
+ """Check if session has expired"""
53
+ return datetime.utcnow() > self.expires_at
src/models/user.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """User model for authentication and authorization
2
+
3
+ Represents an authenticated reader with account credentials.
4
+ Aligns with data-model.md User entity specification.
5
+ """
6
+ from datetime import datetime
7
+ from uuid import UUID, uuid4
8
+ from sqlalchemy import Column, String, TIMESTAMP
9
+ from sqlalchemy.dialects.postgresql import UUID as PGUUID
10
+ from sqlalchemy.orm import relationship
11
+
12
+ from src.config.database import Base
13
+
14
+
15
+ class User(Base):
16
+ """User model for authentication
17
+
18
+ Attributes:
19
+ id: Unique user identifier (UUID)
20
+ email: User's email address (unique)
21
+ password_hash: Hashed password (managed by auth service)
22
+ created_at: Account creation timestamp
23
+ updated_at: Last account modification timestamp
24
+ """
25
+
26
+ __tablename__ = "users"
27
+
28
+ id = Column(
29
+ PGUUID(as_uuid=True),
30
+ primary_key=True,
31
+ default=uuid4,
32
+ server_default="gen_random_uuid()"
33
+ )
34
+ email = Column(String(255), unique=True, nullable=False, index=True)
35
+ password_hash = Column(String(255), nullable=False)
36
+ created_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, server_default="NOW()")
37
+ updated_at = Column(TIMESTAMP, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, server_default="NOW()")
38
+
39
+ # Relationships
40
+ sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
41
+ chat_messages = relationship("ChatMessage", back_populates="user", cascade="all, delete-orphan")
42
+
43
+ def __repr__(self) -> str:
44
+ return f"<User(id={self.id}, email={self.email})>"
src/services/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Services package"""
2
+ from src.services.auth_service import AuthService
3
+ from src.services.vector_service import VectorService
4
+ from src.services.chat_service import ChatService
5
+
6
+ __all__ = ["AuthService", "VectorService", "ChatService"]
src/services/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (428 Bytes). View file
 
src/services/__pycache__/auth_service.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
src/services/__pycache__/chat_service.cpython-312.pyc ADDED
Binary file (6.02 kB). View file
 
src/services/__pycache__/rag_service.cpython-312.pyc ADDED
Binary file (9.74 kB). View file
 
src/services/__pycache__/vector_service.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
src/services/auth_service.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Authentication service for user management and session handling
2
+
3
+ Provides user registration, authentication, password hashing,
4
+ JWT token generation, and session management.
5
+ """
6
+ from datetime import datetime, timedelta
7
+ from typing import Optional
8
+ from uuid import UUID
9
+ import hashlib
10
+ from passlib.context import CryptContext
11
+ from jose import JWTError, jwt
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+ from sqlalchemy import select, delete
14
+
15
+ from src.models.user import User
16
+ from src.models.session import Session
17
+ from src.config.settings import settings
18
+ from src.utils.logger import get_logger
19
+
20
+ logger = get_logger(__name__)
21
+
22
+ # Password hashing configuration
23
+ # Using Argon2 (modern, memory-hard algorithm) with bcrypt fallback for existing passwords
24
+ pwd_context = CryptContext(schemes=["argon2", "bcrypt"], deprecated="auto")
25
+
26
+ # JWT configuration
27
+ ALGORITHM = "HS256"
28
+
29
+
30
+ class AuthService:
31
+ """Authentication service for user and session management"""
32
+
33
+ @staticmethod
34
+ def hash_password(password: str) -> str:
35
+ """Hash a plain text password using Argon2
36
+
37
+ Args:
38
+ password: Plain text password (no length limitations with Argon2)
39
+
40
+ Returns:
41
+ Hashed password
42
+ """
43
+ # Argon2 is a modern, memory-hard hashing algorithm with no password length limits
44
+ return pwd_context.hash(password, scheme="argon2")
45
+
46
+ @staticmethod
47
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
48
+ """Verify a password against its hash
49
+
50
+ Args:
51
+ plain_password: Plain text password to verify
52
+ hashed_password: Hashed password to compare against
53
+
54
+ Returns:
55
+ True if password matches, False otherwise
56
+ """
57
+ # Passlib automatically handles both Argon2 and bcrypt hashes
58
+ # Works with the hash that was used (supports password migration)
59
+ return pwd_context.verify(plain_password, hashed_password)
60
+
61
+ @staticmethod
62
+ def generate_jwt_token(user_id: UUID) -> str:
63
+ """Generate a JWT token for a user
64
+
65
+ Args:
66
+ user_id: User's UUID
67
+
68
+ Returns:
69
+ JWT token string
70
+ """
71
+ expires_delta = timedelta(days=settings.session_expiry_days)
72
+ expire = datetime.utcnow() + expires_delta
73
+
74
+ to_encode = {
75
+ "sub": str(user_id),
76
+ "exp": expire,
77
+ "iat": datetime.utcnow()
78
+ }
79
+
80
+ encoded_jwt = jwt.encode(
81
+ to_encode,
82
+ settings.better_auth_secret,
83
+ algorithm=ALGORITHM
84
+ )
85
+ return encoded_jwt
86
+
87
+ @staticmethod
88
+ def decode_jwt_token(token: str) -> Optional[dict]:
89
+ """Decode and validate a JWT token
90
+
91
+ Args:
92
+ token: JWT token string
93
+
94
+ Returns:
95
+ Decoded token payload or None if invalid
96
+ """
97
+ try:
98
+ payload = jwt.decode(
99
+ token,
100
+ settings.better_auth_secret,
101
+ algorithms=[ALGORITHM]
102
+ )
103
+ return payload
104
+ except JWTError as e:
105
+ logger.warning(f"JWT decode error: {e}")
106
+ return None
107
+
108
+ @staticmethod
109
+ def hash_token(token: str) -> str:
110
+ """Create SHA-256 hash of a token for storage
111
+
112
+ Args:
113
+ token: Token to hash
114
+
115
+ Returns:
116
+ Hex digest of token hash
117
+ """
118
+ return hashlib.sha256(token.encode()).hexdigest()
119
+
120
+ @staticmethod
121
+ async def create_user(
122
+ db: AsyncSession,
123
+ email: str,
124
+ password: str
125
+ ) -> User:
126
+ """Create a new user account
127
+
128
+ Args:
129
+ db: Database session
130
+ email: User's email address
131
+ password: Plain text password
132
+
133
+ Returns:
134
+ Created User instance
135
+ """
136
+ password_hash = AuthService.hash_password(password)
137
+
138
+ user = User(
139
+ email=email.lower(), # Normalize email to lowercase
140
+ password_hash=password_hash
141
+ )
142
+
143
+ db.add(user)
144
+ await db.commit()
145
+ await db.refresh(user)
146
+
147
+ logger.info(f"User created: {user.id}")
148
+ return user
149
+
150
+ @staticmethod
151
+ async def authenticate_user(
152
+ db: AsyncSession,
153
+ email: str,
154
+ password: str
155
+ ) -> Optional[User]:
156
+ """Authenticate a user by email and password
157
+
158
+ Args:
159
+ db: Database session
160
+ email: User's email address
161
+ password: Plain text password
162
+
163
+ Returns:
164
+ User instance if authenticated, None otherwise
165
+ """
166
+ # Query user by email
167
+ result = await db.execute(
168
+ select(User).where(User.email == email.lower())
169
+ )
170
+ user = result.scalar_one_or_none()
171
+
172
+ if user is None:
173
+ logger.warning(f"Authentication failed: user not found for email {email}")
174
+ return None
175
+
176
+ if not AuthService.verify_password(password, user.password_hash):
177
+ logger.warning(f"Authentication failed: invalid password for user {user.id}")
178
+ return None
179
+
180
+ logger.info(f"User authenticated: {user.id}")
181
+ return user
182
+
183
+ @staticmethod
184
+ async def create_session(
185
+ db: AsyncSession,
186
+ user_id: UUID,
187
+ token: str
188
+ ) -> Session:
189
+ """Create a new session for a user
190
+
191
+ Args:
192
+ db: Database session
193
+ user_id: User's UUID
194
+ token: JWT token string
195
+
196
+ Returns:
197
+ Created Session instance
198
+ """
199
+ token_hash = AuthService.hash_token(token)
200
+ expires_at = datetime.utcnow() + timedelta(days=settings.session_expiry_days)
201
+
202
+ session = Session(
203
+ user_id=user_id,
204
+ token_hash=token_hash,
205
+ expires_at=expires_at
206
+ )
207
+
208
+ db.add(session)
209
+ await db.commit()
210
+ await db.refresh(session)
211
+
212
+ logger.info(f"Session created: {session.id} for user {user_id}")
213
+ return session
214
+
215
+ @staticmethod
216
+ async def validate_session(
217
+ db: AsyncSession,
218
+ token: str
219
+ ) -> Optional[Session]:
220
+ """Validate a session token
221
+
222
+ Args:
223
+ db: Database session
224
+ token: JWT token string
225
+
226
+ Returns:
227
+ Session instance if valid, None otherwise
228
+ """
229
+ token_hash = AuthService.hash_token(token)
230
+
231
+ # Query session by token hash
232
+ result = await db.execute(
233
+ select(Session).where(Session.token_hash == token_hash)
234
+ )
235
+ session = result.scalar_one_or_none()
236
+
237
+ if session is None:
238
+ logger.warning("Session validation failed: session not found")
239
+ return None
240
+
241
+ if session.is_expired:
242
+ logger.warning(f"Session validation failed: session {session.id} expired")
243
+ return None
244
+
245
+ return session
246
+
247
+ @staticmethod
248
+ async def revoke_session(
249
+ db: AsyncSession,
250
+ token: str
251
+ ) -> bool:
252
+ """Revoke a session by token
253
+
254
+ Args:
255
+ db: Database session
256
+ token: JWT token string
257
+
258
+ Returns:
259
+ True if session was revoked, False otherwise
260
+ """
261
+ token_hash = AuthService.hash_token(token)
262
+
263
+ result = await db.execute(
264
+ delete(Session).where(Session.token_hash == token_hash)
265
+ )
266
+ await db.commit()
267
+
268
+ revoked = result.rowcount > 0
269
+ if revoked:
270
+ logger.info(f"Session revoked for token hash {token_hash[:16]}...")
271
+
272
+ return revoked
273
+
274
+ @staticmethod
275
+ async def cleanup_expired_sessions(db: AsyncSession) -> int:
276
+ """Remove all expired sessions from database
277
+
278
+ Args:
279
+ db: Database session
280
+
281
+ Returns:
282
+ Number of sessions deleted
283
+ """
284
+ result = await db.execute(
285
+ delete(Session).where(Session.expires_at < datetime.utcnow())
286
+ )
287
+ await db.commit()
288
+
289
+ count = result.rowcount
290
+ if count > 0:
291
+ logger.info(f"Cleaned up {count} expired sessions")
292
+
293
+ return count
294
+
295
+ @staticmethod
296
+ async def get_user_by_id(
297
+ db: AsyncSession,
298
+ user_id: UUID
299
+ ) -> Optional[User]:
300
+ """Get a user by ID
301
+
302
+ Args:
303
+ db: Database session
304
+ user_id: User's UUID
305
+
306
+ Returns:
307
+ User instance if found, None otherwise
308
+ """
309
+ result = await db.execute(
310
+ select(User).where(User.id == user_id)
311
+ )
312
+ return result.scalar_one_or_none()