Spaces:
Runtime error
Runtime error
Commit
·
7be47d4
0
Parent(s):
clear history and fix bug websocket connection
Browse files- .dockerignore +46 -0
- .env.example +33 -0
- .gitignore +83 -0
- Dockerfile +31 -0
- README.md +419 -0
- app.py +223 -0
- app/__init__.py +25 -0
- app/api/__init__.py +1 -0
- app/api/mongodb_routes.py +276 -0
- app/api/pdf_routes.py +310 -0
- app/api/pdf_websocket.py +263 -0
- app/api/postgresql_routes.py +0 -0
- app/api/rag_routes.py +338 -0
- app/api/websocket_routes.py +303 -0
- app/database/__init__.py +1 -0
- app/database/models.py +204 -0
- app/database/mongodb.py +221 -0
- app/database/pinecone.py +573 -0
- app/database/postgresql.py +192 -0
- app/models/__init__.py +1 -0
- app/models/mongodb_models.py +55 -0
- app/models/pdf_models.py +52 -0
- app/models/rag_models.py +68 -0
- app/utils/__init__.py +1 -0
- app/utils/cache.py +193 -0
- app/utils/debug_utils.py +207 -0
- app/utils/middleware.py +109 -0
- app/utils/pdf_processor.py +292 -0
- app/utils/utils.py +478 -0
- docker-compose.yml +21 -0
- docs/api_documentation.md +581 -0
- pytest.ini +12 -0
- requirements.txt +47 -0
.dockerignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
.gitattributes
|
| 5 |
+
|
| 6 |
+
# Environment files
|
| 7 |
+
.env
|
| 8 |
+
.env.*
|
| 9 |
+
!.env.example
|
| 10 |
+
|
| 11 |
+
# Python cache files
|
| 12 |
+
__pycache__/
|
| 13 |
+
*.py[cod]
|
| 14 |
+
*$py.class
|
| 15 |
+
*.so
|
| 16 |
+
.Python
|
| 17 |
+
.pytest_cache/
|
| 18 |
+
*.egg-info/
|
| 19 |
+
.installed.cfg
|
| 20 |
+
*.egg
|
| 21 |
+
|
| 22 |
+
# Logs
|
| 23 |
+
*.log
|
| 24 |
+
|
| 25 |
+
# Tests
|
| 26 |
+
tests/
|
| 27 |
+
|
| 28 |
+
# Docker related
|
| 29 |
+
Dockerfile
|
| 30 |
+
docker-compose.yml
|
| 31 |
+
.dockerignore
|
| 32 |
+
|
| 33 |
+
# Other files
|
| 34 |
+
.vscode/
|
| 35 |
+
.idea/
|
| 36 |
+
*.swp
|
| 37 |
+
*.swo
|
| 38 |
+
.DS_Store
|
| 39 |
+
.coverage
|
| 40 |
+
htmlcov/
|
| 41 |
+
.mypy_cache/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
instance/
|
| 45 |
+
.webassets-cache
|
| 46 |
+
main.py
|
.env.example
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PostgreSQL Configuration
|
| 2 |
+
DB_CONNECTION_MODE=aiven
|
| 3 |
+
AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
|
| 4 |
+
|
| 5 |
+
# MongoDB Configuration
|
| 6 |
+
MONGODB_URL=mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority
|
| 7 |
+
DB_NAME=Telegram
|
| 8 |
+
COLLECTION_NAME=session_chat
|
| 9 |
+
|
| 10 |
+
# Pinecone configuration
|
| 11 |
+
PINECONE_API_KEY=your-pinecone-api-key
|
| 12 |
+
PINECONE_INDEX_NAME=your-pinecone-index-name
|
| 13 |
+
PINECONE_ENVIRONMENT=gcp-starter
|
| 14 |
+
|
| 15 |
+
# Google Gemini API key
|
| 16 |
+
GOOGLE_API_KEY=your-google-api-key
|
| 17 |
+
|
| 18 |
+
# WebSocket configuration
|
| 19 |
+
WEBSOCKET_SERVER=localhost
|
| 20 |
+
WEBSOCKET_PORT=7860
|
| 21 |
+
WEBSOCKET_PATH=/notify
|
| 22 |
+
|
| 23 |
+
# Application settings
|
| 24 |
+
ENVIRONMENT=production
|
| 25 |
+
DEBUG=false
|
| 26 |
+
PORT=7860
|
| 27 |
+
|
| 28 |
+
# Cache Configuration
|
| 29 |
+
CACHE_TTL_SECONDS=300
|
| 30 |
+
CACHE_CLEANUP_INTERVAL=60
|
| 31 |
+
CACHE_MAX_SIZE=1000
|
| 32 |
+
HISTORY_QUEUE_SIZE=10
|
| 33 |
+
HISTORY_CACHE_TTL=3600
|
.gitignore
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
.pytest_cache/
|
| 23 |
+
htmlcov/
|
| 24 |
+
.coverage
|
| 25 |
+
.coverage.*
|
| 26 |
+
.cache/
|
| 27 |
+
coverage.xml
|
| 28 |
+
*.cover
|
| 29 |
+
.mypy_cache/
|
| 30 |
+
|
| 31 |
+
# Environment
|
| 32 |
+
.env
|
| 33 |
+
.venv
|
| 34 |
+
env/
|
| 35 |
+
venv/
|
| 36 |
+
ENV/
|
| 37 |
+
|
| 38 |
+
# VSCode
|
| 39 |
+
.vscode/
|
| 40 |
+
*.code-workspace
|
| 41 |
+
.history/
|
| 42 |
+
|
| 43 |
+
# PyCharm
|
| 44 |
+
.idea/
|
| 45 |
+
*.iml
|
| 46 |
+
*.iws
|
| 47 |
+
*.ipr
|
| 48 |
+
*.iws
|
| 49 |
+
out/
|
| 50 |
+
.idea_modules/
|
| 51 |
+
|
| 52 |
+
# Logs and databases
|
| 53 |
+
*.log
|
| 54 |
+
*.sql
|
| 55 |
+
*.sqlite
|
| 56 |
+
*.db
|
| 57 |
+
|
| 58 |
+
# Tests
|
| 59 |
+
tests/
|
| 60 |
+
|
| 61 |
+
Admin_bot/
|
| 62 |
+
|
| 63 |
+
Pix-Agent/
|
| 64 |
+
|
| 65 |
+
# Hugging Face Spaces
|
| 66 |
+
.gitattributes
|
| 67 |
+
|
| 68 |
+
# OS specific
|
| 69 |
+
.DS_Store
|
| 70 |
+
.DS_Store?
|
| 71 |
+
._*
|
| 72 |
+
.Spotlight-V100
|
| 73 |
+
.Trashes
|
| 74 |
+
Icon?
|
| 75 |
+
ehthumbs.db
|
| 76 |
+
Thumbs.db
|
| 77 |
+
|
| 78 |
+
# Project specific
|
| 79 |
+
*.log
|
| 80 |
+
.env
|
| 81 |
+
main.py
|
| 82 |
+
|
| 83 |
+
test/
|
Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Cài đặt các gói hệ thống cần thiết
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
curl \
|
| 9 |
+
software-properties-common \
|
| 10 |
+
git \
|
| 11 |
+
gcc \
|
| 12 |
+
python3-dev \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Sao chép các file yêu cầu trước để tận dụng cache của Docker
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# Cài đặt các gói Python
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Ensure langchain-core is installed
|
| 22 |
+
RUN pip install --no-cache-dir langchain-core==0.1.19
|
| 23 |
+
|
| 24 |
+
# Sao chép toàn bộ code vào container
|
| 25 |
+
COPY . .
|
| 26 |
+
|
| 27 |
+
# Mở cổng mà ứng dụng sẽ chạy
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Chạy ứng dụng với uvicorn
|
| 31 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PIX Project Backend
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: "3.0.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 13 |
+
|
| 14 |
+
# PIX Project Backend
|
| 15 |
+
|
| 16 |
+
[](https://fastapi.tiangolo.com/)
|
| 17 |
+
[](https://www.python.org/)
|
| 18 |
+
[](https://huggingface.co/spaces)
|
| 19 |
+
|
| 20 |
+
Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration. This project provides a comprehensive backend solution for managing FAQ items, emergency contacts, events, and a RAG-based question answering system.
|
| 21 |
+
|
| 22 |
+
## Features
|
| 23 |
+
|
| 24 |
+
- **MongoDB Integration**: Store user sessions and conversation history
|
| 25 |
+
- **PostgreSQL Integration**: Manage FAQ items, emergency contacts, and events
|
| 26 |
+
- **Pinecone Vector Database**: Store and retrieve vector embeddings for RAG
|
| 27 |
+
- **RAG Question Answering**: Answer questions using relevant information from the vector database
|
| 28 |
+
- **WebSocket Notifications**: Real-time notifications for Admin Bot
|
| 29 |
+
- **API Documentation**: Automatic OpenAPI documentation via Swagger
|
| 30 |
+
- **Docker Support**: Easy deployment using Docker
|
| 31 |
+
- **Auto-Debugging**: Built-in debugging, error tracking, and performance monitoring
|
| 32 |
+
|
| 33 |
+
## API Endpoints
|
| 34 |
+
|
| 35 |
+
### MongoDB Endpoints
|
| 36 |
+
|
| 37 |
+
- `POST /mongodb/session`: Create a new session record
|
| 38 |
+
- `PUT /mongodb/session/{session_id}/response`: Update a session with a response
|
| 39 |
+
- `GET /mongodb/history`: Get user conversation history
|
| 40 |
+
- `GET /mongodb/health`: Check MongoDB connection health
|
| 41 |
+
|
| 42 |
+
### PostgreSQL Endpoints
|
| 43 |
+
|
| 44 |
+
- `GET /postgres/health`: Check PostgreSQL connection health
|
| 45 |
+
- `GET /postgres/faq`: Get FAQ items
|
| 46 |
+
- `POST /postgres/faq`: Create a new FAQ item
|
| 47 |
+
- `GET /postgres/faq/{faq_id}`: Get a specific FAQ item
|
| 48 |
+
- `PUT /postgres/faq/{faq_id}`: Update a specific FAQ item
|
| 49 |
+
- `DELETE /postgres/faq/{faq_id}`: Delete a specific FAQ item
|
| 50 |
+
- `GET /postgres/emergency`: Get emergency contact items
|
| 51 |
+
- `POST /postgres/emergency`: Create a new emergency contact item
|
| 52 |
+
- `GET /postgres/emergency/{emergency_id}`: Get a specific emergency contact
|
| 53 |
+
- `GET /postgres/events`: Get event items
|
| 54 |
+
|
| 55 |
+
### RAG Endpoints
|
| 56 |
+
|
| 57 |
+
- `POST /rag/chat`: Get answer for a question using RAG
|
| 58 |
+
- `POST /rag/embedding`: Generate embedding for text
|
| 59 |
+
- `GET /rag/health`: Check RAG services health
|
| 60 |
+
|
| 61 |
+
### WebSocket Endpoints
|
| 62 |
+
|
| 63 |
+
- `WebSocket /notify`: Receive real-time notifications for new sessions
|
| 64 |
+
|
| 65 |
+
### Debug Endpoints (Available in Debug Mode Only)
|
| 66 |
+
|
| 67 |
+
- `GET /debug/config`: Get configuration information
|
| 68 |
+
- `GET /debug/system`: Get system information (CPU, memory, disk usage)
|
| 69 |
+
- `GET /debug/database`: Check all database connections
|
| 70 |
+
- `GET /debug/errors`: View recent error logs
|
| 71 |
+
- `GET /debug/performance`: Get performance metrics
|
| 72 |
+
- `GET /debug/full`: Get comprehensive debug information
|
| 73 |
+
|
| 74 |
+
## WebSocket API
|
| 75 |
+
|
| 76 |
+
### Notifications for New Sessions
|
| 77 |
+
|
| 78 |
+
The backend provides a WebSocket endpoint for receiving notifications about new sessions that match specific criteria.
|
| 79 |
+
|
| 80 |
+
#### WebSocket Endpoint Configuration
|
| 81 |
+
|
| 82 |
+
The WebSocket endpoint is configured using environment variables:
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
# WebSocket configuration
|
| 86 |
+
WEBSOCKET_SERVER=localhost
|
| 87 |
+
WEBSOCKET_PORT=7860
|
| 88 |
+
WEBSOCKET_PATH=/notify
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
The full WebSocket URL will be:
|
| 92 |
+
```
|
| 93 |
+
ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
For example: `ws://localhost:7860/notify`
|
| 97 |
+
|
| 98 |
+
#### Notification Criteria
|
| 99 |
+
|
| 100 |
+
A notification is sent when:
|
| 101 |
+
1. A new session is created with `factor` set to "RAG"
|
| 102 |
+
2. The message content starts with "I don't know"
|
| 103 |
+
|
| 104 |
+
#### Notification Format
|
| 105 |
+
|
| 106 |
+
```json
|
| 107 |
+
{
|
| 108 |
+
"type": "new_session",
|
| 109 |
+
"timestamp": "2025-04-15 22:30:45",
|
| 110 |
+
"data": {
|
| 111 |
+
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
| 112 |
+
"factor": "rag",
|
| 113 |
+
"action": "asking_freely",
|
| 114 |
+
"created_at": "2025-04-15 22:30:45",
|
| 115 |
+
"first_name": "John",
|
| 116 |
+
"last_name": "Doe",
|
| 117 |
+
"message": "I don't know how to find emergency contacts",
|
| 118 |
+
"user_id": "12345678",
|
| 119 |
+
"username": "johndoe"
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
#### Usage Example
|
| 125 |
+
|
| 126 |
+
Admin Bot should establish a WebSocket connection to this endpoint using the configured URL:
|
| 127 |
+
|
| 128 |
+
```python
|
| 129 |
+
import websocket
|
| 130 |
+
import json
|
| 131 |
+
import os
|
| 132 |
+
from dotenv import load_dotenv
|
| 133 |
+
|
| 134 |
+
# Load environment variables
|
| 135 |
+
load_dotenv()
|
| 136 |
+
|
| 137 |
+
# Get WebSocket configuration from environment variables
|
| 138 |
+
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
|
| 139 |
+
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
|
| 140 |
+
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
|
| 141 |
+
|
| 142 |
+
# Create full URL
|
| 143 |
+
ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
|
| 144 |
+
|
| 145 |
+
def on_message(ws, message):
|
| 146 |
+
data = json.loads(message)
|
| 147 |
+
print(f"Received notification: {data}")
|
| 148 |
+
# Forward to Telegram Admin
|
| 149 |
+
|
| 150 |
+
def on_error(ws, error):
|
| 151 |
+
print(f"Error: {error}")
|
| 152 |
+
|
| 153 |
+
def on_close(ws, close_status_code, close_msg):
|
| 154 |
+
print("Connection closed")
|
| 155 |
+
|
| 156 |
+
def on_open(ws):
|
| 157 |
+
print("Connection opened")
|
| 158 |
+
# Send keepalive message periodically
|
| 159 |
+
ws.send("keepalive")
|
| 160 |
+
|
| 161 |
+
# Connect to WebSocket
|
| 162 |
+
ws = websocket.WebSocketApp(
|
| 163 |
+
ws_url,
|
| 164 |
+
on_open=on_open,
|
| 165 |
+
on_message=on_message,
|
| 166 |
+
on_error=on_error,
|
| 167 |
+
on_close=on_close
|
| 168 |
+
)
|
| 169 |
+
ws.run_forever()
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
When a notification is received, Admin Bot should forward the content to the Telegram Admin.
|
| 173 |
+
|
| 174 |
+
## Environment Variables
|
| 175 |
+
|
| 176 |
+
Create a `.env` file with the following variables:
|
| 177 |
+
|
| 178 |
+
```
|
| 179 |
+
# PostgreSQL Configuration
|
| 180 |
+
DB_CONNECTION_MODE=aiven
|
| 181 |
+
AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
|
| 182 |
+
|
| 183 |
+
# MongoDB Configuration
|
| 184 |
+
MONGODB_URL=mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority
|
| 185 |
+
DB_NAME=Telegram
|
| 186 |
+
COLLECTION_NAME=session_chat
|
| 187 |
+
|
| 188 |
+
# Pinecone configuration
|
| 189 |
+
PINECONE_API_KEY=your-pinecone-api-key
|
| 190 |
+
PINECONE_INDEX_NAME=your-pinecone-index-name
|
| 191 |
+
PINECONE_ENVIRONMENT=gcp-starter
|
| 192 |
+
|
| 193 |
+
# Google Gemini API key
|
| 194 |
+
GOOGLE_API_KEY=your-google-api-key
|
| 195 |
+
|
| 196 |
+
# WebSocket configuration
|
| 197 |
+
WEBSOCKET_SERVER=localhost
|
| 198 |
+
WEBSOCKET_PORT=7860
|
| 199 |
+
WEBSOCKET_PATH=/notify
|
| 200 |
+
|
| 201 |
+
# Application settings
|
| 202 |
+
ENVIRONMENT=production
|
| 203 |
+
DEBUG=false
|
| 204 |
+
PORT=7860
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
## Installation and Setup
|
| 208 |
+
|
| 209 |
+
### Local Development
|
| 210 |
+
|
| 211 |
+
1. Clone the repository:
|
| 212 |
+
```bash
|
| 213 |
+
git clone https://github.com/ManTT-Data/PixAgent.git
|
| 214 |
+
cd PixAgent
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
2. Create a virtual environment and install dependencies:
|
| 218 |
+
```bash
|
| 219 |
+
python -m venv venv
|
| 220 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 221 |
+
pip install -r requirements.txt
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
3. Create a `.env` file with your configuration (see above)
|
| 225 |
+
|
| 226 |
+
4. Run the application:
|
| 227 |
+
```bash
|
| 228 |
+
uvicorn app:app --reload --port 7860
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
5. Open your browser and navigate to [http://localhost:7860/docs](http://localhost:7860/docs) to see the API documentation
|
| 232 |
+
|
| 233 |
+
### Docker Deployment
|
| 234 |
+
|
| 235 |
+
1. Build the Docker image:
|
| 236 |
+
```bash
|
| 237 |
+
docker build -t pix-project-backend .
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
2. Run the Docker container:
|
| 241 |
+
```bash
|
| 242 |
+
docker run -p 7860:7860 --env-file .env pix-project-backend
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
## Deployment to HuggingFace Spaces
|
| 246 |
+
|
| 247 |
+
1. Create a new Space on HuggingFace (Dockerfile type)
|
| 248 |
+
2. Link your GitHub repository or push directly to the HuggingFace repo
|
| 249 |
+
3. Add your environment variables in the Space settings
|
| 250 |
+
4. The deployment will use `app.py` as the entry point, which is the standard for HuggingFace Spaces
|
| 251 |
+
|
| 252 |
+
### Important Notes for HuggingFace Deployment
|
| 253 |
+
|
| 254 |
+
- The application uses `app.py` with the FastAPI instance named `app` to avoid the "Error loading ASGI app. Attribute 'app' not found in module 'app'" error
|
| 255 |
+
- Make sure all environment variables are set in the Space settings
|
| 256 |
+
- The Dockerfile is configured to expose port 7860, which is the default port for HuggingFace Spaces
|
| 257 |
+
|
| 258 |
+
## Project Structure
|
| 259 |
+
|
| 260 |
+
```
|
| 261 |
+
.
|
| 262 |
+
├── app # Main application package
|
| 263 |
+
│ ├── api # API endpoints
|
| 264 |
+
│ │ ├── mongodb_routes.py
|
| 265 |
+
│ │ ├── postgresql_routes.py
|
| 266 |
+
│ │ ├── rag_routes.py
|
| 267 |
+
│ │ └── websocket_routes.py
|
| 268 |
+
│ ├── database # Database connections
|
| 269 |
+
│ │ ├── mongodb.py
|
| 270 |
+
│ │ ├── pinecone.py
|
| 271 |
+
│ │ └── postgresql.py
|
| 272 |
+
│ ├── models # Pydantic models
|
| 273 |
+
│ │ ├── mongodb_models.py
|
| 274 |
+
│ │ ├── postgresql_models.py
|
| 275 |
+
│ │ └── rag_models.py
|
| 276 |
+
│ └── utils # Utility functions
|
| 277 |
+
│ ├── debug_utils.py
|
| 278 |
+
│ └── middleware.py
|
| 279 |
+
├── tests # Test directory
|
| 280 |
+
│ └── test_api_endpoints.py
|
| 281 |
+
├── .dockerignore # Docker ignore file
|
| 282 |
+
├── .env.example # Example environment file
|
| 283 |
+
├── .gitattributes # Git attributes
|
| 284 |
+
├── .gitignore # Git ignore file
|
| 285 |
+
├── app.py # Application entry point
|
| 286 |
+
├── docker-compose.yml # Docker compose configuration
|
| 287 |
+
├── Dockerfile # Docker configuration
|
| 288 |
+
├── pytest.ini # Pytest configuration
|
| 289 |
+
├── README.md # Project documentation
|
| 290 |
+
├── requirements.txt # Project dependencies
|
| 291 |
+
└── api_documentation.txt # API documentation for frontend engineers
|
| 292 |
+
```
|
| 293 |
+
|
| 294 |
+
## License
|
| 295 |
+
|
| 296 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 297 |
+
|
| 298 |
+
# Advanced Retrieval System
|
| 299 |
+
|
| 300 |
+
This project now features an enhanced vector retrieval system that improves the quality and relevance of information retrieved from Pinecone using threshold-based filtering and multiple similarity metrics.
|
| 301 |
+
|
| 302 |
+
## Features
|
| 303 |
+
|
| 304 |
+
### 1. Threshold-Based Retrieval
|
| 305 |
+
|
| 306 |
+
The system implements a threshold-based approach to vector retrieval, which:
|
| 307 |
+
- Retrieves a larger candidate set from the vector database
|
| 308 |
+
- Applies a similarity threshold to filter out less relevant results
|
| 309 |
+
- Returns only the most relevant documents that exceed the threshold
|
| 310 |
+
|
| 311 |
+
### 2. Multiple Similarity Metrics
|
| 312 |
+
|
| 313 |
+
The system supports multiple similarity metrics:
|
| 314 |
+
- **Cosine Similarity** (default): Measures the cosine of the angle between vectors
|
| 315 |
+
- **Dot Product**: Calculates the dot product between vectors
|
| 316 |
+
- **Euclidean Distance**: Measures the straight-line distance between vectors
|
| 317 |
+
|
| 318 |
+
Each metric has different characteristics and may perform better for different types of data and queries.
|
| 319 |
+
|
| 320 |
+
### 3. Score Normalization
|
| 321 |
+
|
| 322 |
+
For metrics like Euclidean distance where lower values indicate higher similarity, the system automatically normalizes scores to a 0-1 scale where higher values always indicate higher similarity. This makes it easier to compare results across different metrics.
|
| 323 |
+
|
| 324 |
+
## Configuration
|
| 325 |
+
|
| 326 |
+
The retrieval system can be configured through environment variables:
|
| 327 |
+
|
| 328 |
+
```
|
| 329 |
+
# Pinecone retrieval configuration
|
| 330 |
+
PINECONE_DEFAULT_LIMIT_K=10 # Maximum number of candidates to retrieve
|
| 331 |
+
PINECONE_DEFAULT_TOP_K=6 # Number of results to return after filtering
|
| 332 |
+
PINECONE_DEFAULT_SIMILARITY_METRIC=cosine # Default similarity metric
|
| 333 |
+
PINECONE_DEFAULT_SIMILARITY_THRESHOLD=0.75 # Similarity threshold (0-1)
|
| 334 |
+
PINECONE_ALLOWED_METRICS=cosine,dotproduct,euclidean # Available metrics
|
| 335 |
+
```
|
| 336 |
+
|
| 337 |
+
## API Usage
|
| 338 |
+
|
| 339 |
+
You can customize the retrieval parameters when making API requests:
|
| 340 |
+
|
| 341 |
+
```json
|
| 342 |
+
{
|
| 343 |
+
"user_id": "user123",
|
| 344 |
+
"question": "What are the best restaurants in Da Nang?",
|
| 345 |
+
"similarity_top_k": 5,
|
| 346 |
+
"limit_k": 15,
|
| 347 |
+
"similarity_metric": "cosine",
|
| 348 |
+
"similarity_threshold": 0.8
|
| 349 |
+
}
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
## Benefits
|
| 353 |
+
|
| 354 |
+
1. **Quality Improvement**: Retrieves only the most relevant documents above a certain quality threshold
|
| 355 |
+
2. **Flexibility**: Different similarity metrics can be used for different types of queries
|
| 356 |
+
3. **Efficiency**: Avoids processing irrelevant documents, improving response time
|
| 357 |
+
4. **Configurability**: All parameters can be adjusted via environment variables or at request time
|
| 358 |
+
|
| 359 |
+
## Implementation Details
|
| 360 |
+
|
| 361 |
+
The system is implemented as a custom retriever class `ThresholdRetriever` that integrates with LangChain's retrieval infrastructure while providing enhanced functionality.
|
| 362 |
+
|
| 363 |
+
## In-Memory Cache
|
| 364 |
+
|
| 365 |
+
Dự án bao gồm một hệ thống cache trong bộ nhớ để giảm thiểu truy cập đến cơ sở dữ liệu PostgreSQL và MongoDB.
|
| 366 |
+
|
| 367 |
+
### Cấu hình Cache
|
| 368 |
+
|
| 369 |
+
Cache được cấu hình thông qua các biến môi trường:
|
| 370 |
+
|
| 371 |
+
```
|
| 372 |
+
# Cache Configuration
|
| 373 |
+
CACHE_TTL_SECONDS=300 # Thời gian tồn tại của cache item (giây)
|
| 374 |
+
CACHE_CLEANUP_INTERVAL=60 # Chu kỳ xóa cache hết hạn (giây)
|
| 375 |
+
CACHE_MAX_SIZE=1000 # Số lượng item tối đa trong cache
|
| 376 |
+
HISTORY_QUEUE_SIZE=10 # Số lượng item tối đa trong queue lịch sử người dùng
|
| 377 |
+
HISTORY_CACHE_TTL=3600 # Thời gian tồn tại của lịch sử người dùng (giây)
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
### Cơ chế Cache
|
| 381 |
+
|
| 382 |
+
Hệ thống cache kết hợp hai cơ chế hết hạn:
|
| 383 |
+
|
| 384 |
+
1. **Lazy Expiration**: Kiểm tra thời hạn khi truy cập cache item. Nếu item đã hết hạn, nó sẽ bị xóa và trả về kết quả là không tìm thấy.
|
| 385 |
+
|
| 386 |
+
2. **Active Expiration**: Một background thread định kỳ quét và xóa các item đã hết hạn. Điều này giúp tránh tình trạng cache quá lớn với các item không còn được sử dụng.
|
| 387 |
+
|
| 388 |
+
### Các loại dữ liệu được cache
|
| 389 |
+
|
| 390 |
+
- **Dữ liệu PostgreSQL**: Thông tin từ các bảng FAQ, Emergency Contacts, và Events.
|
| 391 |
+
- **Lịch sử người dùng từ MongoDB**: Lịch sử hội thoại người dùng được lưu trong queue với thời gian sống tính theo lần truy cập cuối cùng.
|
| 392 |
+
|
| 393 |
+
### API Cache
|
| 394 |
+
|
| 395 |
+
Dự án cung cấp các API endpoints để quản lý cache:
|
| 396 |
+
|
| 397 |
+
- `GET /cache/stats`: Xem thống kê về cache (tổng số item, bộ nhớ sử dụng, v.v.)
|
| 398 |
+
- `DELETE /cache/clear`: Xóa toàn bộ cache
|
| 399 |
+
- `GET /debug/cache`: (Chỉ trong chế độ debug) Xem thông tin chi tiết về cache, bao gồm các keys và cấu hình
|
| 400 |
+
|
| 401 |
+
### Cách hoạt động
|
| 402 |
+
|
| 403 |
+
1. Khi một request đến, hệ thống sẽ kiểm tra dữ liệu trong cache trước.
|
| 404 |
+
2. Nếu dữ liệu tồn tại và còn hạn, trả về từ cache.
|
| 405 |
+
3. Nếu dữ liệu không tồn tại hoặc đã hết hạn, truy vấn từ database và lưu kết quả vào cache.
|
| 406 |
+
4. Khi dữ liệu được cập nhật hoặc xóa, cache liên quan sẽ tự động được xóa.
|
| 407 |
+
|
| 408 |
+
### Lịch sử người dùng
|
| 409 |
+
|
| 410 |
+
Lịch sử hội thoại người dùng được lưu trong queue riêng với cơ chế đặc biệt:
|
| 411 |
+
|
| 412 |
+
- Mỗi người dùng có một queue riêng với kích thước giới hạn (`HISTORY_QUEUE_SIZE`).
|
| 413 |
+
- Thời gian sống của queue được làm mới mỗi khi có tương tác mới.
|
| 414 |
+
- Khi queue đầy, các item cũ nhất sẽ bị loại bỏ.
|
| 415 |
+
- Queue tự động bị xóa sau một thời gian không hoạt động.
|
| 416 |
+
|
| 417 |
+
## Tác giả
|
| 418 |
+
|
| 419 |
+
- **PIX Project Team**
|
app.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Depends, Request, HTTPException, status
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
import uvicorn
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import logging
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
# Cấu hình logging
|
| 11 |
+
logging.basicConfig(
|
| 12 |
+
level=logging.INFO,
|
| 13 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 14 |
+
handlers=[
|
| 15 |
+
logging.StreamHandler(sys.stdout),
|
| 16 |
+
]
|
| 17 |
+
)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Load environment variables
|
| 21 |
+
load_dotenv()
|
| 22 |
+
DEBUG = os.getenv("DEBUG", "False").lower() in ("true", "1", "t")
|
| 23 |
+
|
| 24 |
+
# Kiểm tra các biến môi trường bắt buộc
|
| 25 |
+
required_env_vars = [
|
| 26 |
+
"AIVEN_DB_URL",
|
| 27 |
+
"MONGODB_URL",
|
| 28 |
+
"PINECONE_API_KEY",
|
| 29 |
+
"PINECONE_INDEX_NAME",
|
| 30 |
+
"GOOGLE_API_KEY"
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
missing_vars = [var for var in required_env_vars if not os.getenv(var)]
|
| 34 |
+
if missing_vars:
|
| 35 |
+
logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
|
| 36 |
+
if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
|
| 37 |
+
sys.exit(1)
|
| 38 |
+
|
| 39 |
+
# Database health checks
|
| 40 |
+
def check_database_connections():
|
| 41 |
+
"""Kiểm tra kết nối các database khi khởi động"""
|
| 42 |
+
from app.database.postgresql import check_db_connection as check_postgresql
|
| 43 |
+
from app.database.mongodb import check_db_connection as check_mongodb
|
| 44 |
+
from app.database.pinecone import check_db_connection as check_pinecone
|
| 45 |
+
|
| 46 |
+
db_status = {
|
| 47 |
+
"postgresql": check_postgresql(),
|
| 48 |
+
"mongodb": check_mongodb(),
|
| 49 |
+
"pinecone": check_pinecone()
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
all_ok = all(db_status.values())
|
| 53 |
+
if not all_ok:
|
| 54 |
+
failed_dbs = [name for name, status in db_status.items() if not status]
|
| 55 |
+
logger.error(f"Failed to connect to databases: {', '.join(failed_dbs)}")
|
| 56 |
+
if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
|
| 57 |
+
sys.exit(1)
|
| 58 |
+
|
| 59 |
+
return db_status
|
| 60 |
+
|
| 61 |
+
# Khởi tạo lifespan để kiểm tra kết nối database khi khởi động
|
| 62 |
+
@asynccontextmanager
|
| 63 |
+
async def lifespan(app: FastAPI):
|
| 64 |
+
# Startup: kiểm tra kết nối các database
|
| 65 |
+
logger.info("Starting application...")
|
| 66 |
+
db_status = check_database_connections()
|
| 67 |
+
|
| 68 |
+
# Khởi tạo bảng trong cơ sở dữ liệu (nếu chưa tồn tại)
|
| 69 |
+
if DEBUG and all(db_status.values()): # Chỉ khởi tạo bảng trong chế độ debug và khi tất cả kết nối DB thành công
|
| 70 |
+
from app.database.postgresql import create_tables
|
| 71 |
+
if create_tables():
|
| 72 |
+
logger.info("Database tables created or already exist")
|
| 73 |
+
|
| 74 |
+
yield
|
| 75 |
+
|
| 76 |
+
# Shutdown
|
| 77 |
+
logger.info("Shutting down application...")
|
| 78 |
+
|
| 79 |
+
# Import routers
|
| 80 |
+
try:
|
| 81 |
+
from app.api.mongodb_routes import router as mongodb_router
|
| 82 |
+
from app.api.postgresql_routes import router as postgresql_router
|
| 83 |
+
from app.api.rag_routes import router as rag_router
|
| 84 |
+
from app.api.pdf_routes import router as pdf_router
|
| 85 |
+
from app.api.websocket_routes import router as websocket_router
|
| 86 |
+
|
| 87 |
+
# Import middlewares
|
| 88 |
+
from app.utils.middleware import RequestLoggingMiddleware, ErrorHandlingMiddleware, DatabaseCheckMiddleware
|
| 89 |
+
|
| 90 |
+
# Import debug utilities
|
| 91 |
+
from app.utils.debug_utils import debug_view, DebugInfo, error_tracker, performance_monitor
|
| 92 |
+
|
| 93 |
+
# Import cache
|
| 94 |
+
from app.utils.cache import get_cache
|
| 95 |
+
|
| 96 |
+
except ImportError as e:
|
| 97 |
+
logger.error(f"Error importing routes or middlewares: {e}")
|
| 98 |
+
raise
|
| 99 |
+
|
| 100 |
+
# Create FastAPI app
|
| 101 |
+
app = FastAPI(
|
| 102 |
+
title="PIX Project Backend API",
|
| 103 |
+
description="Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration",
|
| 104 |
+
version="1.0.0",
|
| 105 |
+
docs_url="/docs",
|
| 106 |
+
redoc_url="/redoc",
|
| 107 |
+
debug=DEBUG,
|
| 108 |
+
lifespan=lifespan,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Configure CORS
|
| 112 |
+
app.add_middleware(
|
| 113 |
+
CORSMiddleware,
|
| 114 |
+
allow_origins=["*"],
|
| 115 |
+
allow_credentials=True,
|
| 116 |
+
allow_methods=["*"],
|
| 117 |
+
allow_headers=["*"],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Thêm middlewares
|
| 121 |
+
app.add_middleware(ErrorHandlingMiddleware)
|
| 122 |
+
app.add_middleware(RequestLoggingMiddleware)
|
| 123 |
+
if not DEBUG: # Chỉ thêm middleware kiểm tra database trong production
|
| 124 |
+
app.add_middleware(DatabaseCheckMiddleware)
|
| 125 |
+
|
| 126 |
+
# Include routers
|
| 127 |
+
app.include_router(mongodb_router)
|
| 128 |
+
app.include_router(postgresql_router)
|
| 129 |
+
app.include_router(rag_router)
|
| 130 |
+
app.include_router(pdf_router)
|
| 131 |
+
app.include_router(websocket_router)
|
| 132 |
+
|
| 133 |
+
# Root endpoint
|
| 134 |
+
@app.get("/")
|
| 135 |
+
def read_root():
|
| 136 |
+
return {
|
| 137 |
+
"message": "Welcome to PIX Project Backend API",
|
| 138 |
+
"documentation": "/docs",
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# Health check endpoint
|
| 142 |
+
@app.get("/health")
|
| 143 |
+
def health_check():
|
| 144 |
+
# Kiểm tra kết nối database
|
| 145 |
+
db_status = check_database_connections()
|
| 146 |
+
all_db_ok = all(db_status.values())
|
| 147 |
+
|
| 148 |
+
return {
|
| 149 |
+
"status": "healthy" if all_db_ok else "degraded",
|
| 150 |
+
"version": "1.0.0",
|
| 151 |
+
"environment": os.environ.get("ENVIRONMENT", "production"),
|
| 152 |
+
"databases": db_status
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
@app.get("/api/ping")
|
| 156 |
+
async def ping():
|
| 157 |
+
return {"status": "pong"}
|
| 158 |
+
|
| 159 |
+
# Cache stats endpoint
|
| 160 |
+
@app.get("/cache/stats")
|
| 161 |
+
def cache_stats():
|
| 162 |
+
"""Trả về thống kê về cache"""
|
| 163 |
+
cache = get_cache()
|
| 164 |
+
return cache.stats()
|
| 165 |
+
|
| 166 |
+
# Cache clear endpoint
|
| 167 |
+
@app.delete("/cache/clear")
|
| 168 |
+
def cache_clear():
|
| 169 |
+
"""Xóa tất cả dữ liệu trong cache"""
|
| 170 |
+
cache = get_cache()
|
| 171 |
+
cache.clear()
|
| 172 |
+
return {"message": "Cache cleared successfully"}
|
| 173 |
+
|
| 174 |
+
# Debug endpoints (chỉ có trong chế độ debug)
|
| 175 |
+
if DEBUG:
|
| 176 |
+
@app.get("/debug/config")
|
| 177 |
+
def debug_config():
|
| 178 |
+
"""Hiển thị thông tin cấu hình (chỉ trong chế độ debug)"""
|
| 179 |
+
config = {
|
| 180 |
+
"environment": os.environ.get("ENVIRONMENT", "production"),
|
| 181 |
+
"debug": DEBUG,
|
| 182 |
+
"db_connection_mode": os.environ.get("DB_CONNECTION_MODE", "aiven"),
|
| 183 |
+
"databases": {
|
| 184 |
+
"postgresql": os.environ.get("AIVEN_DB_URL", "").split("@")[1].split("/")[0] if "@" in os.environ.get("AIVEN_DB_URL", "") else "N/A",
|
| 185 |
+
"mongodb": os.environ.get("MONGODB_URL", "").split("@")[1].split("/?")[0] if "@" in os.environ.get("MONGODB_URL", "") else "N/A",
|
| 186 |
+
"pinecone": os.environ.get("PINECONE_INDEX_NAME", "N/A"),
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
return config
|
| 190 |
+
|
| 191 |
+
@app.get("/debug/system")
|
| 192 |
+
def debug_system():
|
| 193 |
+
"""Hiển thị thông tin hệ thống (chỉ trong chế độ debug)"""
|
| 194 |
+
return DebugInfo.get_system_info()
|
| 195 |
+
|
| 196 |
+
@app.get("/debug/database")
|
| 197 |
+
def debug_database():
|
| 198 |
+
"""Hiển thị trạng thái database (chỉ trong chế độ debug)"""
|
| 199 |
+
return DebugInfo.get_database_status()
|
| 200 |
+
|
| 201 |
+
@app.get("/debug/errors")
|
| 202 |
+
def debug_errors(limit: int = 10):
|
| 203 |
+
"""Hiển thị các lỗi gần đây (chỉ trong chế độ debug)"""
|
| 204 |
+
return error_tracker.get_recent_errors(limit)
|
| 205 |
+
|
| 206 |
+
@app.get("/debug/performance")
|
| 207 |
+
def debug_performance():
|
| 208 |
+
"""Hiển thị thống kê hiệu suất (chỉ trong chế độ debug)"""
|
| 209 |
+
return performance_monitor.get_stats()
|
| 210 |
+
|
| 211 |
+
@app.get("/debug/full")
|
| 212 |
+
def debug_full_report(request: Request):
|
| 213 |
+
"""Hiển thị báo cáo đầy đủ về hệ thống (chỉ trong chế độ debug)"""
|
| 214 |
+
return debug_view(request)
|
| 215 |
+
|
| 216 |
+
@app.get("/debug/cache")
|
| 217 |
+
def debug_cache():
|
| 218 |
+
"""Hiển thị thống kê về cache (chỉ trong chế độ debug)"""
|
| 219 |
+
return get_cache().stats()
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
PORT = int(os.getenv("PORT", "7860"))
|
| 223 |
+
uvicorn.run("app:app", host="0.0.0.0", port=PORT, reload=DEBUG)
|
app/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PIX Project Backend
|
| 2 |
+
# Version: 1.0.0
|
| 3 |
+
|
| 4 |
+
__version__ = "1.0.0"
|
| 5 |
+
|
| 6 |
+
# Import app từ app.py để tests có thể tìm thấy
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# Thêm thư mục gốc vào sys.path
|
| 11 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
# Sửa lại cách import đúng - 'app.py' không phải là module hợp lệ
|
| 15 |
+
# 'app' là tên module, '.py' là phần mở rộng tệp
|
| 16 |
+
from app import app
|
| 17 |
+
except ImportError:
|
| 18 |
+
# Thử cách khác nếu import trực tiếp không hoạt động
|
| 19 |
+
import importlib.util
|
| 20 |
+
spec = importlib.util.spec_from_file_location("app_module",
|
| 21 |
+
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| 22 |
+
"app.py"))
|
| 23 |
+
app_module = importlib.util.module_from_spec(spec)
|
| 24 |
+
spec.loader.exec_module(app_module)
|
| 25 |
+
app = app_module.app
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API routes package
|
app/api/mongodb_routes.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends, Query, status, Response
|
| 2 |
+
from typing import List, Optional, Dict
|
| 3 |
+
from pymongo.errors import PyMongoError
|
| 4 |
+
import logging
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import traceback
|
| 7 |
+
import asyncio
|
| 8 |
+
|
| 9 |
+
from app.database.mongodb import (
|
| 10 |
+
save_session,
|
| 11 |
+
get_chat_history,
|
| 12 |
+
update_session_response,
|
| 13 |
+
check_db_connection,
|
| 14 |
+
session_collection
|
| 15 |
+
)
|
| 16 |
+
from app.models.mongodb_models import (
|
| 17 |
+
SessionCreate,
|
| 18 |
+
SessionResponse,
|
| 19 |
+
HistoryRequest,
|
| 20 |
+
HistoryResponse,
|
| 21 |
+
QuestionAnswer
|
| 22 |
+
)
|
| 23 |
+
from app.api.websocket_routes import send_notification
|
| 24 |
+
|
| 25 |
+
# Configure logging
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
# Create router
|
| 29 |
+
router = APIRouter(
|
| 30 |
+
prefix="/mongodb",
|
| 31 |
+
tags=["MongoDB"],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@router.post("/session", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
|
| 35 |
+
async def create_session(session: SessionCreate, response: Response):
|
| 36 |
+
"""
|
| 37 |
+
Create a new session record in MongoDB.
|
| 38 |
+
|
| 39 |
+
- **session_id**: Unique identifier for the session (auto-generated if not provided)
|
| 40 |
+
- **factor**: Factor type (user, rag, etc.)
|
| 41 |
+
- **action**: Action type (start, events, faq, emergency, help, asking_freely, etc.)
|
| 42 |
+
- **first_name**: User's first name
|
| 43 |
+
- **last_name**: User's last name (optional)
|
| 44 |
+
- **message**: User's message (optional)
|
| 45 |
+
- **user_id**: User's ID from Telegram
|
| 46 |
+
- **username**: User's username (optional)
|
| 47 |
+
- **response**: Response from RAG (optional)
|
| 48 |
+
"""
|
| 49 |
+
try:
|
| 50 |
+
# Kiểm tra kết nối MongoDB
|
| 51 |
+
if not check_db_connection():
|
| 52 |
+
logger.error("MongoDB connection failed when trying to create session")
|
| 53 |
+
raise HTTPException(
|
| 54 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 55 |
+
detail="MongoDB connection failed"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Create new session in MongoDB
|
| 59 |
+
result = save_session(
|
| 60 |
+
session_id=session.session_id,
|
| 61 |
+
factor=session.factor,
|
| 62 |
+
action=session.action,
|
| 63 |
+
first_name=session.first_name,
|
| 64 |
+
last_name=session.last_name,
|
| 65 |
+
message=session.message,
|
| 66 |
+
user_id=session.user_id,
|
| 67 |
+
username=session.username,
|
| 68 |
+
response=session.response
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Chuẩn bị response object
|
| 72 |
+
session_response = SessionResponse(
|
| 73 |
+
**session.model_dump(),
|
| 74 |
+
created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Kiểm tra nếu session cần gửi thông báo (response bắt đầu bằng "I'm sorry")
|
| 78 |
+
if session.response and session.response.strip().lower().startswith("i'm sorry"):
|
| 79 |
+
# Gửi thông báo qua WebSocket
|
| 80 |
+
try:
|
| 81 |
+
notification_data = {
|
| 82 |
+
"session_id": session.session_id,
|
| 83 |
+
"factor": session.factor,
|
| 84 |
+
"action": session.action,
|
| 85 |
+
"message": session.message,
|
| 86 |
+
"user_id": session.user_id,
|
| 87 |
+
"username": session.username,
|
| 88 |
+
"first_name": session.first_name,
|
| 89 |
+
"last_name": session.last_name,
|
| 90 |
+
"response": session.response,
|
| 91 |
+
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Khởi tạo task để gửi thông báo - sử dụng asyncio.create_task để đảm bảo không block quá trình chính
|
| 95 |
+
asyncio.create_task(send_notification(notification_data))
|
| 96 |
+
logger.info(f"Notification queued for session {session.session_id} - response starts with 'I'm sorry'")
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Error queueing notification: {e}")
|
| 99 |
+
# Không dừng xử lý chính khi gửi thông báo thất bại
|
| 100 |
+
|
| 101 |
+
# Return response
|
| 102 |
+
return session_response
|
| 103 |
+
except PyMongoError as e:
|
| 104 |
+
logger.error(f"MongoDB error creating session: {e}")
|
| 105 |
+
logger.error(traceback.format_exc())
|
| 106 |
+
raise HTTPException(
|
| 107 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 108 |
+
detail=f"MongoDB error: {str(e)}"
|
| 109 |
+
)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Unexpected error creating session: {e}")
|
| 112 |
+
logger.error(traceback.format_exc())
|
| 113 |
+
raise HTTPException(
|
| 114 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 115 |
+
detail=f"Failed to create session: {str(e)}"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
@router.put("/session/{session_id}/response", status_code=status.HTTP_200_OK)
|
| 119 |
+
async def update_session_with_response(session_id: str, response_text: str):
|
| 120 |
+
"""
|
| 121 |
+
Update a session with the response.
|
| 122 |
+
|
| 123 |
+
- **session_id**: ID of the session to update
|
| 124 |
+
- **response_text**: Response to add to the session
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
# Kiểm tra kết nối MongoDB
|
| 128 |
+
if not check_db_connection():
|
| 129 |
+
logger.error("MongoDB connection failed when trying to update session response")
|
| 130 |
+
raise HTTPException(
|
| 131 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 132 |
+
detail="MongoDB connection failed"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Update session in MongoDB
|
| 136 |
+
result = update_session_response(session_id, response_text)
|
| 137 |
+
|
| 138 |
+
if not result:
|
| 139 |
+
raise HTTPException(
|
| 140 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 141 |
+
detail=f"Session with ID {session_id} not found"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return {"status": "success", "message": "Response added to session"}
|
| 145 |
+
except PyMongoError as e:
|
| 146 |
+
logger.error(f"MongoDB error updating session response: {e}")
|
| 147 |
+
logger.error(traceback.format_exc())
|
| 148 |
+
raise HTTPException(
|
| 149 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 150 |
+
detail=f"MongoDB error: {str(e)}"
|
| 151 |
+
)
|
| 152 |
+
except HTTPException:
|
| 153 |
+
# Re-throw HTTP exceptions
|
| 154 |
+
raise
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"Unexpected error updating session response: {e}")
|
| 157 |
+
logger.error(traceback.format_exc())
|
| 158 |
+
raise HTTPException(
|
| 159 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 160 |
+
detail=f"Failed to update session: {str(e)}"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
@router.get("/history", response_model=HistoryResponse)
|
| 164 |
+
async def get_history(user_id: str, n: int = Query(3, ge=1, le=10)):
|
| 165 |
+
"""
|
| 166 |
+
Get user history for a specific user.
|
| 167 |
+
|
| 168 |
+
- **user_id**: User's ID from Telegram
|
| 169 |
+
- **n**: Number of most recent interactions to return (default: 3, min: 1, max: 10)
|
| 170 |
+
"""
|
| 171 |
+
try:
|
| 172 |
+
# Kiểm tra kết nối MongoDB
|
| 173 |
+
if not check_db_connection():
|
| 174 |
+
logger.error("MongoDB connection failed when trying to get user history")
|
| 175 |
+
raise HTTPException(
|
| 176 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 177 |
+
detail="MongoDB connection failed"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Get user history from MongoDB
|
| 181 |
+
history_data = get_chat_history(user_id=user_id, n=n)
|
| 182 |
+
|
| 183 |
+
# Convert to response model
|
| 184 |
+
return HistoryResponse(history=history_data)
|
| 185 |
+
except PyMongoError as e:
|
| 186 |
+
logger.error(f"MongoDB error getting user history: {e}")
|
| 187 |
+
logger.error(traceback.format_exc())
|
| 188 |
+
raise HTTPException(
|
| 189 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 190 |
+
detail=f"MongoDB error: {str(e)}"
|
| 191 |
+
)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(f"Unexpected error getting user history: {e}")
|
| 194 |
+
logger.error(traceback.format_exc())
|
| 195 |
+
raise HTTPException(
|
| 196 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 197 |
+
detail=f"Failed to get user history: {str(e)}"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
@router.get("/health")
|
| 201 |
+
async def health_check():
|
| 202 |
+
"""
|
| 203 |
+
Check health of MongoDB connection.
|
| 204 |
+
"""
|
| 205 |
+
try:
|
| 206 |
+
# Kiểm tra kết nối MongoDB
|
| 207 |
+
is_connected = check_db_connection()
|
| 208 |
+
|
| 209 |
+
if not is_connected:
|
| 210 |
+
return {
|
| 211 |
+
"status": "unhealthy",
|
| 212 |
+
"message": "MongoDB connection failed",
|
| 213 |
+
"timestamp": datetime.now().isoformat()
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return {
|
| 217 |
+
"status": "healthy",
|
| 218 |
+
"message": "MongoDB connection is working",
|
| 219 |
+
"timestamp": datetime.now().isoformat()
|
| 220 |
+
}
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"MongoDB health check failed: {e}")
|
| 223 |
+
logger.error(traceback.format_exc())
|
| 224 |
+
return {
|
| 225 |
+
"status": "error",
|
| 226 |
+
"message": f"MongoDB health check error: {str(e)}",
|
| 227 |
+
"timestamp": datetime.now().isoformat()
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
@router.get("/session/{session_id}")
|
| 231 |
+
async def get_session(session_id: str):
|
| 232 |
+
"""
|
| 233 |
+
Lấy thông tin session từ MongoDB theo session_id.
|
| 234 |
+
|
| 235 |
+
- **session_id**: ID của session cần lấy
|
| 236 |
+
"""
|
| 237 |
+
try:
|
| 238 |
+
# Kiểm tra kết nối MongoDB
|
| 239 |
+
if not check_db_connection():
|
| 240 |
+
logger.error("MongoDB connection failed when trying to get session")
|
| 241 |
+
raise HTTPException(
|
| 242 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 243 |
+
detail="MongoDB connection failed"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Lấy thông tin từ MongoDB
|
| 247 |
+
session_data = session_collection.find_one({"session_id": session_id})
|
| 248 |
+
|
| 249 |
+
if not session_data:
|
| 250 |
+
raise HTTPException(
|
| 251 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 252 |
+
detail=f"Session with ID {session_id} not found"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Chuyển _id thành string để có thể JSON serialize
|
| 256 |
+
if "_id" in session_data:
|
| 257 |
+
session_data["_id"] = str(session_data["_id"])
|
| 258 |
+
|
| 259 |
+
return session_data
|
| 260 |
+
except PyMongoError as e:
|
| 261 |
+
logger.error(f"MongoDB error getting session: {e}")
|
| 262 |
+
logger.error(traceback.format_exc())
|
| 263 |
+
raise HTTPException(
|
| 264 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 265 |
+
detail=f"MongoDB error: {str(e)}"
|
| 266 |
+
)
|
| 267 |
+
except HTTPException:
|
| 268 |
+
# Re-throw HTTP exceptions
|
| 269 |
+
raise
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.error(f"Unexpected error getting session: {e}")
|
| 272 |
+
logger.error(traceback.format_exc())
|
| 273 |
+
raise HTTPException(
|
| 274 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 275 |
+
detail=f"Failed to get session: {str(e)}"
|
| 276 |
+
)
|
app/api/pdf_routes.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import uuid
|
| 4 |
+
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks, Depends
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
from typing import Optional, List, Dict, Any
|
| 7 |
+
from sqlalchemy.orm import Session
|
| 8 |
+
|
| 9 |
+
from app.utils.pdf_processor import PDFProcessor
|
| 10 |
+
from app.models.pdf_models import PDFResponse, DeleteDocumentRequest, DocumentsListResponse
|
| 11 |
+
from app.database.postgresql import get_db
|
| 12 |
+
from app.database.models import VectorDatabase, Document, VectorStatus, DocumentContent
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from app.api.pdf_websocket import (
|
| 15 |
+
send_pdf_upload_started,
|
| 16 |
+
send_pdf_upload_progress,
|
| 17 |
+
send_pdf_upload_completed,
|
| 18 |
+
send_pdf_upload_failed,
|
| 19 |
+
send_pdf_delete_started,
|
| 20 |
+
send_pdf_delete_completed,
|
| 21 |
+
send_pdf_delete_failed
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Khởi tạo router
|
| 25 |
+
router = APIRouter(
|
| 26 |
+
prefix="/pdf",
|
| 27 |
+
tags=["PDF Processing"],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Thư mục lưu file tạm - sử dụng /tmp để tránh lỗi quyền truy cập
|
| 31 |
+
TEMP_UPLOAD_DIR = "/tmp/uploads/temp"
|
| 32 |
+
STORAGE_DIR = "/tmp/uploads/pdfs"
|
| 33 |
+
|
| 34 |
+
# Đảm bảo thư mục upload tồn tại
|
| 35 |
+
os.makedirs(TEMP_UPLOAD_DIR, exist_ok=True)
|
| 36 |
+
os.makedirs(STORAGE_DIR, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# Endpoint upload và xử lý PDF
|
| 39 |
+
@router.post("/upload", response_model=PDFResponse)
|
| 40 |
+
async def upload_pdf(
|
| 41 |
+
file: UploadFile = File(...),
|
| 42 |
+
namespace: str = Form("Default"),
|
| 43 |
+
index_name: str = Form("testbot768"),
|
| 44 |
+
title: Optional[str] = Form(None),
|
| 45 |
+
description: Optional[str] = Form(None),
|
| 46 |
+
user_id: Optional[str] = Form(None),
|
| 47 |
+
vector_database_id: Optional[int] = Form(None),
|
| 48 |
+
background_tasks: BackgroundTasks = None,
|
| 49 |
+
db: Session = Depends(get_db)
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Upload và xử lý file PDF để tạo embeddings và lưu vào Pinecone
|
| 53 |
+
|
| 54 |
+
- **file**: File PDF cần xử lý
|
| 55 |
+
- **namespace**: Namespace trong Pinecone để lưu embeddings (mặc định: "Default")
|
| 56 |
+
- **index_name**: Tên index Pinecone (mặc định: "testbot768")
|
| 57 |
+
- **title**: Tiêu đề của tài liệu (tùy chọn)
|
| 58 |
+
- **description**: Mô tả về tài liệu (tùy chọn)
|
| 59 |
+
- **user_id**: ID của người dùng để cập nhật trạng thái qua WebSocket
|
| 60 |
+
- **vector_database_id**: ID của vector database trong PostgreSQL (tùy chọn)
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
# Kiểm tra file có phải PDF không
|
| 64 |
+
if not file.filename.lower().endswith('.pdf'):
|
| 65 |
+
raise HTTPException(status_code=400, detail="Chỉ chấp nhận file PDF")
|
| 66 |
+
|
| 67 |
+
# Nếu có vector_database_id, lấy thông tin từ PostgreSQL
|
| 68 |
+
api_key = None
|
| 69 |
+
vector_db = None
|
| 70 |
+
|
| 71 |
+
if vector_database_id:
|
| 72 |
+
vector_db = db.query(VectorDatabase).filter(
|
| 73 |
+
VectorDatabase.id == vector_database_id,
|
| 74 |
+
VectorDatabase.status == "active"
|
| 75 |
+
).first()
|
| 76 |
+
|
| 77 |
+
if not vector_db:
|
| 78 |
+
raise HTTPException(status_code=404, detail="Vector database không tồn tại hoặc không hoạt động")
|
| 79 |
+
|
| 80 |
+
# Sử dụng thông tin từ vector database
|
| 81 |
+
api_key = vector_db.api_key
|
| 82 |
+
index_name = vector_db.pinecone_index
|
| 83 |
+
|
| 84 |
+
# Tạo file_id và lưu file tạm
|
| 85 |
+
file_id = str(uuid.uuid4())
|
| 86 |
+
temp_file_path = os.path.join(TEMP_UPLOAD_DIR, f"{file_id}.pdf")
|
| 87 |
+
|
| 88 |
+
# Gửi thông báo bắt đầu xử lý qua WebSocket nếu có user_id
|
| 89 |
+
if user_id:
|
| 90 |
+
await send_pdf_upload_started(user_id, file.filename, file_id)
|
| 91 |
+
|
| 92 |
+
# Lưu file
|
| 93 |
+
file_content = await file.read()
|
| 94 |
+
with open(temp_file_path, "wb") as buffer:
|
| 95 |
+
buffer.write(file_content)
|
| 96 |
+
|
| 97 |
+
# Tạo metadata
|
| 98 |
+
metadata = {
|
| 99 |
+
"filename": file.filename,
|
| 100 |
+
"content_type": file.content_type
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
if title:
|
| 104 |
+
metadata["title"] = title
|
| 105 |
+
if description:
|
| 106 |
+
metadata["description"] = description
|
| 107 |
+
|
| 108 |
+
# Gửi thông báo tiến độ qua WebSocket
|
| 109 |
+
if user_id:
|
| 110 |
+
await send_pdf_upload_progress(
|
| 111 |
+
user_id,
|
| 112 |
+
file_id,
|
| 113 |
+
"file_preparation",
|
| 114 |
+
0.2,
|
| 115 |
+
"File saved, preparing for processing"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Lưu thông tin tài liệu vào PostgreSQL nếu có vector_database_id
|
| 119 |
+
if vector_database_id and vector_db:
|
| 120 |
+
# Create document record without file content
|
| 121 |
+
document = Document(
|
| 122 |
+
name=title or file.filename,
|
| 123 |
+
file_type="pdf",
|
| 124 |
+
content_type=file.content_type,
|
| 125 |
+
size=len(file_content),
|
| 126 |
+
is_embedded=False,
|
| 127 |
+
vector_database_id=vector_database_id
|
| 128 |
+
)
|
| 129 |
+
db.add(document)
|
| 130 |
+
db.commit()
|
| 131 |
+
db.refresh(document)
|
| 132 |
+
|
| 133 |
+
# Create document content record to store binary data separately
|
| 134 |
+
document_content = DocumentContent(
|
| 135 |
+
document_id=document.id,
|
| 136 |
+
file_content=file_content
|
| 137 |
+
)
|
| 138 |
+
db.add(document_content)
|
| 139 |
+
db.commit()
|
| 140 |
+
|
| 141 |
+
# Tạo vector status record
|
| 142 |
+
vector_status = VectorStatus(
|
| 143 |
+
document_id=document.id,
|
| 144 |
+
vector_database_id=vector_database_id,
|
| 145 |
+
status="pending"
|
| 146 |
+
)
|
| 147 |
+
db.add(vector_status)
|
| 148 |
+
db.commit()
|
| 149 |
+
|
| 150 |
+
# Khởi tạo PDF processor với API key nếu có
|
| 151 |
+
processor = PDFProcessor(index_name=index_name, namespace=namespace, api_key=api_key)
|
| 152 |
+
|
| 153 |
+
# Gửi thông báo bắt đầu embedding qua WebSocket
|
| 154 |
+
if user_id:
|
| 155 |
+
await send_pdf_upload_progress(
|
| 156 |
+
user_id,
|
| 157 |
+
file_id,
|
| 158 |
+
"embedding_start",
|
| 159 |
+
0.4,
|
| 160 |
+
"Starting to process PDF and create embeddings"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Xử lý PDF và tạo embeddings
|
| 164 |
+
# Tạo callback function để xử lý cập nhật tiến độ
|
| 165 |
+
async def progress_callback_wrapper(step, progress, message):
|
| 166 |
+
if user_id:
|
| 167 |
+
await send_progress_update(user_id, file_id, step, progress, message)
|
| 168 |
+
|
| 169 |
+
# Xử lý PDF và tạo embeddings với callback đã được xử lý đúng cách
|
| 170 |
+
result = await processor.process_pdf(
|
| 171 |
+
file_path=temp_file_path,
|
| 172 |
+
document_id=file_id,
|
| 173 |
+
metadata=metadata,
|
| 174 |
+
progress_callback=progress_callback_wrapper
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Nếu thành công, chuyển file vào storage
|
| 178 |
+
if result.get('success'):
|
| 179 |
+
storage_path = os.path.join(STORAGE_DIR, f"{file_id}.pdf")
|
| 180 |
+
shutil.move(temp_file_path, storage_path)
|
| 181 |
+
|
| 182 |
+
# Cập nhật trạng thái trong PostgreSQL nếu có vector_database_id
|
| 183 |
+
if vector_database_id and 'document' in locals() and 'vector_status' in locals():
|
| 184 |
+
vector_status.status = "completed"
|
| 185 |
+
vector_status.embedded_at = datetime.now()
|
| 186 |
+
vector_status.vector_id = file_id
|
| 187 |
+
document.is_embedded = True
|
| 188 |
+
db.commit()
|
| 189 |
+
|
| 190 |
+
# Gửi thông báo hoàn thành qua WebSocket
|
| 191 |
+
if user_id:
|
| 192 |
+
await send_pdf_upload_completed(
|
| 193 |
+
user_id,
|
| 194 |
+
file_id,
|
| 195 |
+
file.filename,
|
| 196 |
+
result.get('chunks_processed', 0)
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
# Cập nhật trạng thái lỗi trong PostgreSQL nếu có vector_database_id
|
| 200 |
+
if vector_database_id and 'vector_status' in locals():
|
| 201 |
+
vector_status.status = "failed"
|
| 202 |
+
vector_status.error_message = result.get('error', 'Unknown error')
|
| 203 |
+
db.commit()
|
| 204 |
+
|
| 205 |
+
# Gửi thông báo lỗi qua WebSocket
|
| 206 |
+
if user_id:
|
| 207 |
+
await send_pdf_upload_failed(
|
| 208 |
+
user_id,
|
| 209 |
+
file_id,
|
| 210 |
+
file.filename,
|
| 211 |
+
result.get('error', 'Unknown error')
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Dọn dẹp: xóa file tạm nếu vẫn còn
|
| 215 |
+
if os.path.exists(temp_file_path):
|
| 216 |
+
os.remove(temp_file_path)
|
| 217 |
+
|
| 218 |
+
return result
|
| 219 |
+
except Exception as e:
|
| 220 |
+
# Dọn dẹp nếu có lỗi
|
| 221 |
+
if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
|
| 222 |
+
os.remove(temp_file_path)
|
| 223 |
+
|
| 224 |
+
# Cập nhật trạng thái lỗi trong PostgreSQL nếu có vector_database_id
|
| 225 |
+
if 'vector_database_id' in locals() and vector_database_id and 'vector_status' in locals():
|
| 226 |
+
vector_status.status = "failed"
|
| 227 |
+
vector_status.error_message = str(e)
|
| 228 |
+
db.commit()
|
| 229 |
+
|
| 230 |
+
# Gửi thông báo lỗi qua WebSocket
|
| 231 |
+
if 'user_id' in locals() and user_id and 'file_id' in locals():
|
| 232 |
+
await send_pdf_upload_failed(
|
| 233 |
+
user_id,
|
| 234 |
+
file_id,
|
| 235 |
+
file.filename,
|
| 236 |
+
str(e)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return PDFResponse(
|
| 240 |
+
success=False,
|
| 241 |
+
error=str(e)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Function để gửi cập nhật tiến độ - được sử dụng trong callback
|
| 245 |
+
async def send_progress_update(user_id, document_id, step, progress, message):
|
| 246 |
+
if user_id:
|
| 247 |
+
await send_pdf_upload_progress(user_id, document_id, step, progress, message)
|
| 248 |
+
|
| 249 |
+
# Endpoint xóa tài liệu
|
| 250 |
+
@router.delete("/namespace", response_model=PDFResponse)
|
| 251 |
+
async def delete_namespace(
|
| 252 |
+
namespace: str = "Default",
|
| 253 |
+
index_name: str = "testbot768",
|
| 254 |
+
user_id: Optional[str] = None
|
| 255 |
+
):
|
| 256 |
+
"""
|
| 257 |
+
Xóa toàn bộ embeddings trong một namespace từ Pinecone (tương ứng xoá namespace)
|
| 258 |
+
|
| 259 |
+
- **namespace**: Namespace trong Pinecone (mặc định: "Default")
|
| 260 |
+
- **index_name**: Tên index Pinecone (mặc định: "testbot768")
|
| 261 |
+
- **user_id**: ID của người dùng để cập nhật trạng thái qua WebSocket
|
| 262 |
+
"""
|
| 263 |
+
try:
|
| 264 |
+
# Gửi thông báo bắt đầu xóa qua WebSocket
|
| 265 |
+
if user_id:
|
| 266 |
+
await send_pdf_delete_started(user_id, namespace)
|
| 267 |
+
|
| 268 |
+
processor = PDFProcessor(index_name=index_name, namespace=namespace)
|
| 269 |
+
result = await processor.delete_namespace()
|
| 270 |
+
|
| 271 |
+
# Gửi thông báo kết quả qua WebSocket
|
| 272 |
+
if user_id:
|
| 273 |
+
if result.get('success'):
|
| 274 |
+
await send_pdf_delete_completed(user_id, namespace)
|
| 275 |
+
else:
|
| 276 |
+
await send_pdf_delete_failed(user_id, namespace, result.get('error', 'Unknown error'))
|
| 277 |
+
|
| 278 |
+
return result
|
| 279 |
+
except Exception as e:
|
| 280 |
+
# Gửi thông báo lỗi qua WebSocket
|
| 281 |
+
if user_id:
|
| 282 |
+
await send_pdf_delete_failed(user_id, namespace, str(e))
|
| 283 |
+
|
| 284 |
+
return PDFResponse(
|
| 285 |
+
success=False,
|
| 286 |
+
error=str(e)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Endpoint lấy danh sách tài liệu
|
| 290 |
+
@router.get("/documents", response_model=DocumentsListResponse)
|
| 291 |
+
async def get_documents(namespace: str = "Default", index_name: str = "testbot768"):
|
| 292 |
+
"""
|
| 293 |
+
Lấy thông tin về tất cả tài liệu đã được embed
|
| 294 |
+
|
| 295 |
+
- **namespace**: Namespace trong Pinecone (mặc định: "Default")
|
| 296 |
+
- **index_name**: Tên index Pinecone (mặc định: "testbot768")
|
| 297 |
+
"""
|
| 298 |
+
try:
|
| 299 |
+
# Khởi tạo PDF processor
|
| 300 |
+
processor = PDFProcessor(index_name=index_name, namespace=namespace)
|
| 301 |
+
|
| 302 |
+
# Lấy danh sách documents
|
| 303 |
+
result = await processor.list_documents()
|
| 304 |
+
|
| 305 |
+
return result
|
| 306 |
+
except Exception as e:
|
| 307 |
+
return DocumentsListResponse(
|
| 308 |
+
success=False,
|
| 309 |
+
error=str(e)
|
| 310 |
+
)
|
app/api/pdf_websocket.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, List, Optional, Any
|
| 3 |
+
from fastapi import WebSocket, WebSocketDisconnect, APIRouter
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
# Cấu hình logging
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# Models cho Swagger documentation
|
| 12 |
+
class ConnectionStatus(BaseModel):
|
| 13 |
+
user_id: str
|
| 14 |
+
active: bool
|
| 15 |
+
connection_count: int
|
| 16 |
+
last_activity: Optional[float] = None
|
| 17 |
+
|
| 18 |
+
class UserConnection(BaseModel):
|
| 19 |
+
user_id: str
|
| 20 |
+
connection_count: int
|
| 21 |
+
|
| 22 |
+
class AllConnectionsStatus(BaseModel):
|
| 23 |
+
total_users: int
|
| 24 |
+
total_connections: int
|
| 25 |
+
users: List[UserConnection]
|
| 26 |
+
|
| 27 |
+
# Khởi tạo router
|
| 28 |
+
router = APIRouter(
|
| 29 |
+
prefix="/ws",
|
| 30 |
+
tags=["WebSockets"],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
class ConnectionManager:
|
| 34 |
+
"""Quản lý các kết nối WebSocket"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
# Lưu trữ các kết nối theo user_id
|
| 38 |
+
self.active_connections: Dict[str, List[WebSocket]] = {}
|
| 39 |
+
|
| 40 |
+
async def connect(self, websocket: WebSocket, user_id: str):
|
| 41 |
+
"""Kết nối một WebSocket mới"""
|
| 42 |
+
await websocket.accept()
|
| 43 |
+
if user_id not in self.active_connections:
|
| 44 |
+
self.active_connections[user_id] = []
|
| 45 |
+
self.active_connections[user_id].append(websocket)
|
| 46 |
+
logger.info(f"New WebSocket connection for user {user_id}. Total connections: {len(self.active_connections[user_id])}")
|
| 47 |
+
|
| 48 |
+
def disconnect(self, websocket: WebSocket, user_id: str):
|
| 49 |
+
"""Ngắt kết nối WebSocket"""
|
| 50 |
+
if user_id in self.active_connections:
|
| 51 |
+
if websocket in self.active_connections[user_id]:
|
| 52 |
+
self.active_connections[user_id].remove(websocket)
|
| 53 |
+
# Xóa user_id khỏi dict nếu không còn kết nối nào
|
| 54 |
+
if not self.active_connections[user_id]:
|
| 55 |
+
del self.active_connections[user_id]
|
| 56 |
+
logger.info(f"WebSocket disconnected for user {user_id}")
|
| 57 |
+
|
| 58 |
+
async def send_message(self, message: Dict[str, Any], user_id: str):
|
| 59 |
+
"""Gửi tin nhắn tới tất cả kết nối của một user"""
|
| 60 |
+
if user_id in self.active_connections:
|
| 61 |
+
disconnected_websockets = []
|
| 62 |
+
for websocket in self.active_connections[user_id]:
|
| 63 |
+
try:
|
| 64 |
+
await websocket.send_text(json.dumps(message))
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error(f"Error sending message to WebSocket: {str(e)}")
|
| 67 |
+
disconnected_websockets.append(websocket)
|
| 68 |
+
|
| 69 |
+
# Xóa các kết nối bị ngắt
|
| 70 |
+
for websocket in disconnected_websockets:
|
| 71 |
+
self.disconnect(websocket, user_id)
|
| 72 |
+
|
| 73 |
+
def get_connection_status(self, user_id: str = None) -> Dict[str, Any]:
|
| 74 |
+
"""Lấy thông tin về trạng thái kết nối WebSocket"""
|
| 75 |
+
if user_id:
|
| 76 |
+
# Trả về thông tin kết nối cho user cụ thể
|
| 77 |
+
if user_id in self.active_connections:
|
| 78 |
+
return {
|
| 79 |
+
"user_id": user_id,
|
| 80 |
+
"active": True,
|
| 81 |
+
"connection_count": len(self.active_connections[user_id]),
|
| 82 |
+
"last_activity": time.time()
|
| 83 |
+
}
|
| 84 |
+
else:
|
| 85 |
+
return {
|
| 86 |
+
"user_id": user_id,
|
| 87 |
+
"active": False,
|
| 88 |
+
"connection_count": 0,
|
| 89 |
+
"last_activity": None
|
| 90 |
+
}
|
| 91 |
+
else:
|
| 92 |
+
# Trả về thông tin tất cả kết nối
|
| 93 |
+
result = {
|
| 94 |
+
"total_users": len(self.active_connections),
|
| 95 |
+
"total_connections": sum(len(connections) for connections in self.active_connections.values()),
|
| 96 |
+
"users": []
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
for uid, connections in self.active_connections.items():
|
| 100 |
+
result["users"].append({
|
| 101 |
+
"user_id": uid,
|
| 102 |
+
"connection_count": len(connections)
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
return result
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Tạo instance của ConnectionManager
|
| 109 |
+
manager = ConnectionManager()
|
| 110 |
+
|
| 111 |
+
@router.websocket("/pdf/{user_id}")
|
| 112 |
+
async def websocket_endpoint(websocket: WebSocket, user_id: str):
|
| 113 |
+
"""Endpoint WebSocket để cập nhật tiến trình xử lý PDF"""
|
| 114 |
+
await manager.connect(websocket, user_id)
|
| 115 |
+
try:
|
| 116 |
+
while True:
|
| 117 |
+
# Đợi tin nhắn từ client (chỉ để giữ kết nối)
|
| 118 |
+
await websocket.receive_text()
|
| 119 |
+
except WebSocketDisconnect:
|
| 120 |
+
manager.disconnect(websocket, user_id)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"WebSocket error: {str(e)}")
|
| 123 |
+
manager.disconnect(websocket, user_id)
|
| 124 |
+
|
| 125 |
+
# API endpoints để kiểm tra trạng thái WebSocket
|
| 126 |
+
@router.get("/status", response_model=AllConnectionsStatus, responses={
|
| 127 |
+
200: {
|
| 128 |
+
"description": "Successful response",
|
| 129 |
+
"content": {
|
| 130 |
+
"application/json": {
|
| 131 |
+
"example": {
|
| 132 |
+
"total_users": 2,
|
| 133 |
+
"total_connections": 3,
|
| 134 |
+
"users": [
|
| 135 |
+
{"user_id": "user1", "connection_count": 2},
|
| 136 |
+
{"user_id": "user2", "connection_count": 1}
|
| 137 |
+
]
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
})
|
| 143 |
+
async def get_all_websocket_connections():
|
| 144 |
+
"""
|
| 145 |
+
Lấy thông tin về tất cả kết nối WebSocket hiện tại.
|
| 146 |
+
|
| 147 |
+
Endpoint này trả về:
|
| 148 |
+
- Tổng số người dùng đang kết nối
|
| 149 |
+
- Tổng số kết nối WebSocket
|
| 150 |
+
- Danh sách người dùng kèm theo số lượng kết nối của mỗi người
|
| 151 |
+
"""
|
| 152 |
+
return manager.get_connection_status()
|
| 153 |
+
|
| 154 |
+
@router.get("/status/{user_id}", response_model=ConnectionStatus, responses={
|
| 155 |
+
200: {
|
| 156 |
+
"description": "Successful response for active connection",
|
| 157 |
+
"content": {
|
| 158 |
+
"application/json": {
|
| 159 |
+
"examples": {
|
| 160 |
+
"active_connection": {
|
| 161 |
+
"summary": "Active connection",
|
| 162 |
+
"value": {
|
| 163 |
+
"user_id": "user123",
|
| 164 |
+
"active": True,
|
| 165 |
+
"connection_count": 2,
|
| 166 |
+
"last_activity": 1634567890.123
|
| 167 |
+
}
|
| 168 |
+
},
|
| 169 |
+
"no_connection": {
|
| 170 |
+
"summary": "No active connection",
|
| 171 |
+
"value": {
|
| 172 |
+
"user_id": "user456",
|
| 173 |
+
"active": False,
|
| 174 |
+
"connection_count": 0,
|
| 175 |
+
"last_activity": None
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
})
|
| 183 |
+
async def get_user_websocket_status(user_id: str):
|
| 184 |
+
"""
|
| 185 |
+
Lấy thông tin về kết nối WebSocket của một người dùng cụ thể.
|
| 186 |
+
|
| 187 |
+
Parameters:
|
| 188 |
+
- **user_id**: ID của người dùng cần kiểm tra
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
- Thông tin về trạng thái kết nối, bao gồm:
|
| 192 |
+
- active: Có đang kết nối hay không
|
| 193 |
+
- connection_count: Số lượng kết nối hiện tại
|
| 194 |
+
- last_activity: Thời gian hoạt động gần nhất
|
| 195 |
+
"""
|
| 196 |
+
return manager.get_connection_status(user_id)
|
| 197 |
+
|
| 198 |
+
# Các hàm gửi thông báo cập nhật trạng thái
|
| 199 |
+
|
| 200 |
+
async def send_pdf_upload_started(user_id: str, filename: str, document_id: str):
|
| 201 |
+
"""Gửi thông báo bắt đầu upload PDF"""
|
| 202 |
+
await manager.send_message({
|
| 203 |
+
"type": "pdf_upload_started",
|
| 204 |
+
"document_id": document_id,
|
| 205 |
+
"filename": filename,
|
| 206 |
+
"timestamp": int(time.time())
|
| 207 |
+
}, user_id)
|
| 208 |
+
|
| 209 |
+
async def send_pdf_upload_progress(user_id: str, document_id: str, step: str, progress: float, message: str):
|
| 210 |
+
"""Gửi thông báo tiến độ upload PDF"""
|
| 211 |
+
await manager.send_message({
|
| 212 |
+
"type": "pdf_upload_progress",
|
| 213 |
+
"document_id": document_id,
|
| 214 |
+
"step": step,
|
| 215 |
+
"progress": progress,
|
| 216 |
+
"message": message,
|
| 217 |
+
"timestamp": int(time.time())
|
| 218 |
+
}, user_id)
|
| 219 |
+
|
| 220 |
+
async def send_pdf_upload_completed(user_id: str, document_id: str, filename: str, chunks: int):
|
| 221 |
+
"""Gửi thông báo hoàn thành upload PDF"""
|
| 222 |
+
await manager.send_message({
|
| 223 |
+
"type": "pdf_upload_completed",
|
| 224 |
+
"document_id": document_id,
|
| 225 |
+
"filename": filename,
|
| 226 |
+
"chunks": chunks,
|
| 227 |
+
"timestamp": int(time.time())
|
| 228 |
+
}, user_id)
|
| 229 |
+
|
| 230 |
+
async def send_pdf_upload_failed(user_id: str, document_id: str, filename: str, error: str):
|
| 231 |
+
"""Gửi thông báo lỗi upload PDF"""
|
| 232 |
+
await manager.send_message({
|
| 233 |
+
"type": "pdf_upload_failed",
|
| 234 |
+
"document_id": document_id,
|
| 235 |
+
"filename": filename,
|
| 236 |
+
"error": error,
|
| 237 |
+
"timestamp": int(time.time())
|
| 238 |
+
}, user_id)
|
| 239 |
+
|
| 240 |
+
async def send_pdf_delete_started(user_id: str, namespace: str):
|
| 241 |
+
"""Gửi thông báo bắt đầu xóa PDF"""
|
| 242 |
+
await manager.send_message({
|
| 243 |
+
"type": "pdf_delete_started",
|
| 244 |
+
"namespace": namespace,
|
| 245 |
+
"timestamp": int(time.time())
|
| 246 |
+
}, user_id)
|
| 247 |
+
|
| 248 |
+
async def send_pdf_delete_completed(user_id: str, namespace: str):
|
| 249 |
+
"""Gửi thông báo hoàn thành xóa PDF"""
|
| 250 |
+
await manager.send_message({
|
| 251 |
+
"type": "pdf_delete_completed",
|
| 252 |
+
"namespace": namespace,
|
| 253 |
+
"timestamp": int(time.time())
|
| 254 |
+
}, user_id)
|
| 255 |
+
|
| 256 |
+
async def send_pdf_delete_failed(user_id: str, namespace: str, error: str):
|
| 257 |
+
"""Gửi thông báo lỗi xóa PDF"""
|
| 258 |
+
await manager.send_message({
|
| 259 |
+
"type": "pdf_delete_failed",
|
| 260 |
+
"namespace": namespace,
|
| 261 |
+
"error": error,
|
| 262 |
+
"timestamp": int(time.time())
|
| 263 |
+
}, user_id)
|
app/api/postgresql_routes.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/api/rag_routes.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends, Query, BackgroundTasks, Request
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import hashlib
|
| 8 |
+
import asyncio
|
| 9 |
+
import traceback
|
| 10 |
+
import google.generativeai as genai
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from langchain.prompts import PromptTemplate
|
| 13 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 14 |
+
from app.utils.utils import timer_decorator
|
| 15 |
+
|
| 16 |
+
from app.database.mongodb import get_chat_history, get_request_history, session_collection
|
| 17 |
+
from app.database.pinecone import (
|
| 18 |
+
search_vectors,
|
| 19 |
+
get_chain,
|
| 20 |
+
DEFAULT_TOP_K,
|
| 21 |
+
DEFAULT_LIMIT_K,
|
| 22 |
+
DEFAULT_SIMILARITY_METRIC,
|
| 23 |
+
DEFAULT_SIMILARITY_THRESHOLD,
|
| 24 |
+
ALLOWED_METRICS
|
| 25 |
+
)
|
| 26 |
+
from app.models.rag_models import (
|
| 27 |
+
ChatRequest,
|
| 28 |
+
ChatResponse,
|
| 29 |
+
ChatResponseInternal,
|
| 30 |
+
SourceDocument,
|
| 31 |
+
EmbeddingRequest,
|
| 32 |
+
EmbeddingResponse,
|
| 33 |
+
UserMessageModel
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Configure logging
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
# Configure Google Gemini API
|
| 40 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 41 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
| 42 |
+
|
| 43 |
+
# Create router
|
| 44 |
+
router = APIRouter(
|
| 45 |
+
prefix="/rag",
|
| 46 |
+
tags=["RAG"],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
fix_request = PromptTemplate(
|
| 50 |
+
template = """Goal:
|
| 51 |
+
Your task is to extract important keywords from the user's current request, optionally using chat history if relevant.
|
| 52 |
+
You will receive a conversation history and the user's current message.
|
| 53 |
+
Generate a **list of concise keywords** that best represent the user's intent.
|
| 54 |
+
|
| 55 |
+
Return Format:
|
| 56 |
+
Only return keywords (comma-separated, no extra explanation).
|
| 57 |
+
If the current message is NOT related to the chat history or if there is no chat history: Return keywords from the current message only.
|
| 58 |
+
If the current message IS related to the chat history: Return a refined set of keywords based on both history and current message.
|
| 59 |
+
|
| 60 |
+
Warning:
|
| 61 |
+
Only use chat history if the current message is clearly related to the prior context.
|
| 62 |
+
|
| 63 |
+
Conversation History:
|
| 64 |
+
{chat_history}
|
| 65 |
+
|
| 66 |
+
User current message:
|
| 67 |
+
{question}
|
| 68 |
+
""",
|
| 69 |
+
input_variables=["chat_history", "question"],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Create a prompt template with conversation history
|
| 73 |
+
prompt = PromptTemplate(
|
| 74 |
+
template = """Goal:
|
| 75 |
+
You are a professional tour guide assistant that assists users in finding information about places in Da Nang, Vietnam.
|
| 76 |
+
You can provide details on restaurants, cafes, hotels, attractions, and other local venues.
|
| 77 |
+
You have to use core knowledge and conversation history to chat with users, who are Da Nang's tourists.
|
| 78 |
+
|
| 79 |
+
Return Format:
|
| 80 |
+
Respond in friendly, natural, concise and use only English like a real tour guide.
|
| 81 |
+
Always use HTML tags (e.g. <b> for bold) so that Telegram can render the special formatting correctly.
|
| 82 |
+
|
| 83 |
+
Warning:
|
| 84 |
+
Let's support users like a real tour guide, not a bot. The information in core knowledge is your own knowledge.
|
| 85 |
+
Your knowledge is provided in the Core Knowledge. All of information in Core Knowledge is about Da Nang, Vietnam.
|
| 86 |
+
You just care about current time that user mention when user ask about Solana event.
|
| 87 |
+
Only use core knowledge to answer. If you do not have enough information to answer user's question, please reply with "I'm sorry. I don't have information about that" and Give users some more options to ask.
|
| 88 |
+
|
| 89 |
+
Core knowledge:
|
| 90 |
+
{context}
|
| 91 |
+
|
| 92 |
+
Conversation History:
|
| 93 |
+
{chat_history}
|
| 94 |
+
|
| 95 |
+
User message:
|
| 96 |
+
{question}
|
| 97 |
+
|
| 98 |
+
Your message:
|
| 99 |
+
""",
|
| 100 |
+
input_variables = ["context", "question", "chat_history"],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Helper for embeddings
|
| 104 |
+
async def get_embedding(text: str):
|
| 105 |
+
"""Get embedding from Google Gemini API"""
|
| 106 |
+
try:
|
| 107 |
+
# Initialize embedding model
|
| 108 |
+
embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
| 109 |
+
|
| 110 |
+
# Generate embedding
|
| 111 |
+
result = await embedding_model.aembed_query(text)
|
| 112 |
+
|
| 113 |
+
# Return embedding
|
| 114 |
+
return {
|
| 115 |
+
"embedding": result,
|
| 116 |
+
"text": text,
|
| 117 |
+
"model": "embedding-001"
|
| 118 |
+
}
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Error generating embedding: {e}")
|
| 121 |
+
raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
|
| 122 |
+
|
| 123 |
+
# Endpoint for generating embeddings
|
| 124 |
+
@router.post("/embedding", response_model=EmbeddingResponse)
|
| 125 |
+
async def create_embedding(request: EmbeddingRequest):
|
| 126 |
+
"""
|
| 127 |
+
Generate embedding for text.
|
| 128 |
+
|
| 129 |
+
- **text**: Text to generate embedding for
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
# Get embedding
|
| 133 |
+
embedding_data = await get_embedding(request.text)
|
| 134 |
+
|
| 135 |
+
# Return embedding
|
| 136 |
+
return EmbeddingResponse(**embedding_data)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error generating embedding: {e}")
|
| 139 |
+
raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
|
| 140 |
+
|
| 141 |
+
@timer_decorator
|
| 142 |
+
@router.post("/chat", response_model=ChatResponse)
|
| 143 |
+
async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
|
| 144 |
+
"""
|
| 145 |
+
Get answer for a question using RAG.
|
| 146 |
+
|
| 147 |
+
- **user_id**: User's ID from Telegram
|
| 148 |
+
- **question**: User's question
|
| 149 |
+
- **include_history**: Whether to include user history in prompt (default: True)
|
| 150 |
+
- **use_rag**: Whether to use RAG (default: True)
|
| 151 |
+
- **similarity_top_k**: Number of top similar documents to return after filtering (default: 6)
|
| 152 |
+
- **limit_k**: Maximum number of documents to retrieve from vector store (default: 10)
|
| 153 |
+
- **similarity_metric**: Similarity metric to use - cosine, dotproduct, euclidean (default: cosine)
|
| 154 |
+
- **similarity_threshold**: Threshold for vector similarity (default: 0.75)
|
| 155 |
+
- **session_id**: Optional session ID for tracking conversations
|
| 156 |
+
- **first_name**: User's first name
|
| 157 |
+
- **last_name**: User's last name
|
| 158 |
+
- **username**: User's username
|
| 159 |
+
"""
|
| 160 |
+
start_time = time.time()
|
| 161 |
+
try:
|
| 162 |
+
# Save user message first (so it's available for user history)
|
| 163 |
+
session_id = request.session_id or f"{request.user_id}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
| 164 |
+
# logger.info(f"Processing chat request for user {request.user_id}, session {session_id}")
|
| 165 |
+
|
| 166 |
+
retriever = get_chain(
|
| 167 |
+
top_k=request.similarity_top_k,
|
| 168 |
+
limit_k=request.limit_k,
|
| 169 |
+
similarity_metric=request.similarity_metric,
|
| 170 |
+
similarity_threshold=request.similarity_threshold
|
| 171 |
+
)
|
| 172 |
+
if not retriever:
|
| 173 |
+
raise HTTPException(status_code=500, detail="Failed to initialize retriever")
|
| 174 |
+
|
| 175 |
+
# Get chat history
|
| 176 |
+
chat_history = get_chat_history(request.user_id) if request.include_history else ""
|
| 177 |
+
logger.info(f"Using chat history: {chat_history[:100]}...")
|
| 178 |
+
|
| 179 |
+
# Initialize Gemini model
|
| 180 |
+
generation_config = {
|
| 181 |
+
"temperature": 0.9,
|
| 182 |
+
"top_p": 1,
|
| 183 |
+
"top_k": 1,
|
| 184 |
+
"max_output_tokens": 2048,
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
safety_settings = [
|
| 188 |
+
{
|
| 189 |
+
"category": "HARM_CATEGORY_HARASSMENT",
|
| 190 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
| 194 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 198 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 202 |
+
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
| 203 |
+
},
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
model = genai.GenerativeModel(
|
| 207 |
+
model_name='models/gemini-2.0-flash',
|
| 208 |
+
generation_config=generation_config,
|
| 209 |
+
safety_settings=safety_settings
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
prompt_request = fix_request.format(
|
| 213 |
+
question=request.question,
|
| 214 |
+
chat_history=chat_history
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Log thời gian bắt đầu final_request
|
| 218 |
+
final_request_start_time = time.time()
|
| 219 |
+
final_request = model.generate_content(prompt_request)
|
| 220 |
+
# Log thời gian hoàn thành final_request
|
| 221 |
+
logger.info(f"Fixed Request: {final_request.text}")
|
| 222 |
+
logger.info(f"Final request generation time: {time.time() - final_request_start_time:.2f} seconds")
|
| 223 |
+
# print(final_request.text)
|
| 224 |
+
|
| 225 |
+
retrieved_docs = retriever.invoke(final_request.text)
|
| 226 |
+
logger.info(f"Retrieve: {retrieved_docs}")
|
| 227 |
+
context = "\n".join([doc.page_content for doc in retrieved_docs])
|
| 228 |
+
|
| 229 |
+
sources = []
|
| 230 |
+
for doc in retrieved_docs:
|
| 231 |
+
source = None
|
| 232 |
+
metadata = {}
|
| 233 |
+
|
| 234 |
+
if hasattr(doc, 'metadata'):
|
| 235 |
+
source = doc.metadata.get('source', None)
|
| 236 |
+
# Extract score information
|
| 237 |
+
score = doc.metadata.get('score', None)
|
| 238 |
+
normalized_score = doc.metadata.get('normalized_score', None)
|
| 239 |
+
# Remove score info from metadata to avoid duplication
|
| 240 |
+
metadata = {k: v for k, v in doc.metadata.items()
|
| 241 |
+
if k not in ['text', 'source', 'score', 'normalized_score']}
|
| 242 |
+
|
| 243 |
+
sources.append(SourceDocument(
|
| 244 |
+
text=doc.page_content,
|
| 245 |
+
source=source,
|
| 246 |
+
score=score,
|
| 247 |
+
normalized_score=normalized_score,
|
| 248 |
+
metadata=metadata
|
| 249 |
+
))
|
| 250 |
+
|
| 251 |
+
# Generate the prompt using template
|
| 252 |
+
prompt_text = prompt.format(
|
| 253 |
+
context=context,
|
| 254 |
+
question=final_request.text,
|
| 255 |
+
chat_history=chat_history
|
| 256 |
+
)
|
| 257 |
+
logger.info(f"Full prompt with history and context: {prompt_text}")
|
| 258 |
+
|
| 259 |
+
# Generate response
|
| 260 |
+
response = model.generate_content(prompt_text)
|
| 261 |
+
answer = response.text
|
| 262 |
+
|
| 263 |
+
# Calculate processing time
|
| 264 |
+
processing_time = time.time() - start_time
|
| 265 |
+
|
| 266 |
+
# Log full response with sources
|
| 267 |
+
# logger.info(f"Generated response for user {request.user_id}: {answer}")
|
| 268 |
+
|
| 269 |
+
# Create response object for API (without sources)
|
| 270 |
+
chat_response = ChatResponse(
|
| 271 |
+
answer=answer,
|
| 272 |
+
processing_time=processing_time
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Return response
|
| 276 |
+
return chat_response
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.error(f"Error processing chat request: {e}")
|
| 279 |
+
import traceback
|
| 280 |
+
logger.error(traceback.format_exc())
|
| 281 |
+
raise HTTPException(status_code=500, detail=f"Failed to process chat request: {str(e)}")
|
| 282 |
+
|
| 283 |
+
# Health check endpoint
|
| 284 |
+
@router.get("/health")
|
| 285 |
+
async def health_check():
|
| 286 |
+
"""
|
| 287 |
+
Check health of RAG services and retrieval system.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
- status: "healthy" if all services are working, "degraded" otherwise
|
| 291 |
+
- services: Status of each service (gemini, pinecone)
|
| 292 |
+
- retrieval_config: Current retrieval configuration
|
| 293 |
+
- timestamp: Current time
|
| 294 |
+
"""
|
| 295 |
+
services = {
|
| 296 |
+
"gemini": False,
|
| 297 |
+
"pinecone": False
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
# Check Gemini
|
| 301 |
+
try:
|
| 302 |
+
# Initialize simple model
|
| 303 |
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
| 304 |
+
# Test generation
|
| 305 |
+
response = model.generate_content("Hello")
|
| 306 |
+
services["gemini"] = True
|
| 307 |
+
except Exception as e:
|
| 308 |
+
logger.error(f"Gemini health check failed: {e}")
|
| 309 |
+
|
| 310 |
+
# Check Pinecone
|
| 311 |
+
try:
|
| 312 |
+
# Import pinecone function
|
| 313 |
+
from app.database.pinecone import get_pinecone_index
|
| 314 |
+
# Get index
|
| 315 |
+
index = get_pinecone_index()
|
| 316 |
+
# Check if index exists
|
| 317 |
+
if index:
|
| 318 |
+
services["pinecone"] = True
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.error(f"Pinecone health check failed: {e}")
|
| 321 |
+
|
| 322 |
+
# Get retrieval configuration
|
| 323 |
+
retrieval_config = {
|
| 324 |
+
"default_top_k": DEFAULT_TOP_K,
|
| 325 |
+
"default_limit_k": DEFAULT_LIMIT_K,
|
| 326 |
+
"default_similarity_metric": DEFAULT_SIMILARITY_METRIC,
|
| 327 |
+
"default_similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD,
|
| 328 |
+
"allowed_metrics": ALLOWED_METRICS
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# Return health status
|
| 332 |
+
status = "healthy" if all(services.values()) else "degraded"
|
| 333 |
+
return {
|
| 334 |
+
"status": status,
|
| 335 |
+
"services": services,
|
| 336 |
+
"retrieval_config": retrieval_config,
|
| 337 |
+
"timestamp": datetime.now().isoformat()
|
| 338 |
+
}
|
app/api/websocket_routes.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, status
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
import logging
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from app.database.mongodb import session_collection
|
| 10 |
+
from app.utils.utils import get_local_time
|
| 11 |
+
|
| 12 |
+
# Load environment variables
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# Get WebSocket configuration from environment variables
|
| 16 |
+
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
|
| 17 |
+
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
|
| 18 |
+
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Create router
|
| 24 |
+
router = APIRouter(
|
| 25 |
+
tags=["WebSocket"],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Store active WebSocket connections
|
| 29 |
+
class ConnectionManager:
|
| 30 |
+
def __init__(self):
|
| 31 |
+
self.active_connections: List[WebSocket] = []
|
| 32 |
+
|
| 33 |
+
async def connect(self, websocket: WebSocket):
|
| 34 |
+
await websocket.accept()
|
| 35 |
+
self.active_connections.append(websocket)
|
| 36 |
+
client_info = f"{websocket.client.host}:{websocket.client.port}" if hasattr(websocket, 'client') else "Unknown"
|
| 37 |
+
logger.info(f"New WebSocket connection from {client_info}. Total connections: {len(self.active_connections)}")
|
| 38 |
+
|
| 39 |
+
def disconnect(self, websocket: WebSocket):
|
| 40 |
+
self.active_connections.remove(websocket)
|
| 41 |
+
logger.info(f"WebSocket connection removed. Total connections: {len(self.active_connections)}")
|
| 42 |
+
|
| 43 |
+
async def broadcast(self, message: Dict):
|
| 44 |
+
if not self.active_connections:
|
| 45 |
+
logger.warning("No active WebSocket connections to broadcast to")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
disconnected = []
|
| 49 |
+
for connection in self.active_connections:
|
| 50 |
+
try:
|
| 51 |
+
await connection.send_json(message)
|
| 52 |
+
logger.info(f"Message sent to WebSocket connection")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Error sending message to WebSocket: {e}")
|
| 55 |
+
disconnected.append(connection)
|
| 56 |
+
|
| 57 |
+
# Remove disconnected connections
|
| 58 |
+
for conn in disconnected:
|
| 59 |
+
if conn in self.active_connections:
|
| 60 |
+
self.active_connections.remove(conn)
|
| 61 |
+
logger.info(f"Removed disconnected WebSocket. Remaining: {len(self.active_connections)}")
|
| 62 |
+
|
| 63 |
+
# Initialize connection manager
|
| 64 |
+
manager = ConnectionManager()
|
| 65 |
+
|
| 66 |
+
# Create full URL of WebSocket server from environment variables
|
| 67 |
+
def get_full_websocket_url(server_side=False):
|
| 68 |
+
if server_side:
|
| 69 |
+
# Relative URL (for server side)
|
| 70 |
+
return WEBSOCKET_PATH
|
| 71 |
+
else:
|
| 72 |
+
# Full URL (for client)
|
| 73 |
+
# Check if should use wss:// for HTTPS
|
| 74 |
+
is_https = True if int(WEBSOCKET_PORT) == 443 else False
|
| 75 |
+
protocol = "wss" if is_https else "ws"
|
| 76 |
+
|
| 77 |
+
# If using default port for protocol, don't include in URL
|
| 78 |
+
if (is_https and int(WEBSOCKET_PORT) == 443) or (not is_https and int(WEBSOCKET_PORT) == 80):
|
| 79 |
+
return f"{protocol}://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
|
| 80 |
+
else:
|
| 81 |
+
return f"{protocol}://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
|
| 82 |
+
|
| 83 |
+
# Add GET endpoint to display WebSocket information in Swagger
|
| 84 |
+
@router.get("/notify",
|
| 85 |
+
summary="WebSocket notifications for Admin Bot",
|
| 86 |
+
description=f"""
|
| 87 |
+
This is documentation for the WebSocket endpoint.
|
| 88 |
+
|
| 89 |
+
To connect to WebSocket:
|
| 90 |
+
1. Use the path `{get_full_websocket_url()}`
|
| 91 |
+
2. Connect using a WebSocket client library
|
| 92 |
+
3. When there are new sessions requiring attention, you will receive notifications through this connection
|
| 93 |
+
|
| 94 |
+
Notifications are sent when:
|
| 95 |
+
- Session response starts with "I'm sorry"
|
| 96 |
+
- The system cannot answer the user's question
|
| 97 |
+
|
| 98 |
+
Make sure to send a "keepalive" message every 5 minutes to maintain the connection.
|
| 99 |
+
""",
|
| 100 |
+
status_code=status.HTTP_200_OK
|
| 101 |
+
)
|
| 102 |
+
async def websocket_documentation():
|
| 103 |
+
"""
|
| 104 |
+
Provides information about how to use the WebSocket endpoint /notify.
|
| 105 |
+
This endpoint is for documentation purposes only. To use WebSocket, please connect to the WebSocket URL.
|
| 106 |
+
"""
|
| 107 |
+
ws_url = get_full_websocket_url()
|
| 108 |
+
return {
|
| 109 |
+
"websocket_endpoint": WEBSOCKET_PATH,
|
| 110 |
+
"connection_type": "WebSocket",
|
| 111 |
+
"protocol": "ws://",
|
| 112 |
+
"server": WEBSOCKET_SERVER,
|
| 113 |
+
"port": WEBSOCKET_PORT,
|
| 114 |
+
"full_url": ws_url,
|
| 115 |
+
"description": "Endpoint to receive notifications about new sessions requiring attention",
|
| 116 |
+
"notification_format": {
|
| 117 |
+
"type": "sorry_response",
|
| 118 |
+
"timestamp": "YYYY-MM-DD HH:MM:SS",
|
| 119 |
+
"data": {
|
| 120 |
+
"session_id": "session id",
|
| 121 |
+
"factor": "user",
|
| 122 |
+
"action": "action type",
|
| 123 |
+
"message": "User question",
|
| 124 |
+
"response": "I'm sorry...",
|
| 125 |
+
"user_id": "user id",
|
| 126 |
+
"first_name": "user's first name",
|
| 127 |
+
"last_name": "user's last name",
|
| 128 |
+
"username": "username",
|
| 129 |
+
"created_at": "creation time"
|
| 130 |
+
}
|
| 131 |
+
},
|
| 132 |
+
"client_example": """
|
| 133 |
+
import websocket
|
| 134 |
+
import json
|
| 135 |
+
import os
|
| 136 |
+
import time
|
| 137 |
+
import threading
|
| 138 |
+
from dotenv import load_dotenv
|
| 139 |
+
|
| 140 |
+
# Load environment variables
|
| 141 |
+
load_dotenv()
|
| 142 |
+
|
| 143 |
+
# Get WebSocket configuration from environment variables
|
| 144 |
+
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
|
| 145 |
+
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
|
| 146 |
+
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
|
| 147 |
+
|
| 148 |
+
# Create full URL
|
| 149 |
+
ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
|
| 150 |
+
|
| 151 |
+
# If using HTTPS, replace ws:// with wss://
|
| 152 |
+
# ws_url = f"wss://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
|
| 153 |
+
|
| 154 |
+
# Send keepalive periodically
|
| 155 |
+
def send_keepalive(ws):
|
| 156 |
+
while True:
|
| 157 |
+
try:
|
| 158 |
+
if ws.sock and ws.sock.connected:
|
| 159 |
+
ws.send("keepalive")
|
| 160 |
+
print("Sent keepalive message")
|
| 161 |
+
time.sleep(300) # 5 minutes
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"Error sending keepalive: {e}")
|
| 164 |
+
time.sleep(60)
|
| 165 |
+
|
| 166 |
+
def on_message(ws, message):
|
| 167 |
+
try:
|
| 168 |
+
data = json.loads(message)
|
| 169 |
+
print(f"Received notification: {data}")
|
| 170 |
+
# Process notification, e.g.: send to Telegram Admin
|
| 171 |
+
if data.get("type") == "sorry_response":
|
| 172 |
+
session_data = data.get("data", {})
|
| 173 |
+
user_question = session_data.get("message", "")
|
| 174 |
+
user_name = session_data.get("first_name", "Unknown User")
|
| 175 |
+
print(f"User {user_name} asked: {user_question}")
|
| 176 |
+
# Code to send message to Telegram Admin
|
| 177 |
+
except json.JSONDecodeError:
|
| 178 |
+
print(f"Received non-JSON message: {message}")
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"Error processing message: {e}")
|
| 181 |
+
|
| 182 |
+
def on_error(ws, error):
|
| 183 |
+
print(f"WebSocket error: {error}")
|
| 184 |
+
|
| 185 |
+
def on_close(ws, close_status_code, close_msg):
|
| 186 |
+
print(f"WebSocket connection closed: code={close_status_code}, message={close_msg}")
|
| 187 |
+
|
| 188 |
+
def on_open(ws):
|
| 189 |
+
print(f"WebSocket connection opened to {ws_url}")
|
| 190 |
+
# Send keepalive messages periodically in a separate thread
|
| 191 |
+
keepalive_thread = threading.Thread(target=send_keepalive, args=(ws,), daemon=True)
|
| 192 |
+
keepalive_thread.start()
|
| 193 |
+
|
| 194 |
+
def run_forever_with_reconnect():
|
| 195 |
+
while True:
|
| 196 |
+
try:
|
| 197 |
+
# Connect WebSocket with ping to maintain connection
|
| 198 |
+
ws = websocket.WebSocketApp(
|
| 199 |
+
ws_url,
|
| 200 |
+
on_open=on_open,
|
| 201 |
+
on_message=on_message,
|
| 202 |
+
on_error=on_error,
|
| 203 |
+
on_close=on_close
|
| 204 |
+
)
|
| 205 |
+
ws.run_forever(ping_interval=60, ping_timeout=30)
|
| 206 |
+
print("WebSocket connection lost, reconnecting in 5 seconds...")
|
| 207 |
+
time.sleep(5)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f"WebSocket connection error: {e}")
|
| 210 |
+
time.sleep(5)
|
| 211 |
+
|
| 212 |
+
# Start WebSocket client in a separate thread
|
| 213 |
+
websocket_thread = threading.Thread(target=run_forever_with_reconnect, daemon=True)
|
| 214 |
+
websocket_thread.start()
|
| 215 |
+
|
| 216 |
+
# Keep the program running
|
| 217 |
+
try:
|
| 218 |
+
while True:
|
| 219 |
+
time.sleep(1)
|
| 220 |
+
except KeyboardInterrupt:
|
| 221 |
+
print("Stopping WebSocket client...")
|
| 222 |
+
"""
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
@router.websocket("/notify")
|
| 226 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 227 |
+
"""
|
| 228 |
+
WebSocket endpoint to receive notifications about new sessions.
|
| 229 |
+
Admin Bot will connect to this endpoint to receive notifications when there are new sessions requiring attention.
|
| 230 |
+
"""
|
| 231 |
+
await manager.connect(websocket)
|
| 232 |
+
try:
|
| 233 |
+
while True:
|
| 234 |
+
# Maintain WebSocket connection
|
| 235 |
+
data = await websocket.receive_text()
|
| 236 |
+
# Echo back to keep connection active
|
| 237 |
+
await websocket.send_json({"status": "connected", "echo": data, "timestamp": datetime.now().isoformat()})
|
| 238 |
+
logger.info(f"Received message from WebSocket: {data}")
|
| 239 |
+
except WebSocketDisconnect:
|
| 240 |
+
logger.info("WebSocket client disconnected")
|
| 241 |
+
manager.disconnect(websocket)
|
| 242 |
+
except Exception as e:
|
| 243 |
+
logger.error(f"WebSocket error: {e}")
|
| 244 |
+
manager.disconnect(websocket)
|
| 245 |
+
|
| 246 |
+
# Function to send notifications over WebSocket
|
| 247 |
+
async def send_notification(data: dict):
|
| 248 |
+
"""
|
| 249 |
+
Send notification to all active WebSocket connections.
|
| 250 |
+
|
| 251 |
+
This function is used to notify admin bots about new issues or questions that need attention.
|
| 252 |
+
It's triggered when the system cannot answer a user's question (response starts with "I'm sorry").
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
data: The data to send as notification
|
| 256 |
+
"""
|
| 257 |
+
try:
|
| 258 |
+
# Log number of active connections and notification attempt
|
| 259 |
+
logger.info(f"Attempting to send notification. Active connections: {len(manager.active_connections)}")
|
| 260 |
+
logger.info(f"Notification data: session_id={data.get('session_id')}, user_id={data.get('user_id')}")
|
| 261 |
+
logger.info(f"Response: {data.get('response', '')[:50]}...")
|
| 262 |
+
|
| 263 |
+
# Check if the response starts with "I'm sorry"
|
| 264 |
+
response = data.get('response', '')
|
| 265 |
+
if not response or not isinstance(response, str):
|
| 266 |
+
logger.warning(f"Invalid response format in notification data: {response}")
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
if not response.strip().lower().startswith("i'm sorry"):
|
| 270 |
+
logger.info(f"Response doesn't start with 'I'm sorry', notification not needed: {response[:50]}...")
|
| 271 |
+
return
|
| 272 |
+
|
| 273 |
+
logger.info(f"Response starts with 'I'm sorry', sending notification")
|
| 274 |
+
|
| 275 |
+
# Format the notification data for admin - format theo chuẩn Admin_bot
|
| 276 |
+
notification_data = {
|
| 277 |
+
"type": "sorry_response", # Đổi type thành sorry_response để phù hợp với Admin_bot
|
| 278 |
+
"timestamp": get_local_time(),
|
| 279 |
+
"user_id": data.get('user_id', 'unknown'),
|
| 280 |
+
"message": data.get('message', ''),
|
| 281 |
+
"response": response,
|
| 282 |
+
"session_id": data.get('session_id', 'unknown'),
|
| 283 |
+
"user_info": {
|
| 284 |
+
"first_name": data.get('first_name', 'User'),
|
| 285 |
+
"last_name": data.get('last_name', ''),
|
| 286 |
+
"username": data.get('username', '')
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
# Check if there are active connections
|
| 291 |
+
if not manager.active_connections:
|
| 292 |
+
logger.warning("No active WebSocket connections for notification broadcast")
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
# Broadcast notification to all active connections
|
| 296 |
+
logger.info(f"Broadcasting notification to {len(manager.active_connections)} connections")
|
| 297 |
+
await manager.broadcast(notification_data)
|
| 298 |
+
logger.info("Notification broadcast completed successfully")
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"Error sending notification: {e}")
|
| 302 |
+
import traceback
|
| 303 |
+
logger.error(traceback.format_exc())
|
app/database/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Database connections package
|
app/database/models.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Float, Text, LargeBinary, JSON
|
| 2 |
+
from sqlalchemy.sql import func
|
| 3 |
+
from sqlalchemy.orm import relationship
|
| 4 |
+
from .postgresql import Base
|
| 5 |
+
import datetime
|
| 6 |
+
|
| 7 |
+
class FAQItem(Base):
|
| 8 |
+
__tablename__ = "faq_item"
|
| 9 |
+
|
| 10 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 11 |
+
question = Column(String, nullable=False)
|
| 12 |
+
answer = Column(String, nullable=False)
|
| 13 |
+
is_active = Column(Boolean, default=True)
|
| 14 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 15 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 16 |
+
|
| 17 |
+
class EmergencyItem(Base):
|
| 18 |
+
__tablename__ = "emergency_item"
|
| 19 |
+
|
| 20 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 21 |
+
name = Column(String, nullable=False)
|
| 22 |
+
phone_number = Column(String, nullable=False)
|
| 23 |
+
description = Column(String, nullable=True)
|
| 24 |
+
address = Column(String, nullable=True)
|
| 25 |
+
location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
|
| 26 |
+
priority = Column(Integer, default=0)
|
| 27 |
+
is_active = Column(Boolean, default=True)
|
| 28 |
+
section = Column(String, nullable=True) # Section field (16.1, 16.2.1, 16.2.2, 16.3)
|
| 29 |
+
section_id = Column(Integer, nullable=True) # Numeric identifier for section
|
| 30 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 31 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 32 |
+
|
| 33 |
+
class EventItem(Base):
|
| 34 |
+
__tablename__ = "event_item"
|
| 35 |
+
|
| 36 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 37 |
+
name = Column(String, nullable=False)
|
| 38 |
+
description = Column(Text, nullable=False)
|
| 39 |
+
address = Column(String, nullable=False)
|
| 40 |
+
location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
|
| 41 |
+
date_start = Column(DateTime, nullable=False)
|
| 42 |
+
date_end = Column(DateTime, nullable=True)
|
| 43 |
+
price = Column(JSON, nullable=True)
|
| 44 |
+
url = Column(String, nullable=True)
|
| 45 |
+
is_active = Column(Boolean, default=True)
|
| 46 |
+
featured = Column(Boolean, default=False)
|
| 47 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 48 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 49 |
+
|
| 50 |
+
class AboutPixity(Base):
|
| 51 |
+
__tablename__ = "about_pixity"
|
| 52 |
+
|
| 53 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 54 |
+
content = Column(Text, nullable=False)
|
| 55 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 56 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 57 |
+
|
| 58 |
+
class SolanaSummit(Base):
|
| 59 |
+
__tablename__ = "solana_summit"
|
| 60 |
+
|
| 61 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 62 |
+
content = Column(Text, nullable=False)
|
| 63 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 64 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 65 |
+
|
| 66 |
+
class DaNangBucketList(Base):
|
| 67 |
+
__tablename__ = "danang_bucket_list"
|
| 68 |
+
|
| 69 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 70 |
+
content = Column(Text, nullable=False)
|
| 71 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 72 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 73 |
+
|
| 74 |
+
class VectorDatabase(Base):
|
| 75 |
+
__tablename__ = "vector_database"
|
| 76 |
+
|
| 77 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 78 |
+
name = Column(String, nullable=False, unique=True)
|
| 79 |
+
description = Column(String, nullable=True)
|
| 80 |
+
pinecone_index = Column(String, nullable=False)
|
| 81 |
+
api_key_id = Column(Integer, ForeignKey("api_key.id"), nullable=True)
|
| 82 |
+
status = Column(String, default="active")
|
| 83 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 84 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 85 |
+
|
| 86 |
+
# Relationships
|
| 87 |
+
documents = relationship("Document", back_populates="vector_database")
|
| 88 |
+
vector_statuses = relationship("VectorStatus", back_populates="vector_database")
|
| 89 |
+
engine_associations = relationship("EngineVectorDb", back_populates="vector_database")
|
| 90 |
+
api_key_ref = relationship("ApiKey", foreign_keys=[api_key_id])
|
| 91 |
+
|
| 92 |
+
class Document(Base):
|
| 93 |
+
__tablename__ = "document"
|
| 94 |
+
|
| 95 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 96 |
+
name = Column(String, nullable=False)
|
| 97 |
+
file_type = Column(String, nullable=True)
|
| 98 |
+
content_type = Column(String, nullable=True)
|
| 99 |
+
size = Column(Integer, nullable=True)
|
| 100 |
+
is_embedded = Column(Boolean, default=False)
|
| 101 |
+
vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
|
| 102 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 103 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 104 |
+
|
| 105 |
+
# Relationships
|
| 106 |
+
vector_database = relationship("VectorDatabase", back_populates="documents")
|
| 107 |
+
vector_statuses = relationship("VectorStatus", back_populates="document")
|
| 108 |
+
file_content_ref = relationship("DocumentContent", back_populates="document", uselist=False, cascade="all, delete-orphan")
|
| 109 |
+
|
| 110 |
+
class DocumentContent(Base):
|
| 111 |
+
__tablename__ = "document_content"
|
| 112 |
+
|
| 113 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 114 |
+
document_id = Column(Integer, ForeignKey("document.id"), nullable=False, unique=True)
|
| 115 |
+
file_content = Column(LargeBinary, nullable=True)
|
| 116 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 117 |
+
|
| 118 |
+
# Relationships
|
| 119 |
+
document = relationship("Document", back_populates="file_content_ref")
|
| 120 |
+
|
| 121 |
+
class VectorStatus(Base):
|
| 122 |
+
__tablename__ = "vector_status"
|
| 123 |
+
|
| 124 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 125 |
+
document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
|
| 126 |
+
vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
|
| 127 |
+
vector_id = Column(String, nullable=True)
|
| 128 |
+
status = Column(String, default="pending")
|
| 129 |
+
error_message = Column(String, nullable=True)
|
| 130 |
+
embedded_at = Column(DateTime, nullable=True)
|
| 131 |
+
|
| 132 |
+
# Relationships
|
| 133 |
+
document = relationship("Document", back_populates="vector_statuses")
|
| 134 |
+
vector_database = relationship("VectorDatabase", back_populates="vector_statuses")
|
| 135 |
+
|
| 136 |
+
class TelegramBot(Base):
|
| 137 |
+
__tablename__ = "telegram_bot"
|
| 138 |
+
|
| 139 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 140 |
+
name = Column(String, nullable=False)
|
| 141 |
+
username = Column(String, nullable=False, unique=True)
|
| 142 |
+
token = Column(String, nullable=False)
|
| 143 |
+
status = Column(String, default="inactive")
|
| 144 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 145 |
+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 146 |
+
|
| 147 |
+
# Relationships
|
| 148 |
+
bot_engines = relationship("BotEngine", back_populates="bot")
|
| 149 |
+
|
| 150 |
+
class ChatEngine(Base):
|
| 151 |
+
__tablename__ = "chat_engine"
|
| 152 |
+
|
| 153 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 154 |
+
name = Column(String, nullable=False)
|
| 155 |
+
answer_model = Column(String, nullable=False)
|
| 156 |
+
system_prompt = Column(Text, nullable=True)
|
| 157 |
+
empty_response = Column(String, nullable=True)
|
| 158 |
+
similarity_top_k = Column(Integer, default=3)
|
| 159 |
+
vector_distance_threshold = Column(Float, default=0.75)
|
| 160 |
+
grounding_threshold = Column(Float, default=0.2)
|
| 161 |
+
use_public_information = Column(Boolean, default=False)
|
| 162 |
+
status = Column(String, default="active")
|
| 163 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 164 |
+
last_modified = Column(DateTime, server_default=func.now(), onupdate=func.now())
|
| 165 |
+
|
| 166 |
+
# Relationships
|
| 167 |
+
bot_engines = relationship("BotEngine", back_populates="engine")
|
| 168 |
+
engine_vector_dbs = relationship("EngineVectorDb", back_populates="engine")
|
| 169 |
+
|
| 170 |
+
class BotEngine(Base):
|
| 171 |
+
__tablename__ = "bot_engine"
|
| 172 |
+
|
| 173 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 174 |
+
bot_id = Column(Integer, ForeignKey("telegram_bot.id"), nullable=False)
|
| 175 |
+
engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
|
| 176 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 177 |
+
|
| 178 |
+
# Relationships
|
| 179 |
+
bot = relationship("TelegramBot", back_populates="bot_engines")
|
| 180 |
+
engine = relationship("ChatEngine", back_populates="bot_engines")
|
| 181 |
+
|
| 182 |
+
class EngineVectorDb(Base):
|
| 183 |
+
__tablename__ = "engine_vector_db"
|
| 184 |
+
|
| 185 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 186 |
+
engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
|
| 187 |
+
vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
|
| 188 |
+
priority = Column(Integer, default=0)
|
| 189 |
+
|
| 190 |
+
# Relationships
|
| 191 |
+
engine = relationship("ChatEngine", back_populates="engine_vector_dbs")
|
| 192 |
+
vector_database = relationship("VectorDatabase", back_populates="engine_associations")
|
| 193 |
+
|
| 194 |
+
class ApiKey(Base):
|
| 195 |
+
__tablename__ = "api_key"
|
| 196 |
+
|
| 197 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 198 |
+
key_type = Column(String, nullable=False)
|
| 199 |
+
key_value = Column(Text, nullable=False)
|
| 200 |
+
description = Column(Text, nullable=True)
|
| 201 |
+
created_at = Column(DateTime, server_default=func.now())
|
| 202 |
+
last_used = Column(DateTime, nullable=True)
|
| 203 |
+
expires_at = Column(DateTime, nullable=True)
|
| 204 |
+
is_active = Column(Boolean, default=True)
|
app/database/mongodb.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pymongo import MongoClient
|
| 3 |
+
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
import pytz
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Load environment variables
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# MongoDB connection string from .env
|
| 16 |
+
MONGODB_URL = os.getenv("MONGODB_URL")
|
| 17 |
+
DB_NAME = os.getenv("DB_NAME", "Telegram")
|
| 18 |
+
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "session_chat")
|
| 19 |
+
|
| 20 |
+
# Set timeout for MongoDB connection
|
| 21 |
+
MONGODB_TIMEOUT = int(os.getenv("MONGODB_TIMEOUT", "5000")) # 5 seconds by default
|
| 22 |
+
|
| 23 |
+
# Legacy cache settings - now only used for configuration purposes
|
| 24 |
+
HISTORY_CACHE_TTL = int(os.getenv("HISTORY_CACHE_TTL", "3600")) # 1 hour by default
|
| 25 |
+
HISTORY_QUEUE_SIZE = int(os.getenv("HISTORY_QUEUE_SIZE", "10")) # 10 items by default
|
| 26 |
+
|
| 27 |
+
# Create MongoDB connection with timeout
|
| 28 |
+
try:
|
| 29 |
+
client = MongoClient(MONGODB_URL, serverSelectionTimeoutMS=MONGODB_TIMEOUT)
|
| 30 |
+
db = client[DB_NAME]
|
| 31 |
+
|
| 32 |
+
# Collections
|
| 33 |
+
session_collection = db[COLLECTION_NAME]
|
| 34 |
+
logger.info(f"MongoDB connection initialized to {DB_NAME}.{COLLECTION_NAME}")
|
| 35 |
+
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.error(f"Failed to initialize MongoDB connection: {e}")
|
| 38 |
+
# Don't raise exception to avoid crash during startup, error handling will be done in functions
|
| 39 |
+
|
| 40 |
+
# Check MongoDB connection
|
| 41 |
+
def check_db_connection():
|
| 42 |
+
"""Check MongoDB connection"""
|
| 43 |
+
try:
|
| 44 |
+
# Issue a ping to confirm a successful connection
|
| 45 |
+
client.admin.command('ping')
|
| 46 |
+
logger.info("MongoDB connection is working")
|
| 47 |
+
return True
|
| 48 |
+
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
|
| 49 |
+
logger.error(f"MongoDB connection failed: {e}")
|
| 50 |
+
return False
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"Unknown error when checking MongoDB connection: {e}")
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
# Timezone for Asia/Ho_Chi_Minh
|
| 56 |
+
asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
|
| 57 |
+
|
| 58 |
+
def get_local_time():
|
| 59 |
+
"""Get current time in Asia/Ho_Chi_Minh timezone"""
|
| 60 |
+
return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
|
| 61 |
+
|
| 62 |
+
def get_local_datetime():
|
| 63 |
+
"""Get current datetime object in Asia/Ho_Chi_Minh timezone"""
|
| 64 |
+
return datetime.now(asia_tz)
|
| 65 |
+
|
| 66 |
+
# For backward compatibility
|
| 67 |
+
get_vietnam_time = get_local_time
|
| 68 |
+
get_vietnam_datetime = get_local_datetime
|
| 69 |
+
|
| 70 |
+
# Utility functions
|
| 71 |
+
def save_session(session_id, factor, action, first_name, last_name, message, user_id, username, response=None):
|
| 72 |
+
"""Save user session to MongoDB"""
|
| 73 |
+
try:
|
| 74 |
+
session_data = {
|
| 75 |
+
"session_id": session_id,
|
| 76 |
+
"factor": factor,
|
| 77 |
+
"action": action,
|
| 78 |
+
"created_at": get_local_time(),
|
| 79 |
+
"created_at_datetime": get_local_datetime(),
|
| 80 |
+
"first_name": first_name,
|
| 81 |
+
"last_name": last_name,
|
| 82 |
+
"message": message,
|
| 83 |
+
"user_id": user_id,
|
| 84 |
+
"username": username,
|
| 85 |
+
"response": response
|
| 86 |
+
}
|
| 87 |
+
result = session_collection.insert_one(session_data)
|
| 88 |
+
logger.info(f"Session saved with ID: {result.inserted_id}")
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"acknowledged": result.acknowledged,
|
| 92 |
+
"inserted_id": str(result.inserted_id),
|
| 93 |
+
"session_data": session_data
|
| 94 |
+
}
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Error saving session: {e}")
|
| 97 |
+
raise
|
| 98 |
+
|
| 99 |
+
def update_session_response(session_id, response):
|
| 100 |
+
"""Update a session with response"""
|
| 101 |
+
try:
|
| 102 |
+
# Lấy session hiện có
|
| 103 |
+
existing_session = session_collection.find_one({"session_id": session_id})
|
| 104 |
+
|
| 105 |
+
if not existing_session:
|
| 106 |
+
logger.warning(f"No session found with ID: {session_id}")
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
result = session_collection.update_one(
|
| 110 |
+
{"session_id": session_id},
|
| 111 |
+
{"$set": {"response": response}}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
logger.info(f"Session {session_id} updated with response")
|
| 115 |
+
return True
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Error updating session response: {e}")
|
| 118 |
+
raise
|
| 119 |
+
|
| 120 |
+
def get_recent_sessions(user_id, action, n=3):
|
| 121 |
+
"""Get n most recent sessions for a specific user and action"""
|
| 122 |
+
try:
|
| 123 |
+
# Truy vấn trực tiếp từ MongoDB
|
| 124 |
+
result = list(
|
| 125 |
+
session_collection.find(
|
| 126 |
+
{"user_id": user_id, "action": action},
|
| 127 |
+
{"_id": 0, "message": 1, "response": 1}
|
| 128 |
+
).sort("created_at_datetime", -1).limit(n)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
logger.debug(f"Retrieved {len(result)} recent sessions for user {user_id}, action {action}")
|
| 132 |
+
return result
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Error getting recent sessions: {e}")
|
| 135 |
+
return []
|
| 136 |
+
|
| 137 |
+
def get_chat_history(user_id, n = 5) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Lấy lịch sử chat cho user_id từ MongoDB và ghép thành chuỗi theo định dạng:
|
| 140 |
+
|
| 141 |
+
User: ...
|
| 142 |
+
Bot: ...
|
| 143 |
+
User: ...
|
| 144 |
+
Bot: ...
|
| 145 |
+
|
| 146 |
+
Chỉ lấy history sau lệnh /start hoặc /clear mới nhất
|
| 147 |
+
"""
|
| 148 |
+
try:
|
| 149 |
+
# Tìm session /start hoặc /clear mới nhất
|
| 150 |
+
reset_session = session_collection.find_one(
|
| 151 |
+
{
|
| 152 |
+
"user_id": str(user_id),
|
| 153 |
+
"$or": [
|
| 154 |
+
{"action": "start"},
|
| 155 |
+
{"action": "clear"}
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
sort=[("created_at_datetime", -1)]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Nếu không tìm thấy session reset nào, lấy n session gần nhất
|
| 162 |
+
if reset_session:
|
| 163 |
+
reset_time = reset_session["created_at_datetime"]
|
| 164 |
+
# Lấy các session sau reset_time
|
| 165 |
+
docs = list(
|
| 166 |
+
session_collection.find({
|
| 167 |
+
"user_id": str(user_id),
|
| 168 |
+
"created_at_datetime": {"$gt": reset_time}
|
| 169 |
+
}).sort("created_at_datetime", 1)
|
| 170 |
+
)
|
| 171 |
+
logger.info(f"Lấy {len(docs)} session sau lệnh {reset_session['action']} lúc {reset_time}")
|
| 172 |
+
else:
|
| 173 |
+
# Không tìm thấy reset session, lấy n session gần nhất
|
| 174 |
+
docs = list(session_collection.find({"user_id": str(user_id)}).sort("created_at", -1).limit(n))
|
| 175 |
+
# Đảo ngược để có thứ tự từ cũ đến mới
|
| 176 |
+
docs.reverse()
|
| 177 |
+
logger.info(f"Không tìm thấy session reset, lấy {len(docs)} session gần nhất")
|
| 178 |
+
|
| 179 |
+
if not docs:
|
| 180 |
+
logger.info(f"Không tìm thấy dữ liệu cho user_id: {user_id}")
|
| 181 |
+
return ""
|
| 182 |
+
|
| 183 |
+
conversation_lines = []
|
| 184 |
+
# Xử lý từng document theo cấu trúc mới
|
| 185 |
+
for doc in docs:
|
| 186 |
+
factor = doc.get("factor", "").lower()
|
| 187 |
+
action = doc.get("action", "").lower()
|
| 188 |
+
message = doc.get("message", "")
|
| 189 |
+
response = doc.get("response", "")
|
| 190 |
+
|
| 191 |
+
# Bỏ qua lệnh start và clear
|
| 192 |
+
if action in ["start", "clear"]:
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
if factor == "user" and action == "asking_freely":
|
| 196 |
+
conversation_lines.append(f"User: {message}")
|
| 197 |
+
conversation_lines.append(f"Bot: {response}")
|
| 198 |
+
|
| 199 |
+
# Ghép các dòng thành chuỗi
|
| 200 |
+
return "\n".join(conversation_lines)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"Lỗi khi lấy lịch sử chat cho user_id {user_id}: {e}")
|
| 203 |
+
return ""
|
| 204 |
+
|
| 205 |
+
def get_request_history(user_id, n=3):
|
| 206 |
+
"""Get the most recent user requests to use as context for retrieval"""
|
| 207 |
+
try:
|
| 208 |
+
# Truy vấn trực tiếp từ MongoDB
|
| 209 |
+
history = get_chat_history(user_id, n)
|
| 210 |
+
|
| 211 |
+
# Just extract the questions for context
|
| 212 |
+
requests = []
|
| 213 |
+
for line in history.split('\n'):
|
| 214 |
+
if line.startswith("User: "):
|
| 215 |
+
requests.append(line[6:]) # Lấy nội dung sau "User: "
|
| 216 |
+
|
| 217 |
+
# Join all recent requests into a single string for context
|
| 218 |
+
return " ".join(requests)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.error(f"Error getting request history: {e}")
|
| 221 |
+
return ""
|
app/database/pinecone.py
ADDED
|
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pinecone import Pinecone
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Optional, List, Dict, Any, Union, Tuple
|
| 6 |
+
import time
|
| 7 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 8 |
+
import google.generativeai as genai
|
| 9 |
+
from langchain_core.retrievers import BaseRetriever
|
| 10 |
+
from langchain.callbacks.manager import Callbacks
|
| 11 |
+
from langchain_core.documents import Document
|
| 12 |
+
from langchain_core.pydantic_v1 import Field
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Load environment variables
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
# Pinecone API key and index name
|
| 21 |
+
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
| 22 |
+
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
|
| 23 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 24 |
+
|
| 25 |
+
# Pinecone retrieval configuration
|
| 26 |
+
DEFAULT_LIMIT_K = int(os.getenv("PINECONE_DEFAULT_LIMIT_K", "10"))
|
| 27 |
+
DEFAULT_TOP_K = int(os.getenv("PINECONE_DEFAULT_TOP_K", "6"))
|
| 28 |
+
DEFAULT_SIMILARITY_METRIC = os.getenv("PINECONE_DEFAULT_SIMILARITY_METRIC", "cosine")
|
| 29 |
+
DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("PINECONE_DEFAULT_SIMILARITY_THRESHOLD", "0.75"))
|
| 30 |
+
ALLOWED_METRICS = os.getenv("PINECONE_ALLOWED_METRICS", "cosine,dotproduct,euclidean").split(",")
|
| 31 |
+
|
| 32 |
+
# Export constants for importing elsewhere
|
| 33 |
+
__all__ = [
|
| 34 |
+
'get_pinecone_index',
|
| 35 |
+
'check_db_connection',
|
| 36 |
+
'search_vectors',
|
| 37 |
+
'upsert_vectors',
|
| 38 |
+
'delete_vectors',
|
| 39 |
+
'fetch_metadata',
|
| 40 |
+
'get_chain',
|
| 41 |
+
'DEFAULT_TOP_K',
|
| 42 |
+
'DEFAULT_LIMIT_K',
|
| 43 |
+
'DEFAULT_SIMILARITY_METRIC',
|
| 44 |
+
'DEFAULT_SIMILARITY_THRESHOLD',
|
| 45 |
+
'ALLOWED_METRICS',
|
| 46 |
+
'ThresholdRetriever'
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
# Configure Google API
|
| 50 |
+
if GOOGLE_API_KEY:
|
| 51 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
| 52 |
+
|
| 53 |
+
# Initialize global variables to store instances of Pinecone and index
|
| 54 |
+
pc = None
|
| 55 |
+
index = None
|
| 56 |
+
_retriever_instance = None
|
| 57 |
+
|
| 58 |
+
# Check environment variables
|
| 59 |
+
if not PINECONE_API_KEY:
|
| 60 |
+
logger.error("PINECONE_API_KEY is not set in environment variables")
|
| 61 |
+
|
| 62 |
+
if not PINECONE_INDEX_NAME:
|
| 63 |
+
logger.error("PINECONE_INDEX_NAME is not set in environment variables")
|
| 64 |
+
|
| 65 |
+
# Initialize Pinecone
|
| 66 |
+
def init_pinecone():
|
| 67 |
+
"""Initialize pinecone connection using new API"""
|
| 68 |
+
global pc, index
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Only initialize if not already initialized
|
| 72 |
+
if pc is None:
|
| 73 |
+
logger.info(f"Initializing Pinecone connection to index {PINECONE_INDEX_NAME}...")
|
| 74 |
+
|
| 75 |
+
# Check if API key and index name are set
|
| 76 |
+
if not PINECONE_API_KEY:
|
| 77 |
+
logger.error("PINECONE_API_KEY is not set in environment variables")
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
if not PINECONE_INDEX_NAME:
|
| 81 |
+
logger.error("PINECONE_INDEX_NAME is not set in environment variables")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
# Initialize Pinecone client using the new API
|
| 85 |
+
pc = Pinecone(api_key=PINECONE_API_KEY)
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# Check if index exists
|
| 89 |
+
index_list = pc.list_indexes()
|
| 90 |
+
|
| 91 |
+
if not hasattr(index_list, 'names') or PINECONE_INDEX_NAME not in index_list.names():
|
| 92 |
+
logger.error(f"Index {PINECONE_INDEX_NAME} does not exist in Pinecone")
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
# Get existing index
|
| 96 |
+
index = pc.Index(PINECONE_INDEX_NAME)
|
| 97 |
+
logger.info(f"Pinecone connection established to index {PINECONE_INDEX_NAME}")
|
| 98 |
+
except Exception as connection_error:
|
| 99 |
+
logger.error(f"Error connecting to Pinecone index: {connection_error}")
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
return index
|
| 103 |
+
except ImportError as e:
|
| 104 |
+
logger.error(f"Required package for Pinecone is missing: {e}")
|
| 105 |
+
return None
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"Unexpected error initializing Pinecone: {e}")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
# Get Pinecone index singleton
|
| 111 |
+
def get_pinecone_index():
|
| 112 |
+
"""Get Pinecone index"""
|
| 113 |
+
global index
|
| 114 |
+
if index is None:
|
| 115 |
+
index = init_pinecone()
|
| 116 |
+
return index
|
| 117 |
+
|
| 118 |
+
# Check Pinecone connection
|
| 119 |
+
def check_db_connection():
|
| 120 |
+
"""Check Pinecone connection"""
|
| 121 |
+
try:
|
| 122 |
+
pinecone_index = get_pinecone_index()
|
| 123 |
+
if pinecone_index is None:
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
# Check index information to confirm connection is working
|
| 127 |
+
stats = pinecone_index.describe_index_stats()
|
| 128 |
+
|
| 129 |
+
# Get total vector count from the new result structure
|
| 130 |
+
total_vectors = stats.get('total_vector_count', 0)
|
| 131 |
+
if hasattr(stats, 'namespaces'):
|
| 132 |
+
# If there are namespaces, calculate total vector count from namespaces
|
| 133 |
+
total_vectors = sum(ns.get('vector_count', 0) for ns in stats.namespaces.values())
|
| 134 |
+
|
| 135 |
+
logger.info(f"Pinecone connection is working. Total vectors: {total_vectors}")
|
| 136 |
+
return True
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error in Pinecone connection: {e}")
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
# Convert similarity score based on the metric
|
| 142 |
+
def convert_score(score: float, metric: str) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Convert similarity score to a 0-1 scale based on the metric used.
|
| 145 |
+
For metrics like euclidean distance where lower is better, we invert the score.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
score: The raw similarity score
|
| 149 |
+
metric: The similarity metric used
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
A normalized score between 0-1 where higher means more similar
|
| 153 |
+
"""
|
| 154 |
+
if metric.lower() in ["euclidean", "l2"]:
|
| 155 |
+
# For distance metrics (lower is better), we inverse and normalize
|
| 156 |
+
# Assuming max reasonable distance is 2.0 for normalized vectors
|
| 157 |
+
return max(0, 1 - (score / 2.0))
|
| 158 |
+
else:
|
| 159 |
+
# For cosine and dot product (higher is better), return as is
|
| 160 |
+
return score
|
| 161 |
+
|
| 162 |
+
# Filter results based on similarity threshold
|
| 163 |
+
def filter_by_threshold(results, threshold: float, metric: str) -> List[Dict]:
|
| 164 |
+
"""
|
| 165 |
+
Filter query results based on similarity threshold.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
results: The query results from Pinecone
|
| 169 |
+
threshold: The similarity threshold (0-1)
|
| 170 |
+
metric: The similarity metric used
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Filtered list of matches
|
| 174 |
+
"""
|
| 175 |
+
filtered_matches = []
|
| 176 |
+
|
| 177 |
+
if not hasattr(results, 'matches'):
|
| 178 |
+
return filtered_matches
|
| 179 |
+
|
| 180 |
+
for match in results.matches:
|
| 181 |
+
# Get the score
|
| 182 |
+
score = getattr(match, 'score', 0)
|
| 183 |
+
|
| 184 |
+
# Convert score based on metric
|
| 185 |
+
normalized_score = convert_score(score, metric)
|
| 186 |
+
|
| 187 |
+
# Filter based on threshold
|
| 188 |
+
if normalized_score >= threshold:
|
| 189 |
+
# Add normalized score as an additional attribute
|
| 190 |
+
match.normalized_score = normalized_score
|
| 191 |
+
filtered_matches.append(match)
|
| 192 |
+
|
| 193 |
+
return filtered_matches
|
| 194 |
+
|
| 195 |
+
# Search vectors in Pinecone with advanced options
|
| 196 |
+
async def search_vectors(
|
| 197 |
+
query_vector,
|
| 198 |
+
top_k: int = DEFAULT_TOP_K,
|
| 199 |
+
limit_k: int = DEFAULT_LIMIT_K,
|
| 200 |
+
similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
|
| 201 |
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
| 202 |
+
namespace: str = "Default",
|
| 203 |
+
filter: Optional[Dict] = None
|
| 204 |
+
) -> Dict:
|
| 205 |
+
"""
|
| 206 |
+
Search for most similar vectors in Pinecone with advanced filtering options.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
query_vector: The query vector
|
| 210 |
+
top_k: Number of results to return (after threshold filtering)
|
| 211 |
+
limit_k: Maximum number of results to retrieve from Pinecone
|
| 212 |
+
similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
|
| 213 |
+
similarity_threshold: Threshold for similarity (0-1)
|
| 214 |
+
namespace: Namespace to search in
|
| 215 |
+
filter: Filter query
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Search results with matches filtered by threshold
|
| 219 |
+
"""
|
| 220 |
+
try:
|
| 221 |
+
# Validate parameters
|
| 222 |
+
if similarity_metric not in ALLOWED_METRICS:
|
| 223 |
+
logger.warning(f"Invalid similarity metric: {similarity_metric}. Using default: {DEFAULT_SIMILARITY_METRIC}")
|
| 224 |
+
similarity_metric = DEFAULT_SIMILARITY_METRIC
|
| 225 |
+
|
| 226 |
+
if limit_k < top_k:
|
| 227 |
+
logger.warning(f"limit_k ({limit_k}) must be greater than or equal to top_k ({top_k}). Setting limit_k to {top_k}")
|
| 228 |
+
limit_k = top_k
|
| 229 |
+
|
| 230 |
+
# Perform search directly without cache
|
| 231 |
+
pinecone_index = get_pinecone_index()
|
| 232 |
+
if pinecone_index is None:
|
| 233 |
+
logger.error("Failed to get Pinecone index for search")
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
# Query Pinecone with the provided metric and higher limit_k to allow for threshold filtering
|
| 237 |
+
results = pinecone_index.query(
|
| 238 |
+
vector=query_vector,
|
| 239 |
+
top_k=limit_k, # Retrieve more results than needed to allow for threshold filtering
|
| 240 |
+
namespace=namespace,
|
| 241 |
+
filter=filter,
|
| 242 |
+
include_metadata=True,
|
| 243 |
+
include_values=False, # No need to return vector values to save bandwidth
|
| 244 |
+
metric=similarity_metric # Specify similarity metric
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Filter results by threshold
|
| 248 |
+
filtered_matches = filter_by_threshold(results, similarity_threshold, similarity_metric)
|
| 249 |
+
|
| 250 |
+
# Limit to top_k after filtering
|
| 251 |
+
filtered_matches = filtered_matches[:top_k]
|
| 252 |
+
|
| 253 |
+
# Create a new results object with filtered matches
|
| 254 |
+
results.matches = filtered_matches
|
| 255 |
+
|
| 256 |
+
# Log search result metrics
|
| 257 |
+
match_count = len(filtered_matches)
|
| 258 |
+
logger.info(f"Pinecone search returned {match_count} matches after threshold filtering (metric: {similarity_metric}, threshold: {similarity_threshold}, namespace: {namespace})")
|
| 259 |
+
|
| 260 |
+
return results
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logger.error(f"Error searching vectors: {e}")
|
| 263 |
+
return None
|
| 264 |
+
|
| 265 |
+
# Upsert vectors to Pinecone
|
| 266 |
+
async def upsert_vectors(vectors, namespace="Default"):
|
| 267 |
+
"""Upsert vectors to Pinecone index"""
|
| 268 |
+
try:
|
| 269 |
+
pinecone_index = get_pinecone_index()
|
| 270 |
+
if pinecone_index is None:
|
| 271 |
+
logger.error("Failed to get Pinecone index for upsert")
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
response = pinecone_index.upsert(
|
| 275 |
+
vectors=vectors,
|
| 276 |
+
namespace=namespace
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Log upsert metrics
|
| 280 |
+
upserted_count = response.get('upserted_count', 0)
|
| 281 |
+
logger.info(f"Upserted {upserted_count} vectors to Pinecone")
|
| 282 |
+
|
| 283 |
+
return response
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.error(f"Error upserting vectors: {e}")
|
| 286 |
+
return None
|
| 287 |
+
|
| 288 |
+
# Delete vectors from Pinecone
|
| 289 |
+
async def delete_vectors(ids, namespace="Default"):
|
| 290 |
+
"""Delete vectors from Pinecone index"""
|
| 291 |
+
try:
|
| 292 |
+
pinecone_index = get_pinecone_index()
|
| 293 |
+
if pinecone_index is None:
|
| 294 |
+
logger.error("Failed to get Pinecone index for delete")
|
| 295 |
+
return False
|
| 296 |
+
|
| 297 |
+
response = pinecone_index.delete(
|
| 298 |
+
ids=ids,
|
| 299 |
+
namespace=namespace
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
logger.info(f"Deleted vectors with IDs {ids} from Pinecone")
|
| 303 |
+
return True
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logger.error(f"Error deleting vectors: {e}")
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
# Fetch vector metadata from Pinecone
|
| 309 |
+
async def fetch_metadata(ids, namespace="Default"):
|
| 310 |
+
"""Fetch metadata for specific vector IDs"""
|
| 311 |
+
try:
|
| 312 |
+
pinecone_index = get_pinecone_index()
|
| 313 |
+
if pinecone_index is None:
|
| 314 |
+
logger.error("Failed to get Pinecone index for fetch")
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
response = pinecone_index.fetch(
|
| 318 |
+
ids=ids,
|
| 319 |
+
namespace=namespace
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
return response
|
| 323 |
+
except Exception as e:
|
| 324 |
+
logger.error(f"Error fetching vector metadata: {e}")
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
# Create a custom retriever class for Langchain integration
|
| 328 |
+
class ThresholdRetriever(BaseRetriever):
|
| 329 |
+
"""
|
| 330 |
+
Custom retriever that supports threshold-based filtering and multiple similarity metrics.
|
| 331 |
+
This integrates with the Langchain ecosystem while using our advanced retrieval logic.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
vectorstore: Any = Field(description="Vector store to use for retrieval")
|
| 335 |
+
embeddings: Any = Field(description="Embeddings model to use for retrieval")
|
| 336 |
+
search_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Search kwargs for the vectorstore")
|
| 337 |
+
top_k: int = Field(default=DEFAULT_TOP_K, description="Number of results to return after filtering")
|
| 338 |
+
limit_k: int = Field(default=DEFAULT_LIMIT_K, description="Maximum number of results to retrieve from Pinecone")
|
| 339 |
+
similarity_metric: str = Field(default=DEFAULT_SIMILARITY_METRIC, description="Similarity metric to use")
|
| 340 |
+
similarity_threshold: float = Field(default=DEFAULT_SIMILARITY_THRESHOLD, description="Threshold for similarity")
|
| 341 |
+
namespace: str = "Default"
|
| 342 |
+
|
| 343 |
+
class Config:
|
| 344 |
+
"""Configuration for this pydantic object."""
|
| 345 |
+
arbitrary_types_allowed = True
|
| 346 |
+
|
| 347 |
+
async def search_vectors_sync(
|
| 348 |
+
self, query_vector,
|
| 349 |
+
top_k: int = DEFAULT_TOP_K,
|
| 350 |
+
limit_k: int = DEFAULT_LIMIT_K,
|
| 351 |
+
similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
|
| 352 |
+
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
|
| 353 |
+
namespace: str = "Default",
|
| 354 |
+
filter: Optional[Dict] = None
|
| 355 |
+
) -> Dict:
|
| 356 |
+
"""Synchronous wrapper for search_vectors"""
|
| 357 |
+
import asyncio
|
| 358 |
+
try:
|
| 359 |
+
# Get current event loop or create a new one
|
| 360 |
+
try:
|
| 361 |
+
loop = asyncio.get_event_loop()
|
| 362 |
+
except RuntimeError:
|
| 363 |
+
loop = asyncio.new_event_loop()
|
| 364 |
+
asyncio.set_event_loop(loop)
|
| 365 |
+
|
| 366 |
+
# Use event loop to run async function
|
| 367 |
+
if loop.is_running():
|
| 368 |
+
# If we're in an event loop, use asyncio.create_task
|
| 369 |
+
task = asyncio.create_task(search_vectors(
|
| 370 |
+
query_vector=query_vector,
|
| 371 |
+
top_k=top_k,
|
| 372 |
+
limit_k=limit_k,
|
| 373 |
+
similarity_metric=similarity_metric,
|
| 374 |
+
similarity_threshold=similarity_threshold,
|
| 375 |
+
namespace=namespace,
|
| 376 |
+
filter=filter
|
| 377 |
+
))
|
| 378 |
+
return await task
|
| 379 |
+
else:
|
| 380 |
+
# If not in an event loop, just await directly
|
| 381 |
+
return await search_vectors(
|
| 382 |
+
query_vector=query_vector,
|
| 383 |
+
top_k=top_k,
|
| 384 |
+
limit_k=limit_k,
|
| 385 |
+
similarity_metric=similarity_metric,
|
| 386 |
+
similarity_threshold=similarity_threshold,
|
| 387 |
+
namespace=namespace,
|
| 388 |
+
filter=filter
|
| 389 |
+
)
|
| 390 |
+
except Exception as e:
|
| 391 |
+
logger.error(f"Error in search_vectors_sync: {e}")
|
| 392 |
+
return None
|
| 393 |
+
|
| 394 |
+
def _get_relevant_documents(
|
| 395 |
+
self, query: str, *, run_manager: Callbacks = None
|
| 396 |
+
) -> List[Document]:
|
| 397 |
+
"""
|
| 398 |
+
Get documents relevant to the query using threshold-based retrieval.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
query: The query string
|
| 402 |
+
run_manager: The callbacks manager
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
List of relevant documents
|
| 406 |
+
"""
|
| 407 |
+
# Generate embedding for query using the embeddings model
|
| 408 |
+
try:
|
| 409 |
+
# Use the embeddings model we stored in the class
|
| 410 |
+
embedding = self.embeddings.embed_query(query)
|
| 411 |
+
except Exception as e:
|
| 412 |
+
logger.error(f"Error generating embedding: {e}")
|
| 413 |
+
# Fallback to creating a new embedding model if needed
|
| 414 |
+
embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
| 415 |
+
embedding = embedding_model.embed_query(query)
|
| 416 |
+
|
| 417 |
+
# Perform search with advanced options - avoid asyncio.run()
|
| 418 |
+
import asyncio
|
| 419 |
+
|
| 420 |
+
# Get or create event loop
|
| 421 |
+
try:
|
| 422 |
+
loop = asyncio.get_event_loop()
|
| 423 |
+
except RuntimeError:
|
| 424 |
+
loop = asyncio.new_event_loop()
|
| 425 |
+
asyncio.set_event_loop(loop)
|
| 426 |
+
|
| 427 |
+
# Run asynchronous search in a safe way
|
| 428 |
+
if loop.is_running():
|
| 429 |
+
# We're inside an existing event loop (like in FastAPI)
|
| 430 |
+
# Use a different approach - convert it to a synchronous call
|
| 431 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 432 |
+
import functools
|
| 433 |
+
|
| 434 |
+
# Define a wrapper function to run in a thread
|
| 435 |
+
def run_async_in_thread():
|
| 436 |
+
# Create a new event loop for this thread
|
| 437 |
+
thread_loop = asyncio.new_event_loop()
|
| 438 |
+
asyncio.set_event_loop(thread_loop)
|
| 439 |
+
# Run the coroutine and return the result
|
| 440 |
+
return thread_loop.run_until_complete(search_vectors(
|
| 441 |
+
query_vector=embedding,
|
| 442 |
+
top_k=self.top_k,
|
| 443 |
+
limit_k=self.limit_k,
|
| 444 |
+
similarity_metric=self.similarity_metric,
|
| 445 |
+
similarity_threshold=self.similarity_threshold,
|
| 446 |
+
namespace=self.namespace,
|
| 447 |
+
# filter=self.search_kwargs.get("filter", None)
|
| 448 |
+
))
|
| 449 |
+
|
| 450 |
+
# Run the async function in a thread
|
| 451 |
+
with ThreadPoolExecutor() as executor:
|
| 452 |
+
search_result = executor.submit(run_async_in_thread).result()
|
| 453 |
+
else:
|
| 454 |
+
# No event loop running, we can use run_until_complete
|
| 455 |
+
search_result = loop.run_until_complete(search_vectors(
|
| 456 |
+
query_vector=embedding,
|
| 457 |
+
top_k=self.top_k,
|
| 458 |
+
limit_k=self.limit_k,
|
| 459 |
+
similarity_metric=self.similarity_metric,
|
| 460 |
+
similarity_threshold=self.similarity_threshold,
|
| 461 |
+
namespace=self.namespace,
|
| 462 |
+
# filter=self.search_kwargs.get("filter", None)
|
| 463 |
+
))
|
| 464 |
+
|
| 465 |
+
# Convert to documents
|
| 466 |
+
documents = []
|
| 467 |
+
if search_result and hasattr(search_result, 'matches'):
|
| 468 |
+
for match in search_result.matches:
|
| 469 |
+
# Extract metadata
|
| 470 |
+
metadata = {}
|
| 471 |
+
if hasattr(match, 'metadata'):
|
| 472 |
+
metadata = match.metadata
|
| 473 |
+
|
| 474 |
+
# Add score to metadata
|
| 475 |
+
score = getattr(match, 'score', 0)
|
| 476 |
+
normalized_score = getattr(match, 'normalized_score', score)
|
| 477 |
+
metadata['score'] = score
|
| 478 |
+
metadata['normalized_score'] = normalized_score
|
| 479 |
+
|
| 480 |
+
# Extract text
|
| 481 |
+
text = metadata.get('text', '')
|
| 482 |
+
if 'text' in metadata:
|
| 483 |
+
del metadata['text'] # Remove from metadata since it's the content
|
| 484 |
+
|
| 485 |
+
# Create Document
|
| 486 |
+
doc = Document(
|
| 487 |
+
page_content=text,
|
| 488 |
+
metadata=metadata
|
| 489 |
+
)
|
| 490 |
+
documents.append(doc)
|
| 491 |
+
|
| 492 |
+
return documents
|
| 493 |
+
|
| 494 |
+
# Get the retrieval chain with Pinecone vector store
|
| 495 |
+
def get_chain(
|
| 496 |
+
index_name=PINECONE_INDEX_NAME,
|
| 497 |
+
namespace="Default",
|
| 498 |
+
top_k=DEFAULT_TOP_K,
|
| 499 |
+
limit_k=DEFAULT_LIMIT_K,
|
| 500 |
+
similarity_metric=DEFAULT_SIMILARITY_METRIC,
|
| 501 |
+
similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD
|
| 502 |
+
):
|
| 503 |
+
"""
|
| 504 |
+
Get the retrieval chain with Pinecone vector store using threshold-based retrieval.
|
| 505 |
+
|
| 506 |
+
Args:
|
| 507 |
+
index_name: Pinecone index name
|
| 508 |
+
namespace: Pinecone namespace
|
| 509 |
+
top_k: Number of results to return after filtering
|
| 510 |
+
limit_k: Maximum number of results to retrieve from Pinecone
|
| 511 |
+
similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
|
| 512 |
+
similarity_threshold: Threshold for similarity (0-1)
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
ThresholdRetriever instance
|
| 516 |
+
"""
|
| 517 |
+
global _retriever_instance
|
| 518 |
+
try:
|
| 519 |
+
# If already initialized with same parameters, return cached instance
|
| 520 |
+
if _retriever_instance is not None:
|
| 521 |
+
return _retriever_instance
|
| 522 |
+
|
| 523 |
+
start_time = time.time()
|
| 524 |
+
logger.info("Initializing new retriever chain with threshold-based filtering")
|
| 525 |
+
|
| 526 |
+
# Initialize embeddings model
|
| 527 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
| 528 |
+
|
| 529 |
+
# Get index
|
| 530 |
+
pinecone_index = get_pinecone_index()
|
| 531 |
+
if not pinecone_index:
|
| 532 |
+
logger.error("Failed to get Pinecone index for retriever chain")
|
| 533 |
+
return None
|
| 534 |
+
|
| 535 |
+
# Get statistics for logging
|
| 536 |
+
try:
|
| 537 |
+
stats = pinecone_index.describe_index_stats()
|
| 538 |
+
total_vectors = stats.get('total_vector_count', 0)
|
| 539 |
+
logger.info(f"Pinecone index stats - Total vectors: {total_vectors}")
|
| 540 |
+
except Exception as e:
|
| 541 |
+
logger.error(f"Error getting index stats: {e}")
|
| 542 |
+
|
| 543 |
+
# Use Pinecone from langchain_community.vectorstores
|
| 544 |
+
from langchain_community.vectorstores import Pinecone as LangchainPinecone
|
| 545 |
+
|
| 546 |
+
logger.info(f"Creating Pinecone vectorstore with index: {index_name}, namespace: {namespace}")
|
| 547 |
+
vectorstore = LangchainPinecone.from_existing_index(
|
| 548 |
+
embedding=embeddings,
|
| 549 |
+
index_name=index_name,
|
| 550 |
+
namespace=namespace,
|
| 551 |
+
text_key="text"
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
# Create threshold-based retriever
|
| 555 |
+
logger.info(f"Creating ThresholdRetriever with top_k={top_k}, limit_k={limit_k}, " +
|
| 556 |
+
f"metric={similarity_metric}, threshold={similarity_threshold}")
|
| 557 |
+
|
| 558 |
+
# Create ThresholdRetriever with both vectorstore and embeddings
|
| 559 |
+
_retriever_instance = ThresholdRetriever(
|
| 560 |
+
vectorstore=vectorstore,
|
| 561 |
+
embeddings=embeddings, # Pass embeddings separately
|
| 562 |
+
top_k=top_k,
|
| 563 |
+
limit_k=limit_k,
|
| 564 |
+
similarity_metric=similarity_metric,
|
| 565 |
+
similarity_threshold=similarity_threshold
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds")
|
| 569 |
+
|
| 570 |
+
return _retriever_instance
|
| 571 |
+
except Exception as e:
|
| 572 |
+
logger.error(f"Error creating retrieval chain: {e}")
|
| 573 |
+
return None
|
app/database/postgresql.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from sqlalchemy import create_engine, text
|
| 3 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 4 |
+
from sqlalchemy.orm import sessionmaker
|
| 5 |
+
from sqlalchemy.exc import SQLAlchemyError, OperationalError
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Load environment variables
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# Get DB connection mode from environment
|
| 16 |
+
DB_CONNECTION_MODE = os.getenv("DB_CONNECTION_MODE", "aiven")
|
| 17 |
+
|
| 18 |
+
# Set connection string based on mode
|
| 19 |
+
if DB_CONNECTION_MODE == "aiven":
|
| 20 |
+
DATABASE_URL = os.getenv("AIVEN_DB_URL")
|
| 21 |
+
else:
|
| 22 |
+
# Default or other connection modes can be added here
|
| 23 |
+
DATABASE_URL = os.getenv("AIVEN_DB_URL")
|
| 24 |
+
|
| 25 |
+
if not DATABASE_URL:
|
| 26 |
+
logger.error("No database URL configured. Please set AIVEN_DB_URL environment variable.")
|
| 27 |
+
DATABASE_URL = "postgresql://localhost/test" # Fallback to avoid crash on startup
|
| 28 |
+
|
| 29 |
+
# Create SQLAlchemy engine with optimized settings
|
| 30 |
+
try:
|
| 31 |
+
engine = create_engine(
|
| 32 |
+
DATABASE_URL,
|
| 33 |
+
pool_size=10, # Limit max connections
|
| 34 |
+
max_overflow=5, # Allow temporary overflow of connections
|
| 35 |
+
pool_timeout=30, # Timeout waiting for connection from pool
|
| 36 |
+
pool_recycle=300, # Recycle connections every 5 minutes
|
| 37 |
+
pool_pre_ping=True, # Verify connection is still valid before using it
|
| 38 |
+
connect_args={
|
| 39 |
+
"connect_timeout": 5, # Connection timeout in seconds
|
| 40 |
+
"keepalives": 1, # Enable TCP keepalives
|
| 41 |
+
"keepalives_idle": 30, # Time before sending keepalives
|
| 42 |
+
"keepalives_interval": 10, # Time between keepalives
|
| 43 |
+
"keepalives_count": 5, # Number of keepalive probes
|
| 44 |
+
"application_name": "pixagent_api" # Identify app in PostgreSQL logs
|
| 45 |
+
},
|
| 46 |
+
# Performance optimizations
|
| 47 |
+
isolation_level="READ COMMITTED", # Lower isolation level for better performance
|
| 48 |
+
echo=False, # Disable SQL echo to reduce overhead
|
| 49 |
+
echo_pool=False, # Disable pool logging
|
| 50 |
+
future=True, # Use SQLAlchemy 2.0 features
|
| 51 |
+
# Execution options for common queries
|
| 52 |
+
execution_options={
|
| 53 |
+
"compiled_cache": {}, # Use an empty dict for compiled query caching
|
| 54 |
+
"logging_token": "SQL", # Tag for query logging
|
| 55 |
+
}
|
| 56 |
+
)
|
| 57 |
+
logger.info("PostgreSQL engine initialized with optimized settings")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Failed to initialize PostgreSQL engine: {e}")
|
| 60 |
+
# Don't raise exception to avoid crash on startup
|
| 61 |
+
|
| 62 |
+
# Create optimized session factory
|
| 63 |
+
SessionLocal = sessionmaker(
|
| 64 |
+
autocommit=False,
|
| 65 |
+
autoflush=False,
|
| 66 |
+
bind=engine,
|
| 67 |
+
expire_on_commit=False # Prevent automatic reloading after commit
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Base class for declarative models - use sqlalchemy.orm for SQLAlchemy 2.0 compatibility
|
| 71 |
+
from sqlalchemy.orm import declarative_base
|
| 72 |
+
Base = declarative_base()
|
| 73 |
+
|
| 74 |
+
# Check PostgreSQL connection
|
| 75 |
+
def check_db_connection():
|
| 76 |
+
"""Check PostgreSQL connection status"""
|
| 77 |
+
try:
|
| 78 |
+
# Simple query to verify connection
|
| 79 |
+
with engine.connect() as connection:
|
| 80 |
+
connection.execute(text("SELECT 1")).fetchone()
|
| 81 |
+
logger.info("PostgreSQL connection successful")
|
| 82 |
+
return True
|
| 83 |
+
except OperationalError as e:
|
| 84 |
+
logger.error(f"PostgreSQL connection failed: {e}")
|
| 85 |
+
return False
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error(f"Unknown error checking PostgreSQL connection: {e}")
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
# Dependency to get DB session with improved error handling
|
| 91 |
+
def get_db():
|
| 92 |
+
"""Get PostgreSQL database session"""
|
| 93 |
+
db = SessionLocal()
|
| 94 |
+
try:
|
| 95 |
+
# Test connection
|
| 96 |
+
db.execute(text("SELECT 1")).fetchone()
|
| 97 |
+
yield db
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"DB connection error: {e}")
|
| 100 |
+
raise
|
| 101 |
+
finally:
|
| 102 |
+
db.close() # Ensure connection is closed and returned to pool
|
| 103 |
+
|
| 104 |
+
# Create tables in database if they don't exist
|
| 105 |
+
def create_tables():
|
| 106 |
+
"""Create tables in database"""
|
| 107 |
+
try:
|
| 108 |
+
Base.metadata.create_all(bind=engine)
|
| 109 |
+
logger.info("Database tables created or already exist")
|
| 110 |
+
return True
|
| 111 |
+
except SQLAlchemyError as e:
|
| 112 |
+
logger.error(f"Failed to create database tables (SQLAlchemy error): {e}")
|
| 113 |
+
return False
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"Failed to create database tables (unexpected error): {e}")
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
# Function to create indexes for better performance
|
| 119 |
+
def create_indexes():
|
| 120 |
+
"""Create indexes for better query performance"""
|
| 121 |
+
try:
|
| 122 |
+
with engine.connect() as conn:
|
| 123 |
+
try:
|
| 124 |
+
# Index for featured events - use try-except to handle if index already exists
|
| 125 |
+
conn.execute(text("""
|
| 126 |
+
CREATE INDEX idx_event_featured
|
| 127 |
+
ON event_item(featured)
|
| 128 |
+
"""))
|
| 129 |
+
except SQLAlchemyError:
|
| 130 |
+
logger.info("Index idx_event_featured already exists")
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
# Index for active events
|
| 134 |
+
conn.execute(text("""
|
| 135 |
+
CREATE INDEX idx_event_active
|
| 136 |
+
ON event_item(is_active)
|
| 137 |
+
"""))
|
| 138 |
+
except SQLAlchemyError:
|
| 139 |
+
logger.info("Index idx_event_active already exists")
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Index for date filtering
|
| 143 |
+
conn.execute(text("""
|
| 144 |
+
CREATE INDEX idx_event_date_start
|
| 145 |
+
ON event_item(date_start)
|
| 146 |
+
"""))
|
| 147 |
+
except SQLAlchemyError:
|
| 148 |
+
logger.info("Index idx_event_date_start already exists")
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
# Composite index for combined filtering
|
| 152 |
+
conn.execute(text("""
|
| 153 |
+
CREATE INDEX idx_event_featured_active
|
| 154 |
+
ON event_item(featured, is_active)
|
| 155 |
+
"""))
|
| 156 |
+
except SQLAlchemyError:
|
| 157 |
+
logger.info("Index idx_event_featured_active already exists")
|
| 158 |
+
|
| 159 |
+
# Indexes for FAQ and Emergency tables
|
| 160 |
+
try:
|
| 161 |
+
# FAQ active flag index
|
| 162 |
+
conn.execute(text("""
|
| 163 |
+
CREATE INDEX idx_faq_active
|
| 164 |
+
ON faq_item(is_active)
|
| 165 |
+
"""))
|
| 166 |
+
except SQLAlchemyError:
|
| 167 |
+
logger.info("Index idx_faq_active already exists")
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
# Emergency contact active flag and priority indexes
|
| 171 |
+
conn.execute(text("""
|
| 172 |
+
CREATE INDEX idx_emergency_active
|
| 173 |
+
ON emergency_item(is_active)
|
| 174 |
+
"""))
|
| 175 |
+
except SQLAlchemyError:
|
| 176 |
+
logger.info("Index idx_emergency_active already exists")
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
conn.execute(text("""
|
| 180 |
+
CREATE INDEX idx_emergency_priority
|
| 181 |
+
ON emergency_item(priority)
|
| 182 |
+
"""))
|
| 183 |
+
except SQLAlchemyError:
|
| 184 |
+
logger.info("Index idx_emergency_priority already exists")
|
| 185 |
+
|
| 186 |
+
conn.commit()
|
| 187 |
+
|
| 188 |
+
logger.info("Database indexes created or verified")
|
| 189 |
+
return True
|
| 190 |
+
except SQLAlchemyError as e:
|
| 191 |
+
logger.error(f"Failed to create indexes: {e}")
|
| 192 |
+
return False
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Pydantic models package
|
app/models/mongodb_models.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 2 |
+
from typing import Optional, List, Dict, Any
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
class SessionBase(BaseModel):
|
| 7 |
+
"""Base model for session data"""
|
| 8 |
+
session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 9 |
+
factor: str
|
| 10 |
+
action: str
|
| 11 |
+
first_name: str
|
| 12 |
+
last_name: Optional[str] = None
|
| 13 |
+
message: Optional[str] = None
|
| 14 |
+
user_id: str
|
| 15 |
+
username: Optional[str] = None
|
| 16 |
+
|
| 17 |
+
class SessionCreate(SessionBase):
|
| 18 |
+
"""Model for creating new session"""
|
| 19 |
+
response: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
class SessionResponse(SessionBase):
|
| 22 |
+
"""Response model for session data"""
|
| 23 |
+
created_at: str
|
| 24 |
+
response: Optional[str] = None
|
| 25 |
+
|
| 26 |
+
model_config = ConfigDict(
|
| 27 |
+
json_schema_extra={
|
| 28 |
+
"example": {
|
| 29 |
+
"session_id": "123e4567-e89b-12d3-a456-426614174000",
|
| 30 |
+
"factor": "user",
|
| 31 |
+
"action": "asking_freely",
|
| 32 |
+
"created_at": "2023-06-01 14:30:45",
|
| 33 |
+
"first_name": "John",
|
| 34 |
+
"last_name": "Doe",
|
| 35 |
+
"message": "How can I find emergency contacts?",
|
| 36 |
+
"user_id": "12345678",
|
| 37 |
+
"username": "johndoe",
|
| 38 |
+
"response": "You can find emergency contacts in the Emergency section..."
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
class HistoryRequest(BaseModel):
|
| 44 |
+
"""Request model for history"""
|
| 45 |
+
user_id: str
|
| 46 |
+
n: int = 3
|
| 47 |
+
|
| 48 |
+
class QuestionAnswer(BaseModel):
|
| 49 |
+
"""Model for question-answer pair"""
|
| 50 |
+
question: str
|
| 51 |
+
answer: str
|
| 52 |
+
|
| 53 |
+
class HistoryResponse(BaseModel):
|
| 54 |
+
"""Response model for history"""
|
| 55 |
+
history: List[QuestionAnswer]
|
app/models/pdf_models.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, List, Dict, Any
|
| 3 |
+
|
| 4 |
+
class PDFUploadRequest(BaseModel):
|
| 5 |
+
"""Request model cho upload PDF"""
|
| 6 |
+
namespace: Optional[str] = Field("Default", description="Namespace trong Pinecone")
|
| 7 |
+
index_name: Optional[str] = Field("testbot768", description="Tên index trong Pinecone")
|
| 8 |
+
title: Optional[str] = Field(None, description="Tiêu đề của tài liệu")
|
| 9 |
+
description: Optional[str] = Field(None, description="Mô tả về tài liệu")
|
| 10 |
+
vector_database_id: Optional[int] = Field(None, description="ID của vector database trong PostgreSQL để sử dụng")
|
| 11 |
+
|
| 12 |
+
class PDFResponse(BaseModel):
|
| 13 |
+
"""Response model cho xử lý PDF"""
|
| 14 |
+
success: bool = Field(..., description="Trạng thái xử lý thành công hay không")
|
| 15 |
+
document_id: Optional[str] = Field(None, description="ID của tài liệu")
|
| 16 |
+
chunks_processed: Optional[int] = Field(None, description="Số lượng chunks đã xử lý")
|
| 17 |
+
total_text_length: Optional[int] = Field(None, description="Tổng độ dài văn bản")
|
| 18 |
+
error: Optional[str] = Field(None, description="Thông báo lỗi nếu có")
|
| 19 |
+
|
| 20 |
+
class Config:
|
| 21 |
+
schema_extra = {
|
| 22 |
+
"example": {
|
| 23 |
+
"success": True,
|
| 24 |
+
"document_id": "550e8400-e29b-41d4-a716-446655440000",
|
| 25 |
+
"chunks_processed": 25,
|
| 26 |
+
"total_text_length": 50000
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
class DeleteDocumentRequest(BaseModel):
|
| 31 |
+
"""Request model cho xóa document"""
|
| 32 |
+
document_id: str = Field(..., description="ID của tài liệu cần xóa")
|
| 33 |
+
namespace: Optional[str] = Field("Default", description="Namespace trong Pinecone")
|
| 34 |
+
index_name: Optional[str] = Field("testbot768", description="Tên index trong Pinecone")
|
| 35 |
+
|
| 36 |
+
class DocumentsListResponse(BaseModel):
|
| 37 |
+
"""Response model cho lấy danh sách tài liệu"""
|
| 38 |
+
success: bool = Field(..., description="Trạng thái xử lý thành công hay không")
|
| 39 |
+
total_vectors: Optional[int] = Field(None, description="Tổng số vectors trong index")
|
| 40 |
+
namespace: Optional[str] = Field(None, description="Namespace đang sử dụng")
|
| 41 |
+
index_name: Optional[str] = Field(None, description="Tên index đang sử dụng")
|
| 42 |
+
error: Optional[str] = Field(None, description="Thông báo lỗi nếu có")
|
| 43 |
+
|
| 44 |
+
class Config:
|
| 45 |
+
schema_extra = {
|
| 46 |
+
"example": {
|
| 47 |
+
"success": True,
|
| 48 |
+
"total_vectors": 5000,
|
| 49 |
+
"namespace": "Default",
|
| 50 |
+
"index_name": "testbot768"
|
| 51 |
+
}
|
| 52 |
+
}
|
app/models/rag_models.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, List, Dict, Any
|
| 3 |
+
|
| 4 |
+
class ChatRequest(BaseModel):
|
| 5 |
+
"""Request model for chat endpoint"""
|
| 6 |
+
user_id: str = Field(..., description="User ID from Telegram")
|
| 7 |
+
question: str = Field(..., description="User's question")
|
| 8 |
+
include_history: bool = Field(True, description="Whether to include user history in prompt")
|
| 9 |
+
use_rag: bool = Field(True, description="Whether to use RAG")
|
| 10 |
+
|
| 11 |
+
# Advanced retrieval parameters
|
| 12 |
+
similarity_top_k: int = Field(6, description="Number of top similar documents to return (after filtering)")
|
| 13 |
+
limit_k: int = Field(10, description="Maximum number of documents to retrieve from vector store")
|
| 14 |
+
similarity_metric: str = Field("cosine", description="Similarity metric to use (cosine, dotproduct, euclidean)")
|
| 15 |
+
similarity_threshold: float = Field(0.75, description="Threshold for vector similarity (0-1)")
|
| 16 |
+
|
| 17 |
+
# User information
|
| 18 |
+
session_id: Optional[str] = Field(None, description="Session ID for tracking conversations")
|
| 19 |
+
first_name: Optional[str] = Field(None, description="User's first name")
|
| 20 |
+
last_name: Optional[str] = Field(None, description="User's last name")
|
| 21 |
+
username: Optional[str] = Field(None, description="User's username")
|
| 22 |
+
|
| 23 |
+
class SourceDocument(BaseModel):
|
| 24 |
+
"""Model for source documents"""
|
| 25 |
+
text: str = Field(..., description="Text content of the document")
|
| 26 |
+
source: Optional[str] = Field(None, description="Source of the document")
|
| 27 |
+
score: Optional[float] = Field(None, description="Raw similarity score of the document")
|
| 28 |
+
normalized_score: Optional[float] = Field(None, description="Normalized similarity score (0-1)")
|
| 29 |
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Metadata of the document")
|
| 30 |
+
|
| 31 |
+
class ChatResponse(BaseModel):
|
| 32 |
+
"""Response model for chat endpoint"""
|
| 33 |
+
answer: str = Field(..., description="Generated answer")
|
| 34 |
+
processing_time: float = Field(..., description="Processing time in seconds")
|
| 35 |
+
|
| 36 |
+
class ChatResponseInternal(BaseModel):
|
| 37 |
+
"""Internal model for chat response with sources - used only for logging"""
|
| 38 |
+
answer: str
|
| 39 |
+
sources: Optional[List[SourceDocument]] = Field(None, description="Source documents used for generating answer")
|
| 40 |
+
processing_time: Optional[float] = None
|
| 41 |
+
|
| 42 |
+
class EmbeddingRequest(BaseModel):
|
| 43 |
+
"""Request model for embedding endpoint"""
|
| 44 |
+
text: str = Field(..., description="Text to generate embedding for")
|
| 45 |
+
|
| 46 |
+
class EmbeddingResponse(BaseModel):
|
| 47 |
+
"""Response model for embedding endpoint"""
|
| 48 |
+
embedding: List[float] = Field(..., description="Generated embedding")
|
| 49 |
+
text: str = Field(..., description="Text that was embedded")
|
| 50 |
+
model: str = Field(..., description="Model used for embedding")
|
| 51 |
+
|
| 52 |
+
class HealthResponse(BaseModel):
|
| 53 |
+
"""Response model for health endpoint"""
|
| 54 |
+
status: str
|
| 55 |
+
services: Dict[str, bool]
|
| 56 |
+
timestamp: str
|
| 57 |
+
|
| 58 |
+
class UserMessageModel(BaseModel):
|
| 59 |
+
"""Model for user messages sent to the RAG API"""
|
| 60 |
+
user_id: str = Field(..., description="User ID from the client application")
|
| 61 |
+
session_id: str = Field(..., description="Session ID for tracking the conversation")
|
| 62 |
+
message: str = Field(..., description="User's message/question")
|
| 63 |
+
|
| 64 |
+
# Advanced retrieval parameters (optional)
|
| 65 |
+
similarity_top_k: Optional[int] = Field(None, description="Number of top similar documents to return (after filtering)")
|
| 66 |
+
limit_k: Optional[int] = Field(None, description="Maximum number of documents to retrieve from vector store")
|
| 67 |
+
similarity_metric: Optional[str] = Field(None, description="Similarity metric to use (cosine, dotproduct, euclidean)")
|
| 68 |
+
similarity_threshold: Optional[float] = Field(None, description="Threshold for vector similarity (0-1)")
|
app/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Utility functions package
|
app/utils/cache.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import threading
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Any, Optional, Tuple, List, Callable, Generic, TypeVar, Union
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
# Thiết lập logging
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Load biến môi trường
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
# Cấu hình cache từ biến môi trường
|
| 17 |
+
DEFAULT_CACHE_TTL = int(os.getenv("CACHE_TTL_SECONDS", "300")) # Mặc định 5 phút
|
| 18 |
+
DEFAULT_CACHE_CLEANUP_INTERVAL = int(os.getenv("CACHE_CLEANUP_INTERVAL", "60")) # Mặc định 1 phút
|
| 19 |
+
DEFAULT_CACHE_MAX_SIZE = int(os.getenv("CACHE_MAX_SIZE", "1000")) # Mặc định 1000 phần tử
|
| 20 |
+
|
| 21 |
+
# Generic type để có thể sử dụng cho nhiều loại giá trị khác nhau
|
| 22 |
+
T = TypeVar('T')
|
| 23 |
+
|
| 24 |
+
# Cấu trúc cho một phần tử trong cache
|
| 25 |
+
class CacheItem(Generic[T]):
|
| 26 |
+
def __init__(self, value: T, ttl: int = DEFAULT_CACHE_TTL):
|
| 27 |
+
self.value = value
|
| 28 |
+
self.expire_at = time.time() + ttl
|
| 29 |
+
self.last_accessed = time.time()
|
| 30 |
+
|
| 31 |
+
def is_expired(self) -> bool:
|
| 32 |
+
"""Kiểm tra xem item có hết hạn chưa"""
|
| 33 |
+
return time.time() > self.expire_at
|
| 34 |
+
|
| 35 |
+
def touch(self) -> None:
|
| 36 |
+
"""Cập nhật thời gian truy cập lần cuối"""
|
| 37 |
+
self.last_accessed = time.time()
|
| 38 |
+
|
| 39 |
+
def extend(self, ttl: int = DEFAULT_CACHE_TTL) -> None:
|
| 40 |
+
"""Gia hạn thời gian sống của item"""
|
| 41 |
+
self.expire_at = time.time() + ttl
|
| 42 |
+
|
| 43 |
+
# Lớp cache chính
|
| 44 |
+
class InMemoryCache:
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
ttl: int = DEFAULT_CACHE_TTL,
|
| 48 |
+
cleanup_interval: int = DEFAULT_CACHE_CLEANUP_INTERVAL,
|
| 49 |
+
max_size: int = DEFAULT_CACHE_MAX_SIZE
|
| 50 |
+
):
|
| 51 |
+
self.cache: Dict[str, CacheItem] = {}
|
| 52 |
+
self.ttl = ttl
|
| 53 |
+
self.cleanup_interval = cleanup_interval
|
| 54 |
+
self.max_size = max_size
|
| 55 |
+
self.lock = threading.RLock() # Sử dụng RLock để tránh deadlock
|
| 56 |
+
|
| 57 |
+
# Khởi động thread dọn dẹp cache định kỳ (active expiration)
|
| 58 |
+
self.cleanup_thread = threading.Thread(target=self._cleanup_task, daemon=True)
|
| 59 |
+
self.cleanup_thread.start()
|
| 60 |
+
|
| 61 |
+
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
| 62 |
+
"""Lưu một giá trị vào cache"""
|
| 63 |
+
with self.lock:
|
| 64 |
+
ttl_value = ttl if ttl is not None else self.ttl
|
| 65 |
+
|
| 66 |
+
# Nếu cache đã đầy, xóa bớt các item ít được truy cập nhất
|
| 67 |
+
if len(self.cache) >= self.max_size and key not in self.cache:
|
| 68 |
+
self._evict_lru_items()
|
| 69 |
+
|
| 70 |
+
self.cache[key] = CacheItem(value, ttl_value)
|
| 71 |
+
logger.debug(f"Cache set: {key} (expires in {ttl_value}s)")
|
| 72 |
+
|
| 73 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 74 |
+
"""
|
| 75 |
+
Lấy giá trị từ cache. Nếu key không tồn tại hoặc đã hết hạn, trả về giá trị mặc định.
|
| 76 |
+
Áp dụng lazy expiration: kiểm tra và xóa các item hết hạn khi truy cập.
|
| 77 |
+
"""
|
| 78 |
+
with self.lock:
|
| 79 |
+
item = self.cache.get(key)
|
| 80 |
+
|
| 81 |
+
# Nếu không tìm thấy key hoặc item đã hết hạn
|
| 82 |
+
if item is None or item.is_expired():
|
| 83 |
+
# Nếu item tồn tại nhưng đã hết hạn, xóa nó (lazy expiration)
|
| 84 |
+
if item is not None:
|
| 85 |
+
logger.debug(f"Cache miss (expired): {key}")
|
| 86 |
+
del self.cache[key]
|
| 87 |
+
else:
|
| 88 |
+
logger.debug(f"Cache miss (not found): {key}")
|
| 89 |
+
return default
|
| 90 |
+
|
| 91 |
+
# Cập nhật thời gian truy cập
|
| 92 |
+
item.touch()
|
| 93 |
+
logger.debug(f"Cache hit: {key}")
|
| 94 |
+
return item.value
|
| 95 |
+
|
| 96 |
+
def delete(self, key: str) -> bool:
|
| 97 |
+
"""Xóa một key khỏi cache"""
|
| 98 |
+
with self.lock:
|
| 99 |
+
if key in self.cache:
|
| 100 |
+
del self.cache[key]
|
| 101 |
+
logger.debug(f"Cache delete: {key}")
|
| 102 |
+
return True
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def clear(self) -> None:
|
| 106 |
+
"""Xóa tất cả dữ liệu trong cache"""
|
| 107 |
+
with self.lock:
|
| 108 |
+
self.cache.clear()
|
| 109 |
+
logger.debug("Cache cleared")
|
| 110 |
+
|
| 111 |
+
def get_or_set(self, key: str, callback: Callable[[], T], ttl: Optional[int] = None) -> T:
|
| 112 |
+
"""
|
| 113 |
+
Lấy giá trị từ cache nếu tồn tại, nếu không thì gọi callback để lấy giá trị
|
| 114 |
+
và lưu vào cache trước khi trả về.
|
| 115 |
+
"""
|
| 116 |
+
with self.lock:
|
| 117 |
+
value = self.get(key)
|
| 118 |
+
if value is None:
|
| 119 |
+
value = callback()
|
| 120 |
+
self.set(key, value, ttl)
|
| 121 |
+
return value
|
| 122 |
+
|
| 123 |
+
def _cleanup_task(self) -> None:
|
| 124 |
+
"""Thread để dọn dẹp các item đã hết hạn (active expiration)"""
|
| 125 |
+
while True:
|
| 126 |
+
time.sleep(self.cleanup_interval)
|
| 127 |
+
try:
|
| 128 |
+
self._remove_expired_items()
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Error in cache cleanup task: {e}")
|
| 131 |
+
|
| 132 |
+
def _remove_expired_items(self) -> None:
|
| 133 |
+
"""Xóa tất cả các item đã hết hạn trong cache"""
|
| 134 |
+
with self.lock:
|
| 135 |
+
now = time.time()
|
| 136 |
+
expired_keys = [k for k, v in self.cache.items() if v.is_expired()]
|
| 137 |
+
for key in expired_keys:
|
| 138 |
+
del self.cache[key]
|
| 139 |
+
|
| 140 |
+
if expired_keys:
|
| 141 |
+
logger.debug(f"Cleaned up {len(expired_keys)} expired cache items")
|
| 142 |
+
|
| 143 |
+
def _evict_lru_items(self, count: int = 1) -> None:
|
| 144 |
+
"""Xóa bỏ các item ít được truy cập nhất khi cache đầy"""
|
| 145 |
+
items = sorted(self.cache.items(), key=lambda x: x[1].last_accessed)
|
| 146 |
+
for i in range(min(count, len(items))):
|
| 147 |
+
del self.cache[items[i][0]]
|
| 148 |
+
logger.debug(f"Evicted {min(count, len(items))} least recently used items from cache")
|
| 149 |
+
|
| 150 |
+
def stats(self) -> Dict[str, Any]:
|
| 151 |
+
"""Trả về thống kê về cache"""
|
| 152 |
+
with self.lock:
|
| 153 |
+
now = time.time()
|
| 154 |
+
total_items = len(self.cache)
|
| 155 |
+
expired_items = sum(1 for item in self.cache.values() if item.is_expired())
|
| 156 |
+
memory_usage = self._estimate_memory_usage()
|
| 157 |
+
return {
|
| 158 |
+
"total_items": total_items,
|
| 159 |
+
"expired_items": expired_items,
|
| 160 |
+
"active_items": total_items - expired_items,
|
| 161 |
+
"memory_usage_bytes": memory_usage,
|
| 162 |
+
"memory_usage_mb": memory_usage / (1024 * 1024),
|
| 163 |
+
"max_size": self.max_size
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
def _estimate_memory_usage(self) -> int:
|
| 167 |
+
"""Ước tính dung lượng bộ nhớ của cache (gần đúng)"""
|
| 168 |
+
# Ước tính dựa trên kích thước của các key và giá trị
|
| 169 |
+
cache_size = sum(len(k) for k in self.cache.keys())
|
| 170 |
+
for item in self.cache.values():
|
| 171 |
+
try:
|
| 172 |
+
# Ước tính kích thước của value (gần đúng)
|
| 173 |
+
if isinstance(item.value, (str, bytes)):
|
| 174 |
+
cache_size += len(item.value)
|
| 175 |
+
elif isinstance(item.value, (dict, list)):
|
| 176 |
+
cache_size += len(json.dumps(item.value))
|
| 177 |
+
else:
|
| 178 |
+
# Giá trị mặc định cho các loại dữ liệu khác
|
| 179 |
+
cache_size += 100
|
| 180 |
+
except:
|
| 181 |
+
cache_size += 100
|
| 182 |
+
|
| 183 |
+
return cache_size
|
| 184 |
+
|
| 185 |
+
# Singleton instance
|
| 186 |
+
_cache_instance = None
|
| 187 |
+
|
| 188 |
+
def get_cache() -> InMemoryCache:
|
| 189 |
+
"""Trả về instance singleton của InMemoryCache"""
|
| 190 |
+
global _cache_instance
|
| 191 |
+
if _cache_instance is None:
|
| 192 |
+
_cache_instance = InMemoryCache()
|
| 193 |
+
return _cache_instance
|
app/utils/debug_utils.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import traceback
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import platform
|
| 9 |
+
|
| 10 |
+
# Try to import psutil, provide fallback if not available
|
| 11 |
+
try:
|
| 12 |
+
import psutil
|
| 13 |
+
PSUTIL_AVAILABLE = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
PSUTIL_AVAILABLE = False
|
| 16 |
+
logging.warning("psutil module not available. System monitoring features will be limited.")
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
class DebugInfo:
|
| 22 |
+
"""Class containing debug information"""
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def get_system_info():
|
| 26 |
+
"""Get system information"""
|
| 27 |
+
try:
|
| 28 |
+
info = {
|
| 29 |
+
"os": platform.system(),
|
| 30 |
+
"os_version": platform.version(),
|
| 31 |
+
"python_version": platform.python_version(),
|
| 32 |
+
"cpu_count": os.cpu_count(),
|
| 33 |
+
"timestamp": datetime.now().isoformat()
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Add information from psutil if available
|
| 37 |
+
if PSUTIL_AVAILABLE:
|
| 38 |
+
info.update({
|
| 39 |
+
"total_memory": round(psutil.virtual_memory().total / (1024 * 1024 * 1024), 2), # GB
|
| 40 |
+
"available_memory": round(psutil.virtual_memory().available / (1024 * 1024 * 1024), 2), # GB
|
| 41 |
+
"cpu_usage": psutil.cpu_percent(interval=0.1),
|
| 42 |
+
"memory_usage": psutil.virtual_memory().percent,
|
| 43 |
+
"disk_usage": psutil.disk_usage('/').percent,
|
| 44 |
+
})
|
| 45 |
+
else:
|
| 46 |
+
info.update({
|
| 47 |
+
"total_memory": "psutil not available",
|
| 48 |
+
"available_memory": "psutil not available",
|
| 49 |
+
"cpu_usage": "psutil not available",
|
| 50 |
+
"memory_usage": "psutil not available",
|
| 51 |
+
"disk_usage": "psutil not available",
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
return info
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"Error getting system info: {e}")
|
| 57 |
+
return {"error": str(e)}
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_env_info():
|
| 61 |
+
"""Get environment variable information (masking sensitive information)"""
|
| 62 |
+
try:
|
| 63 |
+
# List of environment variables to mask values
|
| 64 |
+
sensitive_vars = [
|
| 65 |
+
"API_KEY", "SECRET", "PASSWORD", "TOKEN", "AUTH", "MONGODB_URL",
|
| 66 |
+
"AIVEN_DB_URL", "PINECONE_API_KEY", "GOOGLE_API_KEY"
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
env_vars = {}
|
| 70 |
+
for key, value in os.environ.items():
|
| 71 |
+
# Check if environment variable contains sensitive words
|
| 72 |
+
is_sensitive = any(s in key.upper() for s in sensitive_vars)
|
| 73 |
+
|
| 74 |
+
if is_sensitive and value:
|
| 75 |
+
# Mask value displaying only the first 4 characters
|
| 76 |
+
masked_value = value[:4] + "****" if len(value) > 4 else "****"
|
| 77 |
+
env_vars[key] = masked_value
|
| 78 |
+
else:
|
| 79 |
+
env_vars[key] = value
|
| 80 |
+
|
| 81 |
+
return env_vars
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Error getting environment info: {e}")
|
| 84 |
+
return {"error": str(e)}
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def get_database_status():
|
| 88 |
+
"""Get database connection status"""
|
| 89 |
+
try:
|
| 90 |
+
from app.database.postgresql import check_db_connection as check_postgresql
|
| 91 |
+
from app.database.mongodb import check_db_connection as check_mongodb
|
| 92 |
+
from app.database.pinecone import check_db_connection as check_pinecone
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"postgresql": check_postgresql(),
|
| 96 |
+
"mongodb": check_mongodb(),
|
| 97 |
+
"pinecone": check_pinecone(),
|
| 98 |
+
"timestamp": datetime.now().isoformat()
|
| 99 |
+
}
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Error getting database status: {e}")
|
| 102 |
+
return {"error": str(e)}
|
| 103 |
+
|
| 104 |
+
class PerformanceMonitor:
|
| 105 |
+
"""Performance monitoring class"""
|
| 106 |
+
|
| 107 |
+
def __init__(self):
|
| 108 |
+
self.start_time = time.time()
|
| 109 |
+
self.checkpoints = []
|
| 110 |
+
|
| 111 |
+
def checkpoint(self, name):
|
| 112 |
+
"""Mark a checkpoint and record the time"""
|
| 113 |
+
current_time = time.time()
|
| 114 |
+
elapsed = current_time - self.start_time
|
| 115 |
+
self.checkpoints.append({
|
| 116 |
+
"name": name,
|
| 117 |
+
"time": current_time,
|
| 118 |
+
"elapsed": elapsed
|
| 119 |
+
})
|
| 120 |
+
logger.debug(f"Checkpoint '{name}' at {elapsed:.4f}s")
|
| 121 |
+
return elapsed
|
| 122 |
+
|
| 123 |
+
def get_report(self):
|
| 124 |
+
"""Generate performance report"""
|
| 125 |
+
if not self.checkpoints:
|
| 126 |
+
return {"error": "No checkpoints recorded"}
|
| 127 |
+
|
| 128 |
+
total_time = time.time() - self.start_time
|
| 129 |
+
|
| 130 |
+
# Calculate time between checkpoints
|
| 131 |
+
intervals = []
|
| 132 |
+
prev_time = self.start_time
|
| 133 |
+
|
| 134 |
+
for checkpoint in self.checkpoints:
|
| 135 |
+
interval = checkpoint["time"] - prev_time
|
| 136 |
+
intervals.append({
|
| 137 |
+
"name": checkpoint["name"],
|
| 138 |
+
"interval": interval,
|
| 139 |
+
"elapsed": checkpoint["elapsed"]
|
| 140 |
+
})
|
| 141 |
+
prev_time = checkpoint["time"]
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"total_time": total_time,
|
| 145 |
+
"checkpoint_count": len(self.checkpoints),
|
| 146 |
+
"intervals": intervals
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
class ErrorTracker:
|
| 150 |
+
"""Class to track and record errors"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, max_errors=100):
|
| 153 |
+
self.errors = []
|
| 154 |
+
self.max_errors = max_errors
|
| 155 |
+
|
| 156 |
+
def track_error(self, error, context=None):
|
| 157 |
+
"""Record error information"""
|
| 158 |
+
error_info = {
|
| 159 |
+
"error_type": type(error).__name__,
|
| 160 |
+
"error_message": str(error),
|
| 161 |
+
"traceback": traceback.format_exc(),
|
| 162 |
+
"timestamp": datetime.now().isoformat(),
|
| 163 |
+
"context": context or {}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Add to error list
|
| 167 |
+
self.errors.append(error_info)
|
| 168 |
+
|
| 169 |
+
# Limit the number of stored errors
|
| 170 |
+
if len(self.errors) > self.max_errors:
|
| 171 |
+
self.errors.pop(0) # Remove oldest error
|
| 172 |
+
|
| 173 |
+
return error_info
|
| 174 |
+
|
| 175 |
+
def get_errors(self, limit=None):
|
| 176 |
+
"""Get list of recorded errors"""
|
| 177 |
+
if limit is None or limit >= len(self.errors):
|
| 178 |
+
return self.errors
|
| 179 |
+
return self.errors[-limit:] # Return most recent errors
|
| 180 |
+
|
| 181 |
+
# Initialize global objects
|
| 182 |
+
error_tracker = ErrorTracker()
|
| 183 |
+
performance_monitor = PerformanceMonitor()
|
| 184 |
+
|
| 185 |
+
def debug_view(request=None):
|
| 186 |
+
"""Create a full debug report"""
|
| 187 |
+
debug_data = {
|
| 188 |
+
"system_info": DebugInfo.get_system_info(),
|
| 189 |
+
"database_status": DebugInfo.get_database_status(),
|
| 190 |
+
"performance": performance_monitor.get_report(),
|
| 191 |
+
"recent_errors": error_tracker.get_errors(limit=10),
|
| 192 |
+
"timestamp": datetime.now().isoformat()
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# Add request information if available
|
| 196 |
+
if request:
|
| 197 |
+
debug_data["request"] = {
|
| 198 |
+
"method": request.method,
|
| 199 |
+
"url": str(request.url),
|
| 200 |
+
"headers": dict(request.headers),
|
| 201 |
+
"client": {
|
| 202 |
+
"host": request.client.host if request.client else "unknown",
|
| 203 |
+
"port": request.client.port if request.client else "unknown"
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
return debug_data
|
app/utils/middleware.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request, status
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 4 |
+
import logging
|
| 5 |
+
import time
|
| 6 |
+
import traceback
|
| 7 |
+
import uuid
|
| 8 |
+
from .utils import get_local_time
|
| 9 |
+
|
| 10 |
+
# Configure logging
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
| 14 |
+
"""Middleware to log requests and responses"""
|
| 15 |
+
|
| 16 |
+
async def dispatch(self, request: Request, call_next):
|
| 17 |
+
request_id = str(uuid.uuid4())
|
| 18 |
+
request.state.request_id = request_id
|
| 19 |
+
|
| 20 |
+
# Log request information
|
| 21 |
+
client_host = request.client.host if request.client else "unknown"
|
| 22 |
+
logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}")
|
| 23 |
+
|
| 24 |
+
# Measure processing time
|
| 25 |
+
start_time = time.time()
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
# Process request
|
| 29 |
+
response = await call_next(request)
|
| 30 |
+
|
| 31 |
+
# Calculate processing time
|
| 32 |
+
process_time = time.time() - start_time
|
| 33 |
+
logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s")
|
| 34 |
+
|
| 35 |
+
# Add headers
|
| 36 |
+
response.headers["X-Request-ID"] = request_id
|
| 37 |
+
response.headers["X-Process-Time"] = str(process_time)
|
| 38 |
+
|
| 39 |
+
return response
|
| 40 |
+
|
| 41 |
+
except Exception as e:
|
| 42 |
+
# Log error
|
| 43 |
+
process_time = time.time() - start_time
|
| 44 |
+
logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}")
|
| 45 |
+
logger.error(traceback.format_exc())
|
| 46 |
+
|
| 47 |
+
# Return error response
|
| 48 |
+
return JSONResponse(
|
| 49 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 50 |
+
content={
|
| 51 |
+
"detail": "Internal server error",
|
| 52 |
+
"request_id": request_id,
|
| 53 |
+
"timestamp": get_local_time()
|
| 54 |
+
}
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
| 58 |
+
"""Middleware to handle uncaught exceptions in the application"""
|
| 59 |
+
|
| 60 |
+
async def dispatch(self, request: Request, call_next):
|
| 61 |
+
try:
|
| 62 |
+
return await call_next(request)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
# Get request_id if available
|
| 65 |
+
request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
|
| 66 |
+
|
| 67 |
+
# Log error
|
| 68 |
+
logger.error(f"Uncaught exception [{request_id}]: {str(e)}")
|
| 69 |
+
logger.error(traceback.format_exc())
|
| 70 |
+
|
| 71 |
+
# Return error response
|
| 72 |
+
return JSONResponse(
|
| 73 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 74 |
+
content={
|
| 75 |
+
"detail": "Internal server error",
|
| 76 |
+
"request_id": request_id,
|
| 77 |
+
"timestamp": get_local_time()
|
| 78 |
+
}
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
class DatabaseCheckMiddleware(BaseHTTPMiddleware):
|
| 82 |
+
"""Middleware to check database connections before each request"""
|
| 83 |
+
|
| 84 |
+
async def dispatch(self, request: Request, call_next):
|
| 85 |
+
# Skip paths that don't need database checks
|
| 86 |
+
skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"]
|
| 87 |
+
if request.url.path in skip_paths:
|
| 88 |
+
return await call_next(request)
|
| 89 |
+
|
| 90 |
+
# Check database connections
|
| 91 |
+
try:
|
| 92 |
+
# TODO: Add checks for MongoDB and Pinecone if needed
|
| 93 |
+
# PostgreSQL check is already done in route handler with get_db() method
|
| 94 |
+
|
| 95 |
+
# Process request normally
|
| 96 |
+
return await call_next(request)
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
# Log error
|
| 100 |
+
logger.error(f"Database connection check failed: {str(e)}")
|
| 101 |
+
|
| 102 |
+
# Return error response
|
| 103 |
+
return JSONResponse(
|
| 104 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 105 |
+
content={
|
| 106 |
+
"detail": "Database connection failed",
|
| 107 |
+
"timestamp": get_local_time()
|
| 108 |
+
}
|
| 109 |
+
)
|
app/utils/pdf_processor.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 6 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 7 |
+
import logging
|
| 8 |
+
from pinecone import Pinecone
|
| 9 |
+
|
| 10 |
+
from app.database.pinecone import get_pinecone_index, init_pinecone
|
| 11 |
+
from app.database.postgresql import get_db
|
| 12 |
+
from app.database.models import VectorDatabase
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Initialize embeddings model
|
| 18 |
+
embeddings_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
| 19 |
+
|
| 20 |
+
class PDFProcessor:
|
| 21 |
+
"""Class for processing PDF files and creating embeddings"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, index_name="testbot768", namespace="Default", api_key=None, vector_db_id=None, mock_mode=False):
|
| 24 |
+
"""Initialize with Pinecone index name, namespace and API key"""
|
| 25 |
+
self.index_name = index_name
|
| 26 |
+
self.namespace = namespace
|
| 27 |
+
self.pinecone_index = None
|
| 28 |
+
self.api_key = api_key
|
| 29 |
+
self.vector_db_id = vector_db_id
|
| 30 |
+
self.pinecone_client = None
|
| 31 |
+
self.mock_mode = mock_mode # Add mock mode for testing
|
| 32 |
+
|
| 33 |
+
def _get_api_key_from_db(self):
|
| 34 |
+
"""Get API key from database if not provided directly"""
|
| 35 |
+
if self.api_key:
|
| 36 |
+
return self.api_key
|
| 37 |
+
|
| 38 |
+
if not self.vector_db_id:
|
| 39 |
+
logger.error("No API key provided and no vector_db_id to fetch from database")
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
# Get database session
|
| 44 |
+
db = next(get_db())
|
| 45 |
+
|
| 46 |
+
# Get vector database
|
| 47 |
+
vector_db = db.query(VectorDatabase).filter(
|
| 48 |
+
VectorDatabase.id == self.vector_db_id
|
| 49 |
+
).first()
|
| 50 |
+
|
| 51 |
+
if not vector_db:
|
| 52 |
+
logger.error(f"Vector database with ID {self.vector_db_id} not found")
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
# Get API key from relationship
|
| 56 |
+
if hasattr(vector_db, 'api_key_ref') and vector_db.api_key_ref and hasattr(vector_db.api_key_ref, 'key_value'):
|
| 57 |
+
logger.info(f"Using API key from api_key table for vector database ID {self.vector_db_id}")
|
| 58 |
+
return vector_db.api_key_ref.key_value
|
| 59 |
+
|
| 60 |
+
logger.error(f"No API key found for vector database ID {self.vector_db_id}. Make sure the api_key_id is properly set.")
|
| 61 |
+
return None
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.error(f"Error fetching API key from database: {e}")
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
def _init_pinecone_connection(self):
|
| 67 |
+
"""Initialize connection to Pinecone with new API"""
|
| 68 |
+
try:
|
| 69 |
+
# If in mock mode, return a mock index
|
| 70 |
+
if self.mock_mode:
|
| 71 |
+
logger.info("Running in mock mode - simulating Pinecone connection")
|
| 72 |
+
class MockPineconeIndex:
|
| 73 |
+
def upsert(self, vectors, namespace=None):
|
| 74 |
+
logger.info(f"Mock upsert: {len(vectors)} vectors to namespace '{namespace}'")
|
| 75 |
+
return {"upserted_count": len(vectors)}
|
| 76 |
+
|
| 77 |
+
def delete(self, ids=None, delete_all=False, namespace=None):
|
| 78 |
+
logger.info(f"Mock delete: {'all vectors' if delete_all else f'{len(ids)} vectors'} from namespace '{namespace}'")
|
| 79 |
+
return {"deleted_count": 10 if delete_all else len(ids or [])}
|
| 80 |
+
|
| 81 |
+
def describe_index_stats(self):
|
| 82 |
+
logger.info(f"Mock describe_index_stats")
|
| 83 |
+
return {"total_vector_count": 100, "namespaces": {self.namespace: {"vector_count": 50}}}
|
| 84 |
+
|
| 85 |
+
return MockPineconeIndex()
|
| 86 |
+
|
| 87 |
+
# Get API key from database if not provided
|
| 88 |
+
api_key = self._get_api_key_from_db()
|
| 89 |
+
|
| 90 |
+
if not api_key or not self.index_name:
|
| 91 |
+
logger.error("Pinecone API key or index name not available")
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
# Initialize Pinecone client using the new API
|
| 95 |
+
self.pinecone_client = Pinecone(api_key=api_key)
|
| 96 |
+
|
| 97 |
+
# Get the index
|
| 98 |
+
index_list = self.pinecone_client.list_indexes()
|
| 99 |
+
existing_indexes = index_list.names() if hasattr(index_list, 'names') else []
|
| 100 |
+
|
| 101 |
+
if self.index_name not in existing_indexes:
|
| 102 |
+
logger.error(f"Index {self.index_name} does not exist in Pinecone")
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
# Connect to the index
|
| 106 |
+
index = self.pinecone_client.Index(self.index_name)
|
| 107 |
+
logger.info(f"Connected to Pinecone index: {self.index_name}")
|
| 108 |
+
return index
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Error connecting to Pinecone: {e}")
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
async def process_pdf(self, file_path, document_id=None, metadata=None, progress_callback=None):
|
| 114 |
+
"""
|
| 115 |
+
Process PDF file, split into chunks and create embeddings
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
file_path (str): Path to the PDF file
|
| 119 |
+
document_id (str, optional): Document ID, if not provided a new ID will be created
|
| 120 |
+
metadata (dict, optional): Additional metadata for the document
|
| 121 |
+
progress_callback (callable, optional): Callback function for progress updates
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
dict: Processing result information including document_id and processed chunks count
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
# Initialize Pinecone connection if not already done
|
| 128 |
+
self.pinecone_index = self._init_pinecone_connection()
|
| 129 |
+
if not self.pinecone_index:
|
| 130 |
+
return {"success": False, "error": "Could not connect to Pinecone"}
|
| 131 |
+
|
| 132 |
+
# Create document_id if not provided
|
| 133 |
+
if not document_id:
|
| 134 |
+
document_id = str(uuid.uuid4())
|
| 135 |
+
|
| 136 |
+
# Load PDF using PyPDFLoader
|
| 137 |
+
logger.info(f"Reading PDF file: {file_path}")
|
| 138 |
+
if progress_callback:
|
| 139 |
+
await progress_callback("pdf_loading", 0.5, "Loading PDF file")
|
| 140 |
+
|
| 141 |
+
loader = PyPDFLoader(file_path)
|
| 142 |
+
pages = loader.load()
|
| 143 |
+
|
| 144 |
+
# Extract and concatenate text from all pages
|
| 145 |
+
all_text = ""
|
| 146 |
+
for page in pages:
|
| 147 |
+
all_text += page.page_content + "\n"
|
| 148 |
+
|
| 149 |
+
if progress_callback:
|
| 150 |
+
await progress_callback("text_extraction", 0.6, "Extracted text from PDF")
|
| 151 |
+
|
| 152 |
+
# Split text into chunks
|
| 153 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300)
|
| 154 |
+
chunks = text_splitter.split_text(all_text)
|
| 155 |
+
|
| 156 |
+
logger.info(f"Split PDF file into {len(chunks)} chunks")
|
| 157 |
+
if progress_callback:
|
| 158 |
+
await progress_callback("chunking", 0.7, f"Split document into {len(chunks)} chunks")
|
| 159 |
+
|
| 160 |
+
# Process embeddings for each chunk and upsert to Pinecone
|
| 161 |
+
vectors = []
|
| 162 |
+
for i, chunk in enumerate(chunks):
|
| 163 |
+
# Update embedding progress
|
| 164 |
+
if progress_callback and i % 5 == 0: # Update every 5 chunks to avoid too many notifications
|
| 165 |
+
embedding_progress = 0.7 + (0.3 * (i / len(chunks)))
|
| 166 |
+
await progress_callback("embedding", embedding_progress, f"Processing chunk {i+1}/{len(chunks)}")
|
| 167 |
+
|
| 168 |
+
# Create vector embedding for each chunk
|
| 169 |
+
vector = embeddings_model.embed_query(chunk)
|
| 170 |
+
|
| 171 |
+
# Prepare metadata for vector
|
| 172 |
+
vector_metadata = {
|
| 173 |
+
"document_id": document_id,
|
| 174 |
+
"chunk_index": i,
|
| 175 |
+
"text": chunk
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
# Add additional metadata if provided
|
| 179 |
+
if metadata:
|
| 180 |
+
for key, value in metadata.items():
|
| 181 |
+
if key not in vector_metadata:
|
| 182 |
+
vector_metadata[key] = value
|
| 183 |
+
|
| 184 |
+
# Add vector to list for upserting
|
| 185 |
+
vectors.append({
|
| 186 |
+
"id": f"{document_id}_{i}",
|
| 187 |
+
"values": vector,
|
| 188 |
+
"metadata": vector_metadata
|
| 189 |
+
})
|
| 190 |
+
|
| 191 |
+
# Upsert in batches of 100 to avoid overloading
|
| 192 |
+
if len(vectors) >= 100:
|
| 193 |
+
await self._upsert_vectors(vectors)
|
| 194 |
+
vectors = []
|
| 195 |
+
|
| 196 |
+
# Upsert any remaining vectors
|
| 197 |
+
if vectors:
|
| 198 |
+
await self._upsert_vectors(vectors)
|
| 199 |
+
|
| 200 |
+
logger.info(f"Embedded and saved {len(chunks)} chunks from PDF with document_id: {document_id}")
|
| 201 |
+
|
| 202 |
+
# Final progress update
|
| 203 |
+
if progress_callback:
|
| 204 |
+
await progress_callback("completed", 1.0, "PDF processing complete")
|
| 205 |
+
|
| 206 |
+
return {
|
| 207 |
+
"success": True,
|
| 208 |
+
"document_id": document_id,
|
| 209 |
+
"chunks_processed": len(chunks),
|
| 210 |
+
"total_text_length": len(all_text)
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.error(f"Error processing PDF: {str(e)}")
|
| 215 |
+
if progress_callback:
|
| 216 |
+
await progress_callback("error", 0, f"Error processing PDF: {str(e)}")
|
| 217 |
+
return {
|
| 218 |
+
"success": False,
|
| 219 |
+
"error": str(e)
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
async def _upsert_vectors(self, vectors):
|
| 223 |
+
"""Upsert vectors to Pinecone"""
|
| 224 |
+
try:
|
| 225 |
+
if not vectors:
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
# Ensure we have a valid pinecone_index
|
| 229 |
+
if not self.pinecone_index:
|
| 230 |
+
self.pinecone_index = self._init_pinecone_connection()
|
| 231 |
+
if not self.pinecone_index:
|
| 232 |
+
raise Exception("Cannot connect to Pinecone")
|
| 233 |
+
|
| 234 |
+
result = self.pinecone_index.upsert(
|
| 235 |
+
vectors=vectors,
|
| 236 |
+
namespace=self.namespace
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
logger.info(f"Upserted {len(vectors)} vectors to Pinecone")
|
| 240 |
+
return result
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.error(f"Error upserting vectors: {str(e)}")
|
| 243 |
+
raise
|
| 244 |
+
|
| 245 |
+
async def delete_namespace(self):
|
| 246 |
+
"""
|
| 247 |
+
Delete all vectors in the current namespace (equivalent to deleting the namespace).
|
| 248 |
+
"""
|
| 249 |
+
# Initialize connection if needed
|
| 250 |
+
self.pinecone_index = self._init_pinecone_connection()
|
| 251 |
+
if not self.pinecone_index:
|
| 252 |
+
return {"success": False, "error": "Could not connect to Pinecone"}
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
# delete_all=True will delete all vectors in the namespace
|
| 256 |
+
result = self.pinecone_index.delete(
|
| 257 |
+
delete_all=True,
|
| 258 |
+
namespace=self.namespace
|
| 259 |
+
)
|
| 260 |
+
logger.info(f"Deleted namespace '{self.namespace}' (all vectors).")
|
| 261 |
+
return {"success": True, "detail": result}
|
| 262 |
+
except Exception as e:
|
| 263 |
+
logger.error(f"Error deleting namespace '{self.namespace}': {e}")
|
| 264 |
+
return {"success": False, "error": str(e)}
|
| 265 |
+
|
| 266 |
+
async def list_documents(self):
|
| 267 |
+
"""Get list of all document_ids from Pinecone"""
|
| 268 |
+
try:
|
| 269 |
+
# Initialize Pinecone connection if not already done
|
| 270 |
+
self.pinecone_index = self._init_pinecone_connection()
|
| 271 |
+
if not self.pinecone_index:
|
| 272 |
+
return {"success": False, "error": "Could not connect to Pinecone"}
|
| 273 |
+
|
| 274 |
+
# Get index information
|
| 275 |
+
stats = self.pinecone_index.describe_index_stats()
|
| 276 |
+
|
| 277 |
+
# Query to get list of all unique document_ids
|
| 278 |
+
# This method may not be efficient with large datasets, but is the simplest approach
|
| 279 |
+
# In practice, you should maintain a list of document_ids in a separate database
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
"success": True,
|
| 283 |
+
"total_vectors": stats.get('total_vector_count', 0),
|
| 284 |
+
"namespace": self.namespace,
|
| 285 |
+
"index_name": self.index_name
|
| 286 |
+
}
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error(f"Error getting document list: {str(e)}")
|
| 289 |
+
return {
|
| 290 |
+
"success": False,
|
| 291 |
+
"error": str(e)
|
| 292 |
+
}
|
app/utils/utils.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
import threading
|
| 5 |
+
import os
|
| 6 |
+
from functools import wraps
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
import pytz
|
| 9 |
+
from typing import Callable, Any, Dict, Optional, List, Tuple, Set
|
| 10 |
+
import gc
|
| 11 |
+
import heapq
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(
|
| 15 |
+
level=logging.INFO,
|
| 16 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 17 |
+
)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Asia/Ho_Chi_Minh timezone
|
| 21 |
+
asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
|
| 22 |
+
|
| 23 |
+
def generate_uuid():
|
| 24 |
+
"""Generate a unique identifier"""
|
| 25 |
+
return str(uuid.uuid4())
|
| 26 |
+
|
| 27 |
+
def get_current_time():
|
| 28 |
+
"""Get current time in ISO format"""
|
| 29 |
+
return datetime.now().isoformat()
|
| 30 |
+
|
| 31 |
+
def get_local_time():
|
| 32 |
+
"""Get current time in Asia/Ho_Chi_Minh timezone"""
|
| 33 |
+
return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
|
| 34 |
+
|
| 35 |
+
def get_local_datetime():
|
| 36 |
+
"""Get current datetime object in Asia/Ho_Chi_Minh timezone"""
|
| 37 |
+
return datetime.now(asia_tz)
|
| 38 |
+
|
| 39 |
+
# For backward compatibility
|
| 40 |
+
get_vietnam_time = get_local_time
|
| 41 |
+
get_vietnam_datetime = get_local_datetime
|
| 42 |
+
|
| 43 |
+
def timer_decorator(func: Callable) -> Callable:
|
| 44 |
+
"""
|
| 45 |
+
Decorator to time function execution and log results.
|
| 46 |
+
"""
|
| 47 |
+
@wraps(func)
|
| 48 |
+
async def wrapper(*args, **kwargs):
|
| 49 |
+
start_time = time.time()
|
| 50 |
+
try:
|
| 51 |
+
result = await func(*args, **kwargs)
|
| 52 |
+
elapsed_time = time.time() - start_time
|
| 53 |
+
logger.info(f"Function {func.__name__} executed in {elapsed_time:.4f} seconds")
|
| 54 |
+
return result
|
| 55 |
+
except Exception as e:
|
| 56 |
+
elapsed_time = time.time() - start_time
|
| 57 |
+
logger.error(f"Function {func.__name__} failed after {elapsed_time:.4f} seconds: {e}")
|
| 58 |
+
raise
|
| 59 |
+
return wrapper
|
| 60 |
+
|
| 61 |
+
def sanitize_input(text):
|
| 62 |
+
"""Sanitize input text"""
|
| 63 |
+
if not text:
|
| 64 |
+
return ""
|
| 65 |
+
# Remove potential dangerous characters or patterns
|
| 66 |
+
return text.strip()
|
| 67 |
+
|
| 68 |
+
def truncate_text(text, max_length=100):
|
| 69 |
+
"""
|
| 70 |
+
Truncate text to given max length and add ellipsis.
|
| 71 |
+
"""
|
| 72 |
+
if not text or len(text) <= max_length:
|
| 73 |
+
return text
|
| 74 |
+
return text[:max_length] + "..."
|
| 75 |
+
|
| 76 |
+
class CacheStrategy:
|
| 77 |
+
"""Cache loading strategy enumeration"""
|
| 78 |
+
LAZY = "lazy" # Only load items into cache when requested
|
| 79 |
+
EAGER = "eager" # Preload items into cache at initialization
|
| 80 |
+
MIXED = "mixed" # Preload high-priority items, lazy load others
|
| 81 |
+
|
| 82 |
+
class CacheItem:
|
| 83 |
+
"""Represents an item in the cache with metadata"""
|
| 84 |
+
def __init__(self, key: str, value: Any, ttl: int = 300, priority: int = 1):
|
| 85 |
+
self.key = key
|
| 86 |
+
self.value = value
|
| 87 |
+
self.expiry = datetime.now() + timedelta(seconds=ttl)
|
| 88 |
+
self.priority = priority # Higher number = higher priority
|
| 89 |
+
self.access_count = 0 # Track number of accesses
|
| 90 |
+
self.last_accessed = datetime.now()
|
| 91 |
+
|
| 92 |
+
def is_expired(self) -> bool:
|
| 93 |
+
"""Check if the item is expired"""
|
| 94 |
+
return datetime.now() > self.expiry
|
| 95 |
+
|
| 96 |
+
def touch(self):
|
| 97 |
+
"""Update last accessed time and access count"""
|
| 98 |
+
self.last_accessed = datetime.now()
|
| 99 |
+
self.access_count += 1
|
| 100 |
+
|
| 101 |
+
def __lt__(self, other):
|
| 102 |
+
"""For heap comparisons - lower priority items are evicted first"""
|
| 103 |
+
# First compare priority
|
| 104 |
+
if self.priority != other.priority:
|
| 105 |
+
return self.priority < other.priority
|
| 106 |
+
# Then compare access frequency (less frequently accessed items are evicted first)
|
| 107 |
+
if self.access_count != other.access_count:
|
| 108 |
+
return self.access_count < other.access_count
|
| 109 |
+
# Finally compare last access time (oldest accessed first)
|
| 110 |
+
return self.last_accessed < other.last_accessed
|
| 111 |
+
|
| 112 |
+
def get_size(self) -> int:
|
| 113 |
+
"""Approximate memory size of the cache item in bytes"""
|
| 114 |
+
try:
|
| 115 |
+
import sys
|
| 116 |
+
return sys.getsizeof(self.value) + sys.getsizeof(self.key) + 64 # Additional overhead
|
| 117 |
+
except:
|
| 118 |
+
# Default estimate if we can't get the size
|
| 119 |
+
return 1024
|
| 120 |
+
|
| 121 |
+
# Enhanced in-memory cache implementation
|
| 122 |
+
class EnhancedCache:
|
| 123 |
+
def __init__(self,
|
| 124 |
+
strategy: str = "lazy",
|
| 125 |
+
max_items: int = 10000,
|
| 126 |
+
max_size_mb: int = 100,
|
| 127 |
+
cleanup_interval: int = 60,
|
| 128 |
+
stats_enabled: bool = True):
|
| 129 |
+
"""
|
| 130 |
+
Initialize enhanced cache with configurable strategy.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
strategy: Cache loading strategy (lazy, eager, mixed)
|
| 134 |
+
max_items: Maximum number of items to store in cache
|
| 135 |
+
max_size_mb: Maximum size of cache in MB
|
| 136 |
+
cleanup_interval: Interval in seconds to run cleanup
|
| 137 |
+
stats_enabled: Whether to collect cache statistics
|
| 138 |
+
"""
|
| 139 |
+
self._cache: Dict[str, CacheItem] = {}
|
| 140 |
+
self._namespace_cache: Dict[str, Set[str]] = {} # Tracking keys by namespace
|
| 141 |
+
self._strategy = strategy
|
| 142 |
+
self._max_items = max_items
|
| 143 |
+
self._max_size_bytes = max_size_mb * 1024 * 1024
|
| 144 |
+
self._current_size_bytes = 0
|
| 145 |
+
self._stats_enabled = stats_enabled
|
| 146 |
+
|
| 147 |
+
# Statistics
|
| 148 |
+
self._hits = 0
|
| 149 |
+
self._misses = 0
|
| 150 |
+
self._evictions = 0
|
| 151 |
+
self._total_get_time = 0
|
| 152 |
+
self._total_set_time = 0
|
| 153 |
+
|
| 154 |
+
# Setup cleanup thread
|
| 155 |
+
self._last_cleanup = datetime.now()
|
| 156 |
+
self._cleanup_interval = cleanup_interval
|
| 157 |
+
self._lock = threading.RLock()
|
| 158 |
+
|
| 159 |
+
if cleanup_interval > 0:
|
| 160 |
+
self._start_cleanup_thread(cleanup_interval)
|
| 161 |
+
|
| 162 |
+
logger.info(f"Enhanced cache initialized with strategy={strategy}, max_items={max_items}, max_size={max_size_mb}MB")
|
| 163 |
+
|
| 164 |
+
def _start_cleanup_thread(self, interval: int):
|
| 165 |
+
"""Start background thread for periodic cleanup"""
|
| 166 |
+
def cleanup_worker():
|
| 167 |
+
while True:
|
| 168 |
+
time.sleep(interval)
|
| 169 |
+
try:
|
| 170 |
+
self.cleanup()
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Error in cache cleanup: {e}")
|
| 173 |
+
|
| 174 |
+
thread = threading.Thread(target=cleanup_worker, daemon=True)
|
| 175 |
+
thread.start()
|
| 176 |
+
logger.info(f"Cache cleanup thread started with interval {interval}s")
|
| 177 |
+
|
| 178 |
+
def get(self, key: str, namespace: str = None) -> Optional[Any]:
|
| 179 |
+
"""Get value from cache if it exists and hasn't expired"""
|
| 180 |
+
if self._stats_enabled:
|
| 181 |
+
start_time = time.time()
|
| 182 |
+
|
| 183 |
+
# Use namespaced key if namespace is provided
|
| 184 |
+
cache_key = f"{namespace}:{key}" if namespace else key
|
| 185 |
+
|
| 186 |
+
with self._lock:
|
| 187 |
+
cache_item = self._cache.get(cache_key)
|
| 188 |
+
|
| 189 |
+
if cache_item:
|
| 190 |
+
if cache_item.is_expired():
|
| 191 |
+
# Clean up expired key
|
| 192 |
+
self._remove_item(cache_key, namespace)
|
| 193 |
+
if self._stats_enabled:
|
| 194 |
+
self._misses += 1
|
| 195 |
+
value = None
|
| 196 |
+
else:
|
| 197 |
+
# Update access metadata
|
| 198 |
+
cache_item.touch()
|
| 199 |
+
if self._stats_enabled:
|
| 200 |
+
self._hits += 1
|
| 201 |
+
value = cache_item.value
|
| 202 |
+
else:
|
| 203 |
+
if self._stats_enabled:
|
| 204 |
+
self._misses += 1
|
| 205 |
+
value = None
|
| 206 |
+
|
| 207 |
+
if self._stats_enabled:
|
| 208 |
+
self._total_get_time += time.time() - start_time
|
| 209 |
+
|
| 210 |
+
return value
|
| 211 |
+
|
| 212 |
+
def set(self, key: str, value: Any, ttl: int = 300, priority: int = 1, namespace: str = None) -> None:
|
| 213 |
+
"""Set a value in the cache with TTL in seconds"""
|
| 214 |
+
if self._stats_enabled:
|
| 215 |
+
start_time = time.time()
|
| 216 |
+
|
| 217 |
+
# Use namespaced key if namespace is provided
|
| 218 |
+
cache_key = f"{namespace}:{key}" if namespace else key
|
| 219 |
+
|
| 220 |
+
with self._lock:
|
| 221 |
+
# Create cache item
|
| 222 |
+
cache_item = CacheItem(cache_key, value, ttl, priority)
|
| 223 |
+
item_size = cache_item.get_size()
|
| 224 |
+
|
| 225 |
+
# Check if we need to make room
|
| 226 |
+
if (len(self._cache) >= self._max_items or
|
| 227 |
+
self._current_size_bytes + item_size > self._max_size_bytes):
|
| 228 |
+
self._evict_items(item_size)
|
| 229 |
+
|
| 230 |
+
# Update size tracking
|
| 231 |
+
if cache_key in self._cache:
|
| 232 |
+
# If replacing, subtract old size first
|
| 233 |
+
self._current_size_bytes -= self._cache[cache_key].get_size()
|
| 234 |
+
self._current_size_bytes += item_size
|
| 235 |
+
|
| 236 |
+
# Store the item
|
| 237 |
+
self._cache[cache_key] = cache_item
|
| 238 |
+
|
| 239 |
+
# Update namespace tracking
|
| 240 |
+
if namespace:
|
| 241 |
+
if namespace not in self._namespace_cache:
|
| 242 |
+
self._namespace_cache[namespace] = set()
|
| 243 |
+
self._namespace_cache[namespace].add(cache_key)
|
| 244 |
+
|
| 245 |
+
if self._stats_enabled:
|
| 246 |
+
self._total_set_time += time.time() - start_time
|
| 247 |
+
|
| 248 |
+
def delete(self, key: str, namespace: str = None) -> None:
|
| 249 |
+
"""Delete a key from the cache"""
|
| 250 |
+
# Use namespaced key if namespace is provided
|
| 251 |
+
cache_key = f"{namespace}:{key}" if namespace else key
|
| 252 |
+
|
| 253 |
+
with self._lock:
|
| 254 |
+
self._remove_item(cache_key, namespace)
|
| 255 |
+
|
| 256 |
+
def _remove_item(self, key: str, namespace: str = None):
|
| 257 |
+
"""Internal method to remove an item and update tracking"""
|
| 258 |
+
if key in self._cache:
|
| 259 |
+
# Update size tracking
|
| 260 |
+
self._current_size_bytes -= self._cache[key].get_size()
|
| 261 |
+
# Remove from cache
|
| 262 |
+
del self._cache[key]
|
| 263 |
+
|
| 264 |
+
# Update namespace tracking
|
| 265 |
+
if namespace and namespace in self._namespace_cache:
|
| 266 |
+
if key in self._namespace_cache[namespace]:
|
| 267 |
+
self._namespace_cache[namespace].remove(key)
|
| 268 |
+
# Cleanup empty sets
|
| 269 |
+
if not self._namespace_cache[namespace]:
|
| 270 |
+
del self._namespace_cache[namespace]
|
| 271 |
+
|
| 272 |
+
def _evict_items(self, needed_space: int = 0) -> None:
|
| 273 |
+
"""Evict items to make room in the cache"""
|
| 274 |
+
if not self._cache:
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
with self._lock:
|
| 278 |
+
# Convert cache items to a list for sorting
|
| 279 |
+
items = list(self._cache.values())
|
| 280 |
+
|
| 281 |
+
# Sort by priority, access count, and last accessed time
|
| 282 |
+
items.sort() # Uses the __lt__ method of CacheItem
|
| 283 |
+
|
| 284 |
+
# Evict items until we have enough space
|
| 285 |
+
space_freed = 0
|
| 286 |
+
evicted_count = 0
|
| 287 |
+
|
| 288 |
+
for item in items:
|
| 289 |
+
# Stop if we've made enough room
|
| 290 |
+
if (len(self._cache) - evicted_count <= self._max_items * 0.9 and
|
| 291 |
+
(space_freed >= needed_space or
|
| 292 |
+
self._current_size_bytes - space_freed <= self._max_size_bytes * 0.9)):
|
| 293 |
+
break
|
| 294 |
+
|
| 295 |
+
# Skip high priority items unless absolutely necessary
|
| 296 |
+
if item.priority > 9 and evicted_count < len(items) // 2:
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
# Evict this item
|
| 300 |
+
item_size = item.get_size()
|
| 301 |
+
namespace = item.key.split(':', 1)[0] if ':' in item.key else None
|
| 302 |
+
self._remove_item(item.key, namespace)
|
| 303 |
+
|
| 304 |
+
space_freed += item_size
|
| 305 |
+
evicted_count += 1
|
| 306 |
+
if self._stats_enabled:
|
| 307 |
+
self._evictions += 1
|
| 308 |
+
|
| 309 |
+
logger.info(f"Cache eviction: removed {evicted_count} items, freed {space_freed / 1024:.2f}KB")
|
| 310 |
+
|
| 311 |
+
def clear(self, namespace: str = None) -> None:
|
| 312 |
+
"""
|
| 313 |
+
Clear the cache or a specific namespace
|
| 314 |
+
"""
|
| 315 |
+
with self._lock:
|
| 316 |
+
if namespace:
|
| 317 |
+
# Clear only keys in the specified namespace
|
| 318 |
+
if namespace in self._namespace_cache:
|
| 319 |
+
keys_to_remove = list(self._namespace_cache[namespace])
|
| 320 |
+
for key in keys_to_remove:
|
| 321 |
+
self._remove_item(key, namespace)
|
| 322 |
+
# The namespace should be auto-cleaned in _remove_item
|
| 323 |
+
else:
|
| 324 |
+
# Clear the entire cache
|
| 325 |
+
self._cache.clear()
|
| 326 |
+
self._namespace_cache.clear()
|
| 327 |
+
self._current_size_bytes = 0
|
| 328 |
+
|
| 329 |
+
logger.info(f"Cache cleared{' for namespace ' + namespace if namespace else ''}")
|
| 330 |
+
|
| 331 |
+
def cleanup(self) -> None:
|
| 332 |
+
"""Remove expired items and run garbage collection if needed"""
|
| 333 |
+
with self._lock:
|
| 334 |
+
now = datetime.now()
|
| 335 |
+
# Only run if it's been at least cleanup_interval since last cleanup
|
| 336 |
+
if (now - self._last_cleanup).total_seconds() < self._cleanup_interval:
|
| 337 |
+
return
|
| 338 |
+
|
| 339 |
+
# Find expired items
|
| 340 |
+
expired_keys = []
|
| 341 |
+
for key, item in self._cache.items():
|
| 342 |
+
if item.is_expired():
|
| 343 |
+
expired_keys.append((key, key.split(':', 1)[0] if ':' in key else None))
|
| 344 |
+
|
| 345 |
+
# Remove expired items
|
| 346 |
+
for key, namespace in expired_keys:
|
| 347 |
+
self._remove_item(key, namespace)
|
| 348 |
+
|
| 349 |
+
# Update last cleanup time
|
| 350 |
+
self._last_cleanup = now
|
| 351 |
+
|
| 352 |
+
# Run garbage collection if we removed several items
|
| 353 |
+
if len(expired_keys) > 100:
|
| 354 |
+
gc.collect()
|
| 355 |
+
|
| 356 |
+
logger.info(f"Cache cleanup: removed {len(expired_keys)} expired items")
|
| 357 |
+
|
| 358 |
+
def get_stats(self) -> Dict:
|
| 359 |
+
"""Get cache statistics"""
|
| 360 |
+
with self._lock:
|
| 361 |
+
if not self._stats_enabled:
|
| 362 |
+
return {"stats_enabled": False}
|
| 363 |
+
|
| 364 |
+
# Calculate hit rate
|
| 365 |
+
total_requests = self._hits + self._misses
|
| 366 |
+
hit_rate = (self._hits / total_requests) * 100 if total_requests > 0 else 0
|
| 367 |
+
|
| 368 |
+
# Calculate average times
|
| 369 |
+
avg_get_time = (self._total_get_time / total_requests) * 1000 if total_requests > 0 else 0
|
| 370 |
+
avg_set_time = (self._total_set_time / self._evictions) * 1000 if self._evictions > 0 else 0
|
| 371 |
+
|
| 372 |
+
return {
|
| 373 |
+
"stats_enabled": True,
|
| 374 |
+
"item_count": len(self._cache),
|
| 375 |
+
"max_items": self._max_items,
|
| 376 |
+
"size_bytes": self._current_size_bytes,
|
| 377 |
+
"max_size_bytes": self._max_size_bytes,
|
| 378 |
+
"hits": self._hits,
|
| 379 |
+
"misses": self._misses,
|
| 380 |
+
"hit_rate_percent": round(hit_rate, 2),
|
| 381 |
+
"evictions": self._evictions,
|
| 382 |
+
"avg_get_time_ms": round(avg_get_time, 3),
|
| 383 |
+
"avg_set_time_ms": round(avg_set_time, 3),
|
| 384 |
+
"namespace_count": len(self._namespace_cache),
|
| 385 |
+
"namespaces": list(self._namespace_cache.keys())
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
def preload(self, items: List[Tuple[str, Any, int, int]], namespace: str = None) -> None:
|
| 389 |
+
"""
|
| 390 |
+
Preload a list of items into the cache
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
items: List of (key, value, ttl, priority) tuples
|
| 394 |
+
namespace: Optional namespace for all items
|
| 395 |
+
"""
|
| 396 |
+
for key, value, ttl, priority in items:
|
| 397 |
+
self.set(key, value, ttl, priority, namespace)
|
| 398 |
+
|
| 399 |
+
logger.info(f"Preloaded {len(items)} items into cache{' namespace ' + namespace if namespace else ''}")
|
| 400 |
+
|
| 401 |
+
def get_or_load(self, key: str, loader_func: Callable[[], Any],
|
| 402 |
+
ttl: int = 300, priority: int = 1, namespace: str = None) -> Any:
|
| 403 |
+
"""
|
| 404 |
+
Get from cache or load using the provided function
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
key: Cache key
|
| 408 |
+
loader_func: Function to call if cache miss occurs
|
| 409 |
+
ttl: TTL in seconds
|
| 410 |
+
priority: Item priority
|
| 411 |
+
namespace: Optional namespace
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
Cached or freshly loaded value
|
| 415 |
+
"""
|
| 416 |
+
# Try to get from cache first
|
| 417 |
+
value = self.get(key, namespace)
|
| 418 |
+
|
| 419 |
+
# If not in cache, load it
|
| 420 |
+
if value is None:
|
| 421 |
+
value = loader_func()
|
| 422 |
+
# Only cache if we got a valid value
|
| 423 |
+
if value is not None:
|
| 424 |
+
self.set(key, value, ttl, priority, namespace)
|
| 425 |
+
|
| 426 |
+
return value
|
| 427 |
+
|
| 428 |
+
# Load cache configuration from environment variables
|
| 429 |
+
CACHE_STRATEGY = os.getenv("CACHE_STRATEGY", "mixed")
|
| 430 |
+
CACHE_MAX_ITEMS = int(os.getenv("CACHE_MAX_ITEMS", "10000"))
|
| 431 |
+
CACHE_MAX_SIZE_MB = int(os.getenv("CACHE_MAX_SIZE_MB", "100"))
|
| 432 |
+
CACHE_CLEANUP_INTERVAL = int(os.getenv("CACHE_CLEANUP_INTERVAL", "60"))
|
| 433 |
+
CACHE_STATS_ENABLED = os.getenv("CACHE_STATS_ENABLED", "true").lower() in ("true", "1", "yes")
|
| 434 |
+
|
| 435 |
+
# Initialize the enhanced cache
|
| 436 |
+
cache = EnhancedCache(
|
| 437 |
+
strategy=CACHE_STRATEGY,
|
| 438 |
+
max_items=CACHE_MAX_ITEMS,
|
| 439 |
+
max_size_mb=CACHE_MAX_SIZE_MB,
|
| 440 |
+
cleanup_interval=CACHE_CLEANUP_INTERVAL,
|
| 441 |
+
stats_enabled=CACHE_STATS_ENABLED
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Backward compatibility for SimpleCache - for a transition period
|
| 445 |
+
class SimpleCache:
|
| 446 |
+
def __init__(self):
|
| 447 |
+
"""Legacy SimpleCache implementation that uses EnhancedCache underneath"""
|
| 448 |
+
logger.warning("SimpleCache is deprecated, please use EnhancedCache directly")
|
| 449 |
+
|
| 450 |
+
def get(self, key: str) -> Optional[Any]:
|
| 451 |
+
"""Get value from cache if it exists and hasn't expired"""
|
| 452 |
+
return cache.get(key)
|
| 453 |
+
|
| 454 |
+
def set(self, key: str, value: Any, ttl: int = 300) -> None:
|
| 455 |
+
"""Set a value in the cache with TTL in seconds"""
|
| 456 |
+
cache.set(key, value, ttl)
|
| 457 |
+
|
| 458 |
+
def delete(self, key: str) -> None:
|
| 459 |
+
"""Delete a key from the cache"""
|
| 460 |
+
cache.delete(key)
|
| 461 |
+
|
| 462 |
+
def clear(self) -> None:
|
| 463 |
+
"""Clear the entire cache"""
|
| 464 |
+
cache.clear()
|
| 465 |
+
|
| 466 |
+
def get_host_url(request) -> str:
|
| 467 |
+
"""
|
| 468 |
+
Get the host URL from a request object.
|
| 469 |
+
"""
|
| 470 |
+
host = request.headers.get("host", "localhost")
|
| 471 |
+
scheme = request.headers.get("x-forwarded-proto", "http")
|
| 472 |
+
return f"{scheme}://{host}"
|
| 473 |
+
|
| 474 |
+
def format_time(timestamp):
|
| 475 |
+
"""
|
| 476 |
+
Format a timestamp into a human-readable string.
|
| 477 |
+
"""
|
| 478 |
+
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
backend:
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
ports:
|
| 9 |
+
- "7860:7860"
|
| 10 |
+
env_file:
|
| 11 |
+
- .env
|
| 12 |
+
restart: unless-stopped
|
| 13 |
+
healthcheck:
|
| 14 |
+
test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
|
| 15 |
+
interval: 30s
|
| 16 |
+
timeout: 10s
|
| 17 |
+
retries: 3
|
| 18 |
+
start_period: 40s
|
| 19 |
+
volumes:
|
| 20 |
+
- ./app:/app/app
|
| 21 |
+
command: uvicorn app:app --host 0.0.0.0 --port 7860 --reload
|
docs/api_documentation.md
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Documentation
|
| 2 |
+
|
| 3 |
+
## Frontend Setup
|
| 4 |
+
|
| 5 |
+
```javascript
|
| 6 |
+
// Basic Axios setup
|
| 7 |
+
import axios from 'axios';
|
| 8 |
+
|
| 9 |
+
const api = axios.create({
|
| 10 |
+
baseURL: 'https://api.your-domain.com',
|
| 11 |
+
timeout: 10000,
|
| 12 |
+
headers: {
|
| 13 |
+
'Content-Type': 'application/json',
|
| 14 |
+
'Accept': 'application/json'
|
| 15 |
+
}
|
| 16 |
+
});
|
| 17 |
+
|
| 18 |
+
// Error handling
|
| 19 |
+
api.interceptors.response.use(
|
| 20 |
+
response => response.data,
|
| 21 |
+
error => {
|
| 22 |
+
const errorMessage = error.response?.data?.detail || 'An error occurred';
|
| 23 |
+
console.error('API Error:', errorMessage);
|
| 24 |
+
return Promise.reject(errorMessage);
|
| 25 |
+
}
|
| 26 |
+
);
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Caching System
|
| 30 |
+
|
| 31 |
+
- All GET endpoints support `use_cache=true` parameter (default)
|
| 32 |
+
- Cache TTL: 300 seconds (5 minutes)
|
| 33 |
+
- Cache is automatically invalidated on data changes
|
| 34 |
+
|
| 35 |
+
## Authentication
|
| 36 |
+
|
| 37 |
+
Currently no authentication is required. If implemented in the future, use JWT Bearer tokens:
|
| 38 |
+
|
| 39 |
+
```javascript
|
| 40 |
+
const api = axios.create({
|
| 41 |
+
// ...other config
|
| 42 |
+
headers: {
|
| 43 |
+
// ...other headers
|
| 44 |
+
'Authorization': `Bearer ${token}`
|
| 45 |
+
}
|
| 46 |
+
});
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Error Codes
|
| 50 |
+
|
| 51 |
+
| Code | Description |
|
| 52 |
+
|------|-------------|
|
| 53 |
+
| 400 | Bad Request |
|
| 54 |
+
| 404 | Not Found |
|
| 55 |
+
| 500 | Internal Server Error |
|
| 56 |
+
| 503 | Service Unavailable |
|
| 57 |
+
|
| 58 |
+
## PostgreSQL Endpoints
|
| 59 |
+
|
| 60 |
+
### FAQ Endpoints
|
| 61 |
+
|
| 62 |
+
#### Get FAQs List
|
| 63 |
+
```
|
| 64 |
+
GET /postgres/faq
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Parameters:
|
| 68 |
+
- `skip`: Number of items to skip (default: 0)
|
| 69 |
+
- `limit`: Maximum items to return (default: 100)
|
| 70 |
+
- `active_only`: Return only active items (default: false)
|
| 71 |
+
- `use_cache`: Use cached data if available (default: true)
|
| 72 |
+
|
| 73 |
+
Response:
|
| 74 |
+
```json
|
| 75 |
+
[
|
| 76 |
+
{
|
| 77 |
+
"question": "How do I book a hotel?",
|
| 78 |
+
"answer": "You can book a hotel through our app or website.",
|
| 79 |
+
"is_active": true,
|
| 80 |
+
"id": 1,
|
| 81 |
+
"created_at": "2023-01-01T00:00:00",
|
| 82 |
+
"updated_at": "2023-01-01T00:00:00"
|
| 83 |
+
}
|
| 84 |
+
]
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Example:
|
| 88 |
+
```javascript
|
| 89 |
+
async function getFAQs() {
|
| 90 |
+
try {
|
| 91 |
+
const data = await api.get('/postgres/faq', {
|
| 92 |
+
params: { active_only: true, limit: 20 }
|
| 93 |
+
});
|
| 94 |
+
return data;
|
| 95 |
+
} catch (error) {
|
| 96 |
+
console.error('Error fetching FAQs:', error);
|
| 97 |
+
throw error;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
#### Create FAQ
|
| 103 |
+
```
|
| 104 |
+
POST /postgres/faq
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
Request Body:
|
| 108 |
+
```json
|
| 109 |
+
{
|
| 110 |
+
"question": "How do I book a hotel?",
|
| 111 |
+
"answer": "You can book a hotel through our app or website.",
|
| 112 |
+
"is_active": true
|
| 113 |
+
}
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Response: Created FAQ object
|
| 117 |
+
|
| 118 |
+
#### Get FAQ Detail
|
| 119 |
+
```
|
| 120 |
+
GET /postgres/faq/{faq_id}
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
Parameters:
|
| 124 |
+
- `faq_id`: ID of FAQ (required)
|
| 125 |
+
- `use_cache`: Use cached data if available (default: true)
|
| 126 |
+
|
| 127 |
+
Response: FAQ object
|
| 128 |
+
|
| 129 |
+
#### Update FAQ
|
| 130 |
+
```
|
| 131 |
+
PUT /postgres/faq/{faq_id}
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Parameters:
|
| 135 |
+
- `faq_id`: ID of FAQ to update (required)
|
| 136 |
+
|
| 137 |
+
Request Body: Partial or complete FAQ object
|
| 138 |
+
Response: Updated FAQ object
|
| 139 |
+
|
| 140 |
+
#### Delete FAQ
|
| 141 |
+
```
|
| 142 |
+
DELETE /postgres/faq/{faq_id}
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
Parameters:
|
| 146 |
+
- `faq_id`: ID of FAQ to delete (required)
|
| 147 |
+
|
| 148 |
+
Response:
|
| 149 |
+
```json
|
| 150 |
+
{
|
| 151 |
+
"status": "success",
|
| 152 |
+
"message": "FAQ item 1 deleted"
|
| 153 |
+
}
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
#### Batch Operations
|
| 157 |
+
|
| 158 |
+
Create multiple FAQs:
|
| 159 |
+
```
|
| 160 |
+
POST /postgres/faqs/batch
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
Update status of multiple FAQs:
|
| 164 |
+
```
|
| 165 |
+
PUT /postgres/faqs/batch-update-status
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Delete multiple FAQs:
|
| 169 |
+
```
|
| 170 |
+
DELETE /postgres/faqs/batch
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
### Emergency Contact Endpoints
|
| 174 |
+
|
| 175 |
+
#### Get Emergency Contacts
|
| 176 |
+
```
|
| 177 |
+
GET /postgres/emergency
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Parameters:
|
| 181 |
+
- `skip`: Number of items to skip (default: 0)
|
| 182 |
+
- `limit`: Maximum items to return (default: 100)
|
| 183 |
+
- `active_only`: Return only active items (default: false)
|
| 184 |
+
- `use_cache`: Use cached data if available (default: true)
|
| 185 |
+
|
| 186 |
+
Response: Array of Emergency Contact objects
|
| 187 |
+
|
| 188 |
+
#### Create Emergency Contact
|
| 189 |
+
```
|
| 190 |
+
POST /postgres/emergency
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
Request Body:
|
| 194 |
+
```json
|
| 195 |
+
{
|
| 196 |
+
"name": "Fire Department",
|
| 197 |
+
"phone_number": "114",
|
| 198 |
+
"description": "Fire rescue services",
|
| 199 |
+
"address": "Da Nang",
|
| 200 |
+
"location": "16.0544, 108.2022",
|
| 201 |
+
"priority": 1,
|
| 202 |
+
"is_active": true
|
| 203 |
+
}
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
Response: Created Emergency Contact object
|
| 207 |
+
|
| 208 |
+
#### Get Emergency Contact
|
| 209 |
+
```
|
| 210 |
+
GET /postgres/emergency/{emergency_id}
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
#### Update Emergency Contact
|
| 214 |
+
```
|
| 215 |
+
PUT /postgres/emergency/{emergency_id}
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
#### Delete Emergency Contact
|
| 219 |
+
```
|
| 220 |
+
DELETE /postgres/emergency/{emergency_id}
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
#### Batch Operations
|
| 224 |
+
|
| 225 |
+
Create multiple Emergency Contacts:
|
| 226 |
+
```
|
| 227 |
+
POST /postgres/emergency/batch
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
Update status of multiple Emergency Contacts:
|
| 231 |
+
```
|
| 232 |
+
PUT /postgres/emergency/batch-update-status
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
Delete multiple Emergency Contacts:
|
| 236 |
+
```
|
| 237 |
+
DELETE /postgres/emergency/batch
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
### Event Endpoints
|
| 241 |
+
|
| 242 |
+
#### Get Events
|
| 243 |
+
```
|
| 244 |
+
GET /postgres/events
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
Parameters:
|
| 248 |
+
- `skip`: Number of items to skip (default: 0)
|
| 249 |
+
- `limit`: Maximum items to return (default: 100)
|
| 250 |
+
- `active_only`: Return only active items (default: false)
|
| 251 |
+
- `featured_only`: Return only featured items (default: false)
|
| 252 |
+
- `use_cache`: Use cached data if available (default: true)
|
| 253 |
+
|
| 254 |
+
Response: Array of Event objects
|
| 255 |
+
|
| 256 |
+
#### Create Event
|
| 257 |
+
```
|
| 258 |
+
POST /postgres/events
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
Request Body:
|
| 262 |
+
```json
|
| 263 |
+
{
|
| 264 |
+
"name": "Da Nang Fireworks Festival",
|
| 265 |
+
"description": "International Fireworks Festival Da Nang 2023",
|
| 266 |
+
"address": "Dragon Bridge, Da Nang",
|
| 267 |
+
"location": "16.0610, 108.2277",
|
| 268 |
+
"date_start": "2023-06-01T19:00:00",
|
| 269 |
+
"date_end": "2023-06-01T22:00:00",
|
| 270 |
+
"price": [
|
| 271 |
+
{"type": "VIP", "amount": 500000},
|
| 272 |
+
{"type": "Standard", "amount": 300000}
|
| 273 |
+
],
|
| 274 |
+
"url": "https://danangfireworks.com",
|
| 275 |
+
"is_active": true,
|
| 276 |
+
"featured": true
|
| 277 |
+
}
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
Response: Created Event object
|
| 281 |
+
|
| 282 |
+
#### Get Event
|
| 283 |
+
```
|
| 284 |
+
GET /postgres/events/{event_id}
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
#### Update Event
|
| 288 |
+
```
|
| 289 |
+
PUT /postgres/events/{event_id}
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
#### Delete Event
|
| 293 |
+
```
|
| 294 |
+
DELETE /postgres/events/{event_id}
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
#### Batch Operations
|
| 298 |
+
|
| 299 |
+
Create multiple Events:
|
| 300 |
+
```
|
| 301 |
+
POST /postgres/events/batch
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
Update status of multiple Events:
|
| 305 |
+
```
|
| 306 |
+
PUT /postgres/events/batch-update-status
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
Delete multiple Events:
|
| 310 |
+
```
|
| 311 |
+
DELETE /postgres/events/batch
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
### About Pixity Endpoints
|
| 315 |
+
|
| 316 |
+
#### Get About Pixity
|
| 317 |
+
```
|
| 318 |
+
GET /postgres/about-pixity
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
Response:
|
| 322 |
+
```json
|
| 323 |
+
{
|
| 324 |
+
"content": "PiXity is your smart, AI-powered local companion...",
|
| 325 |
+
"id": 1,
|
| 326 |
+
"created_at": "2023-01-01T00:00:00",
|
| 327 |
+
"updated_at": "2023-01-01T00:00:00"
|
| 328 |
+
}
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
#### Update About Pixity
|
| 332 |
+
```
|
| 333 |
+
PUT /postgres/about-pixity
|
| 334 |
+
```
|
| 335 |
+
|
| 336 |
+
Request Body:
|
| 337 |
+
```json
|
| 338 |
+
{
|
| 339 |
+
"content": "PiXity is your smart, AI-powered local companion..."
|
| 340 |
+
}
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
Response: Updated About Pixity object
|
| 344 |
+
|
| 345 |
+
### Da Nang Bucket List Endpoints
|
| 346 |
+
|
| 347 |
+
#### Get Da Nang Bucket List
|
| 348 |
+
```
|
| 349 |
+
GET /postgres/danang-bucket-list
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
Response: Bucket List object with JSON content string
|
| 353 |
+
|
| 354 |
+
#### Update Da Nang Bucket List
|
| 355 |
+
```
|
| 356 |
+
PUT /postgres/danang-bucket-list
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
### Solana Summit Endpoints
|
| 360 |
+
|
| 361 |
+
#### Get Solana Summit
|
| 362 |
+
```
|
| 363 |
+
GET /postgres/solana-summit
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
Response: Solana Summit object with JSON content string
|
| 367 |
+
|
| 368 |
+
#### Update Solana Summit
|
| 369 |
+
```
|
| 370 |
+
PUT /postgres/solana-summit
|
| 371 |
+
```
|
| 372 |
+
|
| 373 |
+
### Health Check
|
| 374 |
+
```
|
| 375 |
+
GET /postgres/health
|
| 376 |
+
```
|
| 377 |
+
|
| 378 |
+
Response:
|
| 379 |
+
```json
|
| 380 |
+
{
|
| 381 |
+
"status": "healthy",
|
| 382 |
+
"message": "PostgreSQL connection is working",
|
| 383 |
+
"timestamp": "2023-01-01T00:00:00"
|
| 384 |
+
}
|
| 385 |
+
```
|
| 386 |
+
|
| 387 |
+
## MongoDB Endpoints
|
| 388 |
+
|
| 389 |
+
### Session Endpoints
|
| 390 |
+
|
| 391 |
+
#### Create Session
|
| 392 |
+
```
|
| 393 |
+
POST /session
|
| 394 |
+
```
|
| 395 |
+
|
| 396 |
+
Request Body:
|
| 397 |
+
```json
|
| 398 |
+
{
|
| 399 |
+
"user_id": "user123",
|
| 400 |
+
"query": "How do I book a room?",
|
| 401 |
+
"timestamp": "2023-01-01T00:00:00",
|
| 402 |
+
"metadata": {
|
| 403 |
+
"client_info": "web",
|
| 404 |
+
"location": "Da Nang"
|
| 405 |
+
}
|
| 406 |
+
}
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
Response: Created Session object with session_id
|
| 410 |
+
|
| 411 |
+
#### Update Session with Response
|
| 412 |
+
```
|
| 413 |
+
PUT /session/{session_id}/response
|
| 414 |
+
```
|
| 415 |
+
|
| 416 |
+
Request Body:
|
| 417 |
+
```json
|
| 418 |
+
{
|
| 419 |
+
"response": "You can book a room through our app or website.",
|
| 420 |
+
"response_timestamp": "2023-01-01T00:00:05",
|
| 421 |
+
"metadata": {
|
| 422 |
+
"response_time_ms": 234,
|
| 423 |
+
"model_version": "gpt-4"
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
```
|
| 427 |
+
|
| 428 |
+
Response: Updated Session object
|
| 429 |
+
|
| 430 |
+
#### Get Session
|
| 431 |
+
```
|
| 432 |
+
GET /session/{session_id}
|
| 433 |
+
```
|
| 434 |
+
|
| 435 |
+
Response: Session object
|
| 436 |
+
|
| 437 |
+
#### Get User History
|
| 438 |
+
```
|
| 439 |
+
GET /history
|
| 440 |
+
```
|
| 441 |
+
|
| 442 |
+
Parameters:
|
| 443 |
+
- `user_id`: User ID (required)
|
| 444 |
+
- `limit`: Maximum sessions to return (default: 10)
|
| 445 |
+
- `skip`: Number of sessions to skip (default: 0)
|
| 446 |
+
|
| 447 |
+
Response:
|
| 448 |
+
```json
|
| 449 |
+
{
|
| 450 |
+
"user_id": "user123",
|
| 451 |
+
"sessions": [
|
| 452 |
+
{
|
| 453 |
+
"session_id": "60f7a8b9c1d2e3f4a5b6c7d8",
|
| 454 |
+
"query": "How do I book a room?",
|
| 455 |
+
"timestamp": "2023-01-01T00:00:00",
|
| 456 |
+
"response": "You can book a room through our app or website.",
|
| 457 |
+
"response_timestamp": "2023-01-01T00:00:05"
|
| 458 |
+
}
|
| 459 |
+
],
|
| 460 |
+
"total_count": 1
|
| 461 |
+
}
|
| 462 |
+
```
|
| 463 |
+
|
| 464 |
+
#### Health Check
|
| 465 |
+
```
|
| 466 |
+
GET /health
|
| 467 |
+
```
|
| 468 |
+
|
| 469 |
+
## RAG Endpoints
|
| 470 |
+
|
| 471 |
+
### Create Embedding
|
| 472 |
+
```
|
| 473 |
+
POST /embedding
|
| 474 |
+
```
|
| 475 |
+
|
| 476 |
+
Request Body:
|
| 477 |
+
```json
|
| 478 |
+
{
|
| 479 |
+
"text": "Text to embed"
|
| 480 |
+
}
|
| 481 |
+
```
|
| 482 |
+
|
| 483 |
+
Response:
|
| 484 |
+
```json
|
| 485 |
+
{
|
| 486 |
+
"embedding": [0.1, 0.2, 0.3, ...],
|
| 487 |
+
"dimensions": 1536
|
| 488 |
+
}
|
| 489 |
+
```
|
| 490 |
+
|
| 491 |
+
### Process Chat Request
|
| 492 |
+
```
|
| 493 |
+
POST /chat
|
| 494 |
+
```
|
| 495 |
+
|
| 496 |
+
Request Body:
|
| 497 |
+
```json
|
| 498 |
+
{
|
| 499 |
+
"query": "Can you tell me about Pixity?",
|
| 500 |
+
"chat_history": [
|
| 501 |
+
{"role": "user", "content": "Hello"},
|
| 502 |
+
{"role": "assistant", "content": "Hello! How can I help you?"}
|
| 503 |
+
]
|
| 504 |
+
}
|
| 505 |
+
```
|
| 506 |
+
|
| 507 |
+
Response:
|
| 508 |
+
```json
|
| 509 |
+
{
|
| 510 |
+
"answer": "Pixity is a platform...",
|
| 511 |
+
"sources": [
|
| 512 |
+
{
|
| 513 |
+
"document_id": "doc123",
|
| 514 |
+
"chunk_id": "chunk456",
|
| 515 |
+
"chunk_text": "Pixity was founded in...",
|
| 516 |
+
"relevance_score": 0.92
|
| 517 |
+
}
|
| 518 |
+
]
|
| 519 |
+
}
|
| 520 |
+
```
|
| 521 |
+
|
| 522 |
+
### Direct RAG Query
|
| 523 |
+
```
|
| 524 |
+
POST /rag
|
| 525 |
+
```
|
| 526 |
+
|
| 527 |
+
Request Body:
|
| 528 |
+
```json
|
| 529 |
+
{
|
| 530 |
+
"query": "Can you tell me about Pixity?",
|
| 531 |
+
"namespace": "about_pixity",
|
| 532 |
+
"top_k": 3
|
| 533 |
+
}
|
| 534 |
+
```
|
| 535 |
+
|
| 536 |
+
Response: Query results with relevance scores
|
| 537 |
+
|
| 538 |
+
### Health Check
|
| 539 |
+
```
|
| 540 |
+
GET /health
|
| 541 |
+
```
|
| 542 |
+
|
| 543 |
+
## PDF Processing Endpoints
|
| 544 |
+
|
| 545 |
+
### Upload and Process PDF
|
| 546 |
+
```
|
| 547 |
+
POST /pdf/upload
|
| 548 |
+
```
|
| 549 |
+
|
| 550 |
+
Form Data:
|
| 551 |
+
- `file`: PDF file (required)
|
| 552 |
+
- `namespace`: Vector database namespace (default: "Default")
|
| 553 |
+
- `index_name`: Vector database index name (default: "testbot768")
|
| 554 |
+
- `title`: Document title (optional)
|
| 555 |
+
- `description`: Document description (optional)
|
| 556 |
+
- `user_id`: User ID for WebSocket updates (optional)
|
| 557 |
+
|
| 558 |
+
Response: Processing results with document_id
|
| 559 |
+
|
| 560 |
+
### Delete Documents in Namespace
|
| 561 |
+
```
|
| 562 |
+
DELETE /pdf/namespace
|
| 563 |
+
```
|
| 564 |
+
|
| 565 |
+
Parameters:
|
| 566 |
+
- `namespace`: Vector database namespace (default: "Default")
|
| 567 |
+
- `index_name`: Vector database index name (default: "testbot768")
|
| 568 |
+
- `user_id`: User ID for WebSocket updates (optional)
|
| 569 |
+
|
| 570 |
+
Response: Deletion results
|
| 571 |
+
|
| 572 |
+
### Get Documents List
|
| 573 |
+
```
|
| 574 |
+
GET /pdf/documents
|
| 575 |
+
```
|
| 576 |
+
|
| 577 |
+
Parameters:
|
| 578 |
+
- `namespace`: Vector database namespace (default: "Default")
|
| 579 |
+
- `index_name`: Vector database index name (default: "testbot768")
|
| 580 |
+
|
| 581 |
+
Response: List of documents in the namespace
|
pytest.ini
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
# Bỏ qua cảnh báo về anyio module và các cảnh báo vận hành nội bộ
|
| 3 |
+
filterwarnings =
|
| 4 |
+
ignore::pytest.PytestAssertRewriteWarning:.*anyio
|
| 5 |
+
ignore:.*general_plain_validator_function.* is deprecated.*:DeprecationWarning
|
| 6 |
+
ignore:.*with_info_plain_validator_function.*:DeprecationWarning
|
| 7 |
+
|
| 8 |
+
# Cấu hình cơ bản khác
|
| 9 |
+
testpaths = tests
|
| 10 |
+
python_files = test_*.py
|
| 11 |
+
python_classes = Test*
|
| 12 |
+
python_functions = test_*
|
requirements.txt
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI
|
| 2 |
+
fastapi==0.103.1
|
| 3 |
+
uvicorn[standard]==0.23.2
|
| 4 |
+
pydantic==2.4.2
|
| 5 |
+
python-dotenv==1.0.0
|
| 6 |
+
websockets==11.0.3
|
| 7 |
+
|
| 8 |
+
# MongoDB
|
| 9 |
+
pymongo==4.6.1
|
| 10 |
+
dnspython==2.4.2
|
| 11 |
+
|
| 12 |
+
# PostgreSQL
|
| 13 |
+
sqlalchemy==2.0.20
|
| 14 |
+
pydantic-settings==2.0.3
|
| 15 |
+
psycopg2-binary==2.9.7
|
| 16 |
+
|
| 17 |
+
# Pinecone & RAG
|
| 18 |
+
pinecone-client==3.0.0
|
| 19 |
+
langchain==0.1.4
|
| 20 |
+
langchain-core==0.1.19
|
| 21 |
+
langchain-community==0.0.14
|
| 22 |
+
langchain-google-genai==0.0.5
|
| 23 |
+
langchain-pinecone==0.0.1
|
| 24 |
+
faiss-cpu==1.7.4
|
| 25 |
+
google-generativeai==0.3.1
|
| 26 |
+
|
| 27 |
+
# Extras
|
| 28 |
+
pytz==2023.3
|
| 29 |
+
python-multipart==0.0.6
|
| 30 |
+
httpx==0.25.1
|
| 31 |
+
requests==2.31.0
|
| 32 |
+
beautifulsoup4==4.12.2
|
| 33 |
+
redis==5.0.1
|
| 34 |
+
|
| 35 |
+
# Testing
|
| 36 |
+
prometheus-client==0.17.1
|
| 37 |
+
pytest==7.4.0
|
| 38 |
+
pytest-cov==4.1.0
|
| 39 |
+
watchfiles==0.21.0
|
| 40 |
+
|
| 41 |
+
# Core dependencies
|
| 42 |
+
starlette==0.27.0
|
| 43 |
+
psutil==5.9.6
|
| 44 |
+
|
| 45 |
+
# Upload PDF
|
| 46 |
+
pypdf==3.17.4
|
| 47 |
+
|