Cuong2004 commited on
Commit
7be47d4
·
0 Parent(s):

clear history and fix bug websocket connection

Browse files
.dockerignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git
2
+ .git
3
+ .gitignore
4
+ .gitattributes
5
+
6
+ # Environment files
7
+ .env
8
+ .env.*
9
+ !.env.example
10
+
11
+ # Python cache files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+ *.so
16
+ .Python
17
+ .pytest_cache/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+
22
+ # Logs
23
+ *.log
24
+
25
+ # Tests
26
+ tests/
27
+
28
+ # Docker related
29
+ Dockerfile
30
+ docker-compose.yml
31
+ .dockerignore
32
+
33
+ # Other files
34
+ .vscode/
35
+ .idea/
36
+ *.swp
37
+ *.swo
38
+ .DS_Store
39
+ .coverage
40
+ htmlcov/
41
+ .mypy_cache/
42
+ .tox/
43
+ .nox/
44
+ instance/
45
+ .webassets-cache
46
+ main.py
.env.example ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PostgreSQL Configuration
2
+ DB_CONNECTION_MODE=aiven
3
+ AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
4
+
5
+ # MongoDB Configuration
6
+ MONGODB_URL=mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority
7
+ DB_NAME=Telegram
8
+ COLLECTION_NAME=session_chat
9
+
10
+ # Pinecone configuration
11
+ PINECONE_API_KEY=your-pinecone-api-key
12
+ PINECONE_INDEX_NAME=your-pinecone-index-name
13
+ PINECONE_ENVIRONMENT=gcp-starter
14
+
15
+ # Google Gemini API key
16
+ GOOGLE_API_KEY=your-google-api-key
17
+
18
+ # WebSocket configuration
19
+ WEBSOCKET_SERVER=localhost
20
+ WEBSOCKET_PORT=7860
21
+ WEBSOCKET_PATH=/notify
22
+
23
+ # Application settings
24
+ ENVIRONMENT=production
25
+ DEBUG=false
26
+ PORT=7860
27
+
28
+ # Cache Configuration
29
+ CACHE_TTL_SECONDS=300
30
+ CACHE_CLEANUP_INTERVAL=60
31
+ CACHE_MAX_SIZE=1000
32
+ HISTORY_QUEUE_SIZE=10
33
+ HISTORY_CACHE_TTL=3600
.gitignore ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ .pytest_cache/
23
+ htmlcov/
24
+ .coverage
25
+ .coverage.*
26
+ .cache/
27
+ coverage.xml
28
+ *.cover
29
+ .mypy_cache/
30
+
31
+ # Environment
32
+ .env
33
+ .venv
34
+ env/
35
+ venv/
36
+ ENV/
37
+
38
+ # VSCode
39
+ .vscode/
40
+ *.code-workspace
41
+ .history/
42
+
43
+ # PyCharm
44
+ .idea/
45
+ *.iml
46
+ *.iws
47
+ *.ipr
48
+ *.iws
49
+ out/
50
+ .idea_modules/
51
+
52
+ # Logs and databases
53
+ *.log
54
+ *.sql
55
+ *.sqlite
56
+ *.db
57
+
58
+ # Tests
59
+ tests/
60
+
61
+ Admin_bot/
62
+
63
+ Pix-Agent/
64
+
65
+ # Hugging Face Spaces
66
+ .gitattributes
67
+
68
+ # OS specific
69
+ .DS_Store
70
+ .DS_Store?
71
+ ._*
72
+ .Spotlight-V100
73
+ .Trashes
74
+ Icon?
75
+ ehthumbs.db
76
+ Thumbs.db
77
+
78
+ # Project specific
79
+ *.log
80
+ .env
81
+ main.py
82
+
83
+ test/
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Cài đặt các gói hệ thống cần thiết
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ curl \
9
+ software-properties-common \
10
+ git \
11
+ gcc \
12
+ python3-dev \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Sao chép các file yêu cầu trước để tận dụng cache của Docker
16
+ COPY requirements.txt .
17
+
18
+ # Cài đặt các gói Python
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Ensure langchain-core is installed
22
+ RUN pip install --no-cache-dir langchain-core==0.1.19
23
+
24
+ # Sao chép toàn bộ code vào container
25
+ COPY . .
26
+
27
+ # Mở cổng mà ứng dụng sẽ chạy
28
+ EXPOSE 7860
29
+
30
+ # Chạy ứng dụng với uvicorn
31
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PIX Project Backend
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ sdk_version: "3.0.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+ # PIX Project Backend
15
+
16
+ [![FastAPI](https://img.shields.io/badge/FastAPI-0.103.1-009688?style=flat&logo=fastapi&logoColor=white)](https://fastapi.tiangolo.com/)
17
+ [![Python 3.11](https://img.shields.io/badge/Python-3.11-3776AB?style=flat&logo=python&logoColor=white)](https://www.python.org/)
18
+ [![HuggingFace Spaces](https://img.shields.io/badge/HuggingFace-Spaces-yellow?style=flat&logo=huggingface&logoColor=white)](https://huggingface.co/spaces)
19
+
20
+ Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration. This project provides a comprehensive backend solution for managing FAQ items, emergency contacts, events, and a RAG-based question answering system.
21
+
22
+ ## Features
23
+
24
+ - **MongoDB Integration**: Store user sessions and conversation history
25
+ - **PostgreSQL Integration**: Manage FAQ items, emergency contacts, and events
26
+ - **Pinecone Vector Database**: Store and retrieve vector embeddings for RAG
27
+ - **RAG Question Answering**: Answer questions using relevant information from the vector database
28
+ - **WebSocket Notifications**: Real-time notifications for Admin Bot
29
+ - **API Documentation**: Automatic OpenAPI documentation via Swagger
30
+ - **Docker Support**: Easy deployment using Docker
31
+ - **Auto-Debugging**: Built-in debugging, error tracking, and performance monitoring
32
+
33
+ ## API Endpoints
34
+
35
+ ### MongoDB Endpoints
36
+
37
+ - `POST /mongodb/session`: Create a new session record
38
+ - `PUT /mongodb/session/{session_id}/response`: Update a session with a response
39
+ - `GET /mongodb/history`: Get user conversation history
40
+ - `GET /mongodb/health`: Check MongoDB connection health
41
+
42
+ ### PostgreSQL Endpoints
43
+
44
+ - `GET /postgres/health`: Check PostgreSQL connection health
45
+ - `GET /postgres/faq`: Get FAQ items
46
+ - `POST /postgres/faq`: Create a new FAQ item
47
+ - `GET /postgres/faq/{faq_id}`: Get a specific FAQ item
48
+ - `PUT /postgres/faq/{faq_id}`: Update a specific FAQ item
49
+ - `DELETE /postgres/faq/{faq_id}`: Delete a specific FAQ item
50
+ - `GET /postgres/emergency`: Get emergency contact items
51
+ - `POST /postgres/emergency`: Create a new emergency contact item
52
+ - `GET /postgres/emergency/{emergency_id}`: Get a specific emergency contact
53
+ - `GET /postgres/events`: Get event items
54
+
55
+ ### RAG Endpoints
56
+
57
+ - `POST /rag/chat`: Get answer for a question using RAG
58
+ - `POST /rag/embedding`: Generate embedding for text
59
+ - `GET /rag/health`: Check RAG services health
60
+
61
+ ### WebSocket Endpoints
62
+
63
+ - `WebSocket /notify`: Receive real-time notifications for new sessions
64
+
65
+ ### Debug Endpoints (Available in Debug Mode Only)
66
+
67
+ - `GET /debug/config`: Get configuration information
68
+ - `GET /debug/system`: Get system information (CPU, memory, disk usage)
69
+ - `GET /debug/database`: Check all database connections
70
+ - `GET /debug/errors`: View recent error logs
71
+ - `GET /debug/performance`: Get performance metrics
72
+ - `GET /debug/full`: Get comprehensive debug information
73
+
74
+ ## WebSocket API
75
+
76
+ ### Notifications for New Sessions
77
+
78
+ The backend provides a WebSocket endpoint for receiving notifications about new sessions that match specific criteria.
79
+
80
+ #### WebSocket Endpoint Configuration
81
+
82
+ The WebSocket endpoint is configured using environment variables:
83
+
84
+ ```
85
+ # WebSocket configuration
86
+ WEBSOCKET_SERVER=localhost
87
+ WEBSOCKET_PORT=7860
88
+ WEBSOCKET_PATH=/notify
89
+ ```
90
+
91
+ The full WebSocket URL will be:
92
+ ```
93
+ ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}
94
+ ```
95
+
96
+ For example: `ws://localhost:7860/notify`
97
+
98
+ #### Notification Criteria
99
+
100
+ A notification is sent when:
101
+ 1. A new session is created with `factor` set to "RAG"
102
+ 2. The message content starts with "I don't know"
103
+
104
+ #### Notification Format
105
+
106
+ ```json
107
+ {
108
+ "type": "new_session",
109
+ "timestamp": "2025-04-15 22:30:45",
110
+ "data": {
111
+ "session_id": "123e4567-e89b-12d3-a456-426614174000",
112
+ "factor": "rag",
113
+ "action": "asking_freely",
114
+ "created_at": "2025-04-15 22:30:45",
115
+ "first_name": "John",
116
+ "last_name": "Doe",
117
+ "message": "I don't know how to find emergency contacts",
118
+ "user_id": "12345678",
119
+ "username": "johndoe"
120
+ }
121
+ }
122
+ ```
123
+
124
+ #### Usage Example
125
+
126
+ Admin Bot should establish a WebSocket connection to this endpoint using the configured URL:
127
+
128
+ ```python
129
+ import websocket
130
+ import json
131
+ import os
132
+ from dotenv import load_dotenv
133
+
134
+ # Load environment variables
135
+ load_dotenv()
136
+
137
+ # Get WebSocket configuration from environment variables
138
+ WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
139
+ WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
140
+ WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
141
+
142
+ # Create full URL
143
+ ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
144
+
145
+ def on_message(ws, message):
146
+ data = json.loads(message)
147
+ print(f"Received notification: {data}")
148
+ # Forward to Telegram Admin
149
+
150
+ def on_error(ws, error):
151
+ print(f"Error: {error}")
152
+
153
+ def on_close(ws, close_status_code, close_msg):
154
+ print("Connection closed")
155
+
156
+ def on_open(ws):
157
+ print("Connection opened")
158
+ # Send keepalive message periodically
159
+ ws.send("keepalive")
160
+
161
+ # Connect to WebSocket
162
+ ws = websocket.WebSocketApp(
163
+ ws_url,
164
+ on_open=on_open,
165
+ on_message=on_message,
166
+ on_error=on_error,
167
+ on_close=on_close
168
+ )
169
+ ws.run_forever()
170
+ ```
171
+
172
+ When a notification is received, Admin Bot should forward the content to the Telegram Admin.
173
+
174
+ ## Environment Variables
175
+
176
+ Create a `.env` file with the following variables:
177
+
178
+ ```
179
+ # PostgreSQL Configuration
180
+ DB_CONNECTION_MODE=aiven
181
+ AIVEN_DB_URL=postgresql://username:password@host:port/dbname?sslmode=require
182
+
183
+ # MongoDB Configuration
184
+ MONGODB_URL=mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority
185
+ DB_NAME=Telegram
186
+ COLLECTION_NAME=session_chat
187
+
188
+ # Pinecone configuration
189
+ PINECONE_API_KEY=your-pinecone-api-key
190
+ PINECONE_INDEX_NAME=your-pinecone-index-name
191
+ PINECONE_ENVIRONMENT=gcp-starter
192
+
193
+ # Google Gemini API key
194
+ GOOGLE_API_KEY=your-google-api-key
195
+
196
+ # WebSocket configuration
197
+ WEBSOCKET_SERVER=localhost
198
+ WEBSOCKET_PORT=7860
199
+ WEBSOCKET_PATH=/notify
200
+
201
+ # Application settings
202
+ ENVIRONMENT=production
203
+ DEBUG=false
204
+ PORT=7860
205
+ ```
206
+
207
+ ## Installation and Setup
208
+
209
+ ### Local Development
210
+
211
+ 1. Clone the repository:
212
+ ```bash
213
+ git clone https://github.com/ManTT-Data/PixAgent.git
214
+ cd PixAgent
215
+ ```
216
+
217
+ 2. Create a virtual environment and install dependencies:
218
+ ```bash
219
+ python -m venv venv
220
+ source venv/bin/activate # On Windows: venv\Scripts\activate
221
+ pip install -r requirements.txt
222
+ ```
223
+
224
+ 3. Create a `.env` file with your configuration (see above)
225
+
226
+ 4. Run the application:
227
+ ```bash
228
+ uvicorn app:app --reload --port 7860
229
+ ```
230
+
231
+ 5. Open your browser and navigate to [http://localhost:7860/docs](http://localhost:7860/docs) to see the API documentation
232
+
233
+ ### Docker Deployment
234
+
235
+ 1. Build the Docker image:
236
+ ```bash
237
+ docker build -t pix-project-backend .
238
+ ```
239
+
240
+ 2. Run the Docker container:
241
+ ```bash
242
+ docker run -p 7860:7860 --env-file .env pix-project-backend
243
+ ```
244
+
245
+ ## Deployment to HuggingFace Spaces
246
+
247
+ 1. Create a new Space on HuggingFace (Dockerfile type)
248
+ 2. Link your GitHub repository or push directly to the HuggingFace repo
249
+ 3. Add your environment variables in the Space settings
250
+ 4. The deployment will use `app.py` as the entry point, which is the standard for HuggingFace Spaces
251
+
252
+ ### Important Notes for HuggingFace Deployment
253
+
254
+ - The application uses `app.py` with the FastAPI instance named `app` to avoid the "Error loading ASGI app. Attribute 'app' not found in module 'app'" error
255
+ - Make sure all environment variables are set in the Space settings
256
+ - The Dockerfile is configured to expose port 7860, which is the default port for HuggingFace Spaces
257
+
258
+ ## Project Structure
259
+
260
+ ```
261
+ .
262
+ ├── app # Main application package
263
+ │ ├── api # API endpoints
264
+ │ │ ├── mongodb_routes.py
265
+ │ │ ├── postgresql_routes.py
266
+ │ │ ├── rag_routes.py
267
+ │ │ └── websocket_routes.py
268
+ │ ├── database # Database connections
269
+ │ │ ├── mongodb.py
270
+ │ │ ├── pinecone.py
271
+ │ │ └── postgresql.py
272
+ │ ├── models # Pydantic models
273
+ │ │ ├── mongodb_models.py
274
+ │ │ ├── postgresql_models.py
275
+ │ │ └── rag_models.py
276
+ │ └── utils # Utility functions
277
+ │ ├── debug_utils.py
278
+ │ └── middleware.py
279
+ ├── tests # Test directory
280
+ │ └── test_api_endpoints.py
281
+ ├── .dockerignore # Docker ignore file
282
+ ├── .env.example # Example environment file
283
+ ├── .gitattributes # Git attributes
284
+ ├── .gitignore # Git ignore file
285
+ ├── app.py # Application entry point
286
+ ├── docker-compose.yml # Docker compose configuration
287
+ ├── Dockerfile # Docker configuration
288
+ ├── pytest.ini # Pytest configuration
289
+ ├── README.md # Project documentation
290
+ ├── requirements.txt # Project dependencies
291
+ └── api_documentation.txt # API documentation for frontend engineers
292
+ ```
293
+
294
+ ## License
295
+
296
+ This project is licensed under the MIT License - see the LICENSE file for details.
297
+
298
+ # Advanced Retrieval System
299
+
300
+ This project now features an enhanced vector retrieval system that improves the quality and relevance of information retrieved from Pinecone using threshold-based filtering and multiple similarity metrics.
301
+
302
+ ## Features
303
+
304
+ ### 1. Threshold-Based Retrieval
305
+
306
+ The system implements a threshold-based approach to vector retrieval, which:
307
+ - Retrieves a larger candidate set from the vector database
308
+ - Applies a similarity threshold to filter out less relevant results
309
+ - Returns only the most relevant documents that exceed the threshold
310
+
311
+ ### 2. Multiple Similarity Metrics
312
+
313
+ The system supports multiple similarity metrics:
314
+ - **Cosine Similarity** (default): Measures the cosine of the angle between vectors
315
+ - **Dot Product**: Calculates the dot product between vectors
316
+ - **Euclidean Distance**: Measures the straight-line distance between vectors
317
+
318
+ Each metric has different characteristics and may perform better for different types of data and queries.
319
+
320
+ ### 3. Score Normalization
321
+
322
+ For metrics like Euclidean distance where lower values indicate higher similarity, the system automatically normalizes scores to a 0-1 scale where higher values always indicate higher similarity. This makes it easier to compare results across different metrics.
323
+
324
+ ## Configuration
325
+
326
+ The retrieval system can be configured through environment variables:
327
+
328
+ ```
329
+ # Pinecone retrieval configuration
330
+ PINECONE_DEFAULT_LIMIT_K=10 # Maximum number of candidates to retrieve
331
+ PINECONE_DEFAULT_TOP_K=6 # Number of results to return after filtering
332
+ PINECONE_DEFAULT_SIMILARITY_METRIC=cosine # Default similarity metric
333
+ PINECONE_DEFAULT_SIMILARITY_THRESHOLD=0.75 # Similarity threshold (0-1)
334
+ PINECONE_ALLOWED_METRICS=cosine,dotproduct,euclidean # Available metrics
335
+ ```
336
+
337
+ ## API Usage
338
+
339
+ You can customize the retrieval parameters when making API requests:
340
+
341
+ ```json
342
+ {
343
+ "user_id": "user123",
344
+ "question": "What are the best restaurants in Da Nang?",
345
+ "similarity_top_k": 5,
346
+ "limit_k": 15,
347
+ "similarity_metric": "cosine",
348
+ "similarity_threshold": 0.8
349
+ }
350
+ ```
351
+
352
+ ## Benefits
353
+
354
+ 1. **Quality Improvement**: Retrieves only the most relevant documents above a certain quality threshold
355
+ 2. **Flexibility**: Different similarity metrics can be used for different types of queries
356
+ 3. **Efficiency**: Avoids processing irrelevant documents, improving response time
357
+ 4. **Configurability**: All parameters can be adjusted via environment variables or at request time
358
+
359
+ ## Implementation Details
360
+
361
+ The system is implemented as a custom retriever class `ThresholdRetriever` that integrates with LangChain's retrieval infrastructure while providing enhanced functionality.
362
+
363
+ ## In-Memory Cache
364
+
365
+ Dự án bao gồm một hệ thống cache trong bộ nhớ để giảm thiểu truy cập đến cơ sở dữ liệu PostgreSQL và MongoDB.
366
+
367
+ ### Cấu hình Cache
368
+
369
+ Cache được cấu hình thông qua các biến môi trường:
370
+
371
+ ```
372
+ # Cache Configuration
373
+ CACHE_TTL_SECONDS=300 # Thời gian tồn tại của cache item (giây)
374
+ CACHE_CLEANUP_INTERVAL=60 # Chu kỳ xóa cache hết hạn (giây)
375
+ CACHE_MAX_SIZE=1000 # Số lượng item tối đa trong cache
376
+ HISTORY_QUEUE_SIZE=10 # Số lượng item tối đa trong queue lịch sử người dùng
377
+ HISTORY_CACHE_TTL=3600 # Thời gian tồn tại của lịch sử người dùng (giây)
378
+ ```
379
+
380
+ ### Cơ chế Cache
381
+
382
+ Hệ thống cache kết hợp hai cơ chế hết hạn:
383
+
384
+ 1. **Lazy Expiration**: Kiểm tra thời hạn khi truy cập cache item. Nếu item đã hết hạn, nó sẽ bị xóa và trả về kết quả là không tìm thấy.
385
+
386
+ 2. **Active Expiration**: Một background thread định kỳ quét và xóa các item đã hết hạn. Điều này giúp tránh tình trạng cache quá lớn với các item không còn được sử dụng.
387
+
388
+ ### Các loại dữ liệu được cache
389
+
390
+ - **Dữ liệu PostgreSQL**: Thông tin từ các bảng FAQ, Emergency Contacts, và Events.
391
+ - **Lịch sử người dùng từ MongoDB**: Lịch sử hội thoại người dùng được lưu trong queue với thời gian sống tính theo lần truy cập cuối cùng.
392
+
393
+ ### API Cache
394
+
395
+ Dự án cung cấp các API endpoints để quản lý cache:
396
+
397
+ - `GET /cache/stats`: Xem thống kê về cache (tổng số item, bộ nhớ sử dụng, v.v.)
398
+ - `DELETE /cache/clear`: Xóa toàn bộ cache
399
+ - `GET /debug/cache`: (Chỉ trong chế độ debug) Xem thông tin chi tiết về cache, bao gồm các keys và cấu hình
400
+
401
+ ### Cách hoạt động
402
+
403
+ 1. Khi một request đến, hệ thống sẽ kiểm tra dữ liệu trong cache trước.
404
+ 2. Nếu dữ liệu tồn tại và còn hạn, trả về từ cache.
405
+ 3. Nếu dữ liệu không tồn tại hoặc đã hết hạn, truy vấn từ database và lưu kết quả vào cache.
406
+ 4. Khi dữ liệu được cập nhật hoặc xóa, cache liên quan sẽ tự động được xóa.
407
+
408
+ ### Lịch sử người dùng
409
+
410
+ Lịch sử hội thoại người dùng được lưu trong queue riêng với cơ chế đặc biệt:
411
+
412
+ - Mỗi người dùng có một queue riêng với kích thước giới hạn (`HISTORY_QUEUE_SIZE`).
413
+ - Thời gian sống của queue được làm mới mỗi khi có tương tác mới.
414
+ - Khi queue đầy, các item cũ nhất sẽ bị loại bỏ.
415
+ - Queue tự động bị xóa sau một thời gian không hoạt động.
416
+
417
+ ## Tác giả
418
+
419
+ - **PIX Project Team**
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Depends, Request, HTTPException, status
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from contextlib import asynccontextmanager
4
+ import uvicorn
5
+ import os
6
+ import sys
7
+ import logging
8
+ from dotenv import load_dotenv
9
+
10
+ # Cấu hình logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
14
+ handlers=[
15
+ logging.StreamHandler(sys.stdout),
16
+ ]
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+ DEBUG = os.getenv("DEBUG", "False").lower() in ("true", "1", "t")
23
+
24
+ # Kiểm tra các biến môi trường bắt buộc
25
+ required_env_vars = [
26
+ "AIVEN_DB_URL",
27
+ "MONGODB_URL",
28
+ "PINECONE_API_KEY",
29
+ "PINECONE_INDEX_NAME",
30
+ "GOOGLE_API_KEY"
31
+ ]
32
+
33
+ missing_vars = [var for var in required_env_vars if not os.getenv(var)]
34
+ if missing_vars:
35
+ logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
36
+ if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
37
+ sys.exit(1)
38
+
39
+ # Database health checks
40
+ def check_database_connections():
41
+ """Kiểm tra kết nối các database khi khởi động"""
42
+ from app.database.postgresql import check_db_connection as check_postgresql
43
+ from app.database.mongodb import check_db_connection as check_mongodb
44
+ from app.database.pinecone import check_db_connection as check_pinecone
45
+
46
+ db_status = {
47
+ "postgresql": check_postgresql(),
48
+ "mongodb": check_mongodb(),
49
+ "pinecone": check_pinecone()
50
+ }
51
+
52
+ all_ok = all(db_status.values())
53
+ if not all_ok:
54
+ failed_dbs = [name for name, status in db_status.items() if not status]
55
+ logger.error(f"Failed to connect to databases: {', '.join(failed_dbs)}")
56
+ if not DEBUG: # Chỉ thoát nếu không ở chế độ debug
57
+ sys.exit(1)
58
+
59
+ return db_status
60
+
61
+ # Khởi tạo lifespan để kiểm tra kết nối database khi khởi động
62
+ @asynccontextmanager
63
+ async def lifespan(app: FastAPI):
64
+ # Startup: kiểm tra kết nối các database
65
+ logger.info("Starting application...")
66
+ db_status = check_database_connections()
67
+
68
+ # Khởi tạo bảng trong cơ sở dữ liệu (nếu chưa tồn tại)
69
+ if DEBUG and all(db_status.values()): # Chỉ khởi tạo bảng trong chế độ debug và khi tất cả kết nối DB thành công
70
+ from app.database.postgresql import create_tables
71
+ if create_tables():
72
+ logger.info("Database tables created or already exist")
73
+
74
+ yield
75
+
76
+ # Shutdown
77
+ logger.info("Shutting down application...")
78
+
79
+ # Import routers
80
+ try:
81
+ from app.api.mongodb_routes import router as mongodb_router
82
+ from app.api.postgresql_routes import router as postgresql_router
83
+ from app.api.rag_routes import router as rag_router
84
+ from app.api.pdf_routes import router as pdf_router
85
+ from app.api.websocket_routes import router as websocket_router
86
+
87
+ # Import middlewares
88
+ from app.utils.middleware import RequestLoggingMiddleware, ErrorHandlingMiddleware, DatabaseCheckMiddleware
89
+
90
+ # Import debug utilities
91
+ from app.utils.debug_utils import debug_view, DebugInfo, error_tracker, performance_monitor
92
+
93
+ # Import cache
94
+ from app.utils.cache import get_cache
95
+
96
+ except ImportError as e:
97
+ logger.error(f"Error importing routes or middlewares: {e}")
98
+ raise
99
+
100
+ # Create FastAPI app
101
+ app = FastAPI(
102
+ title="PIX Project Backend API",
103
+ description="Backend API for PIX Project with MongoDB, PostgreSQL and RAG integration",
104
+ version="1.0.0",
105
+ docs_url="/docs",
106
+ redoc_url="/redoc",
107
+ debug=DEBUG,
108
+ lifespan=lifespan,
109
+ )
110
+
111
+ # Configure CORS
112
+ app.add_middleware(
113
+ CORSMiddleware,
114
+ allow_origins=["*"],
115
+ allow_credentials=True,
116
+ allow_methods=["*"],
117
+ allow_headers=["*"],
118
+ )
119
+
120
+ # Thêm middlewares
121
+ app.add_middleware(ErrorHandlingMiddleware)
122
+ app.add_middleware(RequestLoggingMiddleware)
123
+ if not DEBUG: # Chỉ thêm middleware kiểm tra database trong production
124
+ app.add_middleware(DatabaseCheckMiddleware)
125
+
126
+ # Include routers
127
+ app.include_router(mongodb_router)
128
+ app.include_router(postgresql_router)
129
+ app.include_router(rag_router)
130
+ app.include_router(pdf_router)
131
+ app.include_router(websocket_router)
132
+
133
+ # Root endpoint
134
+ @app.get("/")
135
+ def read_root():
136
+ return {
137
+ "message": "Welcome to PIX Project Backend API",
138
+ "documentation": "/docs",
139
+ }
140
+
141
+ # Health check endpoint
142
+ @app.get("/health")
143
+ def health_check():
144
+ # Kiểm tra kết nối database
145
+ db_status = check_database_connections()
146
+ all_db_ok = all(db_status.values())
147
+
148
+ return {
149
+ "status": "healthy" if all_db_ok else "degraded",
150
+ "version": "1.0.0",
151
+ "environment": os.environ.get("ENVIRONMENT", "production"),
152
+ "databases": db_status
153
+ }
154
+
155
+ @app.get("/api/ping")
156
+ async def ping():
157
+ return {"status": "pong"}
158
+
159
+ # Cache stats endpoint
160
+ @app.get("/cache/stats")
161
+ def cache_stats():
162
+ """Trả về thống kê về cache"""
163
+ cache = get_cache()
164
+ return cache.stats()
165
+
166
+ # Cache clear endpoint
167
+ @app.delete("/cache/clear")
168
+ def cache_clear():
169
+ """Xóa tất cả dữ liệu trong cache"""
170
+ cache = get_cache()
171
+ cache.clear()
172
+ return {"message": "Cache cleared successfully"}
173
+
174
+ # Debug endpoints (chỉ có trong chế độ debug)
175
+ if DEBUG:
176
+ @app.get("/debug/config")
177
+ def debug_config():
178
+ """Hiển thị thông tin cấu hình (chỉ trong chế độ debug)"""
179
+ config = {
180
+ "environment": os.environ.get("ENVIRONMENT", "production"),
181
+ "debug": DEBUG,
182
+ "db_connection_mode": os.environ.get("DB_CONNECTION_MODE", "aiven"),
183
+ "databases": {
184
+ "postgresql": os.environ.get("AIVEN_DB_URL", "").split("@")[1].split("/")[0] if "@" in os.environ.get("AIVEN_DB_URL", "") else "N/A",
185
+ "mongodb": os.environ.get("MONGODB_URL", "").split("@")[1].split("/?")[0] if "@" in os.environ.get("MONGODB_URL", "") else "N/A",
186
+ "pinecone": os.environ.get("PINECONE_INDEX_NAME", "N/A"),
187
+ }
188
+ }
189
+ return config
190
+
191
+ @app.get("/debug/system")
192
+ def debug_system():
193
+ """Hiển thị thông tin hệ thống (chỉ trong chế độ debug)"""
194
+ return DebugInfo.get_system_info()
195
+
196
+ @app.get("/debug/database")
197
+ def debug_database():
198
+ """Hiển thị trạng thái database (chỉ trong chế độ debug)"""
199
+ return DebugInfo.get_database_status()
200
+
201
+ @app.get("/debug/errors")
202
+ def debug_errors(limit: int = 10):
203
+ """Hiển thị các lỗi gần đây (chỉ trong chế độ debug)"""
204
+ return error_tracker.get_recent_errors(limit)
205
+
206
+ @app.get("/debug/performance")
207
+ def debug_performance():
208
+ """Hiển thị thống kê hiệu suất (chỉ trong chế độ debug)"""
209
+ return performance_monitor.get_stats()
210
+
211
+ @app.get("/debug/full")
212
+ def debug_full_report(request: Request):
213
+ """Hiển thị báo cáo đầy đủ về hệ thống (chỉ trong chế độ debug)"""
214
+ return debug_view(request)
215
+
216
+ @app.get("/debug/cache")
217
+ def debug_cache():
218
+ """Hiển thị thống kê về cache (chỉ trong chế độ debug)"""
219
+ return get_cache().stats()
220
+
221
+ if __name__ == "__main__":
222
+ PORT = int(os.getenv("PORT", "7860"))
223
+ uvicorn.run("app:app", host="0.0.0.0", port=PORT, reload=DEBUG)
app/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PIX Project Backend
2
+ # Version: 1.0.0
3
+
4
+ __version__ = "1.0.0"
5
+
6
+ # Import app từ app.py để tests có thể tìm thấy
7
+ import sys
8
+ import os
9
+
10
+ # Thêm thư mục gốc vào sys.path
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ try:
14
+ # Sửa lại cách import đúng - 'app.py' không phải là module hợp lệ
15
+ # 'app' là tên module, '.py' là phần mở rộng tệp
16
+ from app import app
17
+ except ImportError:
18
+ # Thử cách khác nếu import trực tiếp không hoạt động
19
+ import importlib.util
20
+ spec = importlib.util.spec_from_file_location("app_module",
21
+ os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
22
+ "app.py"))
23
+ app_module = importlib.util.module_from_spec(spec)
24
+ spec.loader.exec_module(app_module)
25
+ app = app_module.app
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API routes package
app/api/mongodb_routes.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Depends, Query, status, Response
2
+ from typing import List, Optional, Dict
3
+ from pymongo.errors import PyMongoError
4
+ import logging
5
+ from datetime import datetime
6
+ import traceback
7
+ import asyncio
8
+
9
+ from app.database.mongodb import (
10
+ save_session,
11
+ get_chat_history,
12
+ update_session_response,
13
+ check_db_connection,
14
+ session_collection
15
+ )
16
+ from app.models.mongodb_models import (
17
+ SessionCreate,
18
+ SessionResponse,
19
+ HistoryRequest,
20
+ HistoryResponse,
21
+ QuestionAnswer
22
+ )
23
+ from app.api.websocket_routes import send_notification
24
+
25
+ # Configure logging
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Create router
29
+ router = APIRouter(
30
+ prefix="/mongodb",
31
+ tags=["MongoDB"],
32
+ )
33
+
34
+ @router.post("/session", response_model=SessionResponse, status_code=status.HTTP_201_CREATED)
35
+ async def create_session(session: SessionCreate, response: Response):
36
+ """
37
+ Create a new session record in MongoDB.
38
+
39
+ - **session_id**: Unique identifier for the session (auto-generated if not provided)
40
+ - **factor**: Factor type (user, rag, etc.)
41
+ - **action**: Action type (start, events, faq, emergency, help, asking_freely, etc.)
42
+ - **first_name**: User's first name
43
+ - **last_name**: User's last name (optional)
44
+ - **message**: User's message (optional)
45
+ - **user_id**: User's ID from Telegram
46
+ - **username**: User's username (optional)
47
+ - **response**: Response from RAG (optional)
48
+ """
49
+ try:
50
+ # Kiểm tra kết nối MongoDB
51
+ if not check_db_connection():
52
+ logger.error("MongoDB connection failed when trying to create session")
53
+ raise HTTPException(
54
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
55
+ detail="MongoDB connection failed"
56
+ )
57
+
58
+ # Create new session in MongoDB
59
+ result = save_session(
60
+ session_id=session.session_id,
61
+ factor=session.factor,
62
+ action=session.action,
63
+ first_name=session.first_name,
64
+ last_name=session.last_name,
65
+ message=session.message,
66
+ user_id=session.user_id,
67
+ username=session.username,
68
+ response=session.response
69
+ )
70
+
71
+ # Chuẩn bị response object
72
+ session_response = SessionResponse(
73
+ **session.model_dump(),
74
+ created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
75
+ )
76
+
77
+ # Kiểm tra nếu session cần gửi thông báo (response bắt đầu bằng "I'm sorry")
78
+ if session.response and session.response.strip().lower().startswith("i'm sorry"):
79
+ # Gửi thông báo qua WebSocket
80
+ try:
81
+ notification_data = {
82
+ "session_id": session.session_id,
83
+ "factor": session.factor,
84
+ "action": session.action,
85
+ "message": session.message,
86
+ "user_id": session.user_id,
87
+ "username": session.username,
88
+ "first_name": session.first_name,
89
+ "last_name": session.last_name,
90
+ "response": session.response,
91
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
92
+ }
93
+
94
+ # Khởi tạo task để gửi thông báo - sử dụng asyncio.create_task để đảm bảo không block quá trình chính
95
+ asyncio.create_task(send_notification(notification_data))
96
+ logger.info(f"Notification queued for session {session.session_id} - response starts with 'I'm sorry'")
97
+ except Exception as e:
98
+ logger.error(f"Error queueing notification: {e}")
99
+ # Không dừng xử lý chính khi gửi thông báo thất bại
100
+
101
+ # Return response
102
+ return session_response
103
+ except PyMongoError as e:
104
+ logger.error(f"MongoDB error creating session: {e}")
105
+ logger.error(traceback.format_exc())
106
+ raise HTTPException(
107
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
108
+ detail=f"MongoDB error: {str(e)}"
109
+ )
110
+ except Exception as e:
111
+ logger.error(f"Unexpected error creating session: {e}")
112
+ logger.error(traceback.format_exc())
113
+ raise HTTPException(
114
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
115
+ detail=f"Failed to create session: {str(e)}"
116
+ )
117
+
118
+ @router.put("/session/{session_id}/response", status_code=status.HTTP_200_OK)
119
+ async def update_session_with_response(session_id: str, response_text: str):
120
+ """
121
+ Update a session with the response.
122
+
123
+ - **session_id**: ID of the session to update
124
+ - **response_text**: Response to add to the session
125
+ """
126
+ try:
127
+ # Kiểm tra kết nối MongoDB
128
+ if not check_db_connection():
129
+ logger.error("MongoDB connection failed when trying to update session response")
130
+ raise HTTPException(
131
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
132
+ detail="MongoDB connection failed"
133
+ )
134
+
135
+ # Update session in MongoDB
136
+ result = update_session_response(session_id, response_text)
137
+
138
+ if not result:
139
+ raise HTTPException(
140
+ status_code=status.HTTP_404_NOT_FOUND,
141
+ detail=f"Session with ID {session_id} not found"
142
+ )
143
+
144
+ return {"status": "success", "message": "Response added to session"}
145
+ except PyMongoError as e:
146
+ logger.error(f"MongoDB error updating session response: {e}")
147
+ logger.error(traceback.format_exc())
148
+ raise HTTPException(
149
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
150
+ detail=f"MongoDB error: {str(e)}"
151
+ )
152
+ except HTTPException:
153
+ # Re-throw HTTP exceptions
154
+ raise
155
+ except Exception as e:
156
+ logger.error(f"Unexpected error updating session response: {e}")
157
+ logger.error(traceback.format_exc())
158
+ raise HTTPException(
159
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
160
+ detail=f"Failed to update session: {str(e)}"
161
+ )
162
+
163
+ @router.get("/history", response_model=HistoryResponse)
164
+ async def get_history(user_id: str, n: int = Query(3, ge=1, le=10)):
165
+ """
166
+ Get user history for a specific user.
167
+
168
+ - **user_id**: User's ID from Telegram
169
+ - **n**: Number of most recent interactions to return (default: 3, min: 1, max: 10)
170
+ """
171
+ try:
172
+ # Kiểm tra kết nối MongoDB
173
+ if not check_db_connection():
174
+ logger.error("MongoDB connection failed when trying to get user history")
175
+ raise HTTPException(
176
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
177
+ detail="MongoDB connection failed"
178
+ )
179
+
180
+ # Get user history from MongoDB
181
+ history_data = get_chat_history(user_id=user_id, n=n)
182
+
183
+ # Convert to response model
184
+ return HistoryResponse(history=history_data)
185
+ except PyMongoError as e:
186
+ logger.error(f"MongoDB error getting user history: {e}")
187
+ logger.error(traceback.format_exc())
188
+ raise HTTPException(
189
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190
+ detail=f"MongoDB error: {str(e)}"
191
+ )
192
+ except Exception as e:
193
+ logger.error(f"Unexpected error getting user history: {e}")
194
+ logger.error(traceback.format_exc())
195
+ raise HTTPException(
196
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
197
+ detail=f"Failed to get user history: {str(e)}"
198
+ )
199
+
200
+ @router.get("/health")
201
+ async def health_check():
202
+ """
203
+ Check health of MongoDB connection.
204
+ """
205
+ try:
206
+ # Kiểm tra kết nối MongoDB
207
+ is_connected = check_db_connection()
208
+
209
+ if not is_connected:
210
+ return {
211
+ "status": "unhealthy",
212
+ "message": "MongoDB connection failed",
213
+ "timestamp": datetime.now().isoformat()
214
+ }
215
+
216
+ return {
217
+ "status": "healthy",
218
+ "message": "MongoDB connection is working",
219
+ "timestamp": datetime.now().isoformat()
220
+ }
221
+ except Exception as e:
222
+ logger.error(f"MongoDB health check failed: {e}")
223
+ logger.error(traceback.format_exc())
224
+ return {
225
+ "status": "error",
226
+ "message": f"MongoDB health check error: {str(e)}",
227
+ "timestamp": datetime.now().isoformat()
228
+ }
229
+
230
+ @router.get("/session/{session_id}")
231
+ async def get_session(session_id: str):
232
+ """
233
+ Lấy thông tin session từ MongoDB theo session_id.
234
+
235
+ - **session_id**: ID của session cần lấy
236
+ """
237
+ try:
238
+ # Kiểm tra kết nối MongoDB
239
+ if not check_db_connection():
240
+ logger.error("MongoDB connection failed when trying to get session")
241
+ raise HTTPException(
242
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
243
+ detail="MongoDB connection failed"
244
+ )
245
+
246
+ # Lấy thông tin từ MongoDB
247
+ session_data = session_collection.find_one({"session_id": session_id})
248
+
249
+ if not session_data:
250
+ raise HTTPException(
251
+ status_code=status.HTTP_404_NOT_FOUND,
252
+ detail=f"Session with ID {session_id} not found"
253
+ )
254
+
255
+ # Chuyển _id thành string để có thể JSON serialize
256
+ if "_id" in session_data:
257
+ session_data["_id"] = str(session_data["_id"])
258
+
259
+ return session_data
260
+ except PyMongoError as e:
261
+ logger.error(f"MongoDB error getting session: {e}")
262
+ logger.error(traceback.format_exc())
263
+ raise HTTPException(
264
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
265
+ detail=f"MongoDB error: {str(e)}"
266
+ )
267
+ except HTTPException:
268
+ # Re-throw HTTP exceptions
269
+ raise
270
+ except Exception as e:
271
+ logger.error(f"Unexpected error getting session: {e}")
272
+ logger.error(traceback.format_exc())
273
+ raise HTTPException(
274
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
275
+ detail=f"Failed to get session: {str(e)}"
276
+ )
app/api/pdf_routes.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import uuid
4
+ from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks, Depends
5
+ from fastapi.responses import JSONResponse
6
+ from typing import Optional, List, Dict, Any
7
+ from sqlalchemy.orm import Session
8
+
9
+ from app.utils.pdf_processor import PDFProcessor
10
+ from app.models.pdf_models import PDFResponse, DeleteDocumentRequest, DocumentsListResponse
11
+ from app.database.postgresql import get_db
12
+ from app.database.models import VectorDatabase, Document, VectorStatus, DocumentContent
13
+ from datetime import datetime
14
+ from app.api.pdf_websocket import (
15
+ send_pdf_upload_started,
16
+ send_pdf_upload_progress,
17
+ send_pdf_upload_completed,
18
+ send_pdf_upload_failed,
19
+ send_pdf_delete_started,
20
+ send_pdf_delete_completed,
21
+ send_pdf_delete_failed
22
+ )
23
+
24
+ # Khởi tạo router
25
+ router = APIRouter(
26
+ prefix="/pdf",
27
+ tags=["PDF Processing"],
28
+ )
29
+
30
+ # Thư mục lưu file tạm - sử dụng /tmp để tránh lỗi quyền truy cập
31
+ TEMP_UPLOAD_DIR = "/tmp/uploads/temp"
32
+ STORAGE_DIR = "/tmp/uploads/pdfs"
33
+
34
+ # Đảm bảo thư mục upload tồn tại
35
+ os.makedirs(TEMP_UPLOAD_DIR, exist_ok=True)
36
+ os.makedirs(STORAGE_DIR, exist_ok=True)
37
+
38
+ # Endpoint upload và xử lý PDF
39
+ @router.post("/upload", response_model=PDFResponse)
40
+ async def upload_pdf(
41
+ file: UploadFile = File(...),
42
+ namespace: str = Form("Default"),
43
+ index_name: str = Form("testbot768"),
44
+ title: Optional[str] = Form(None),
45
+ description: Optional[str] = Form(None),
46
+ user_id: Optional[str] = Form(None),
47
+ vector_database_id: Optional[int] = Form(None),
48
+ background_tasks: BackgroundTasks = None,
49
+ db: Session = Depends(get_db)
50
+ ):
51
+ """
52
+ Upload và xử lý file PDF để tạo embeddings và lưu vào Pinecone
53
+
54
+ - **file**: File PDF cần xử lý
55
+ - **namespace**: Namespace trong Pinecone để lưu embeddings (mặc định: "Default")
56
+ - **index_name**: Tên index Pinecone (mặc định: "testbot768")
57
+ - **title**: Tiêu đề của tài liệu (tùy chọn)
58
+ - **description**: Mô tả về tài liệu (tùy chọn)
59
+ - **user_id**: ID của người dùng để cập nhật trạng thái qua WebSocket
60
+ - **vector_database_id**: ID của vector database trong PostgreSQL (tùy chọn)
61
+ """
62
+ try:
63
+ # Kiểm tra file có phải PDF không
64
+ if not file.filename.lower().endswith('.pdf'):
65
+ raise HTTPException(status_code=400, detail="Chỉ chấp nhận file PDF")
66
+
67
+ # Nếu có vector_database_id, lấy thông tin từ PostgreSQL
68
+ api_key = None
69
+ vector_db = None
70
+
71
+ if vector_database_id:
72
+ vector_db = db.query(VectorDatabase).filter(
73
+ VectorDatabase.id == vector_database_id,
74
+ VectorDatabase.status == "active"
75
+ ).first()
76
+
77
+ if not vector_db:
78
+ raise HTTPException(status_code=404, detail="Vector database không tồn tại hoặc không hoạt động")
79
+
80
+ # Sử dụng thông tin từ vector database
81
+ api_key = vector_db.api_key
82
+ index_name = vector_db.pinecone_index
83
+
84
+ # Tạo file_id và lưu file tạm
85
+ file_id = str(uuid.uuid4())
86
+ temp_file_path = os.path.join(TEMP_UPLOAD_DIR, f"{file_id}.pdf")
87
+
88
+ # Gửi thông báo bắt đầu xử lý qua WebSocket nếu có user_id
89
+ if user_id:
90
+ await send_pdf_upload_started(user_id, file.filename, file_id)
91
+
92
+ # Lưu file
93
+ file_content = await file.read()
94
+ with open(temp_file_path, "wb") as buffer:
95
+ buffer.write(file_content)
96
+
97
+ # Tạo metadata
98
+ metadata = {
99
+ "filename": file.filename,
100
+ "content_type": file.content_type
101
+ }
102
+
103
+ if title:
104
+ metadata["title"] = title
105
+ if description:
106
+ metadata["description"] = description
107
+
108
+ # Gửi thông báo tiến độ qua WebSocket
109
+ if user_id:
110
+ await send_pdf_upload_progress(
111
+ user_id,
112
+ file_id,
113
+ "file_preparation",
114
+ 0.2,
115
+ "File saved, preparing for processing"
116
+ )
117
+
118
+ # Lưu thông tin tài liệu vào PostgreSQL nếu có vector_database_id
119
+ if vector_database_id and vector_db:
120
+ # Create document record without file content
121
+ document = Document(
122
+ name=title or file.filename,
123
+ file_type="pdf",
124
+ content_type=file.content_type,
125
+ size=len(file_content),
126
+ is_embedded=False,
127
+ vector_database_id=vector_database_id
128
+ )
129
+ db.add(document)
130
+ db.commit()
131
+ db.refresh(document)
132
+
133
+ # Create document content record to store binary data separately
134
+ document_content = DocumentContent(
135
+ document_id=document.id,
136
+ file_content=file_content
137
+ )
138
+ db.add(document_content)
139
+ db.commit()
140
+
141
+ # Tạo vector status record
142
+ vector_status = VectorStatus(
143
+ document_id=document.id,
144
+ vector_database_id=vector_database_id,
145
+ status="pending"
146
+ )
147
+ db.add(vector_status)
148
+ db.commit()
149
+
150
+ # Khởi tạo PDF processor với API key nếu có
151
+ processor = PDFProcessor(index_name=index_name, namespace=namespace, api_key=api_key)
152
+
153
+ # Gửi thông báo bắt đầu embedding qua WebSocket
154
+ if user_id:
155
+ await send_pdf_upload_progress(
156
+ user_id,
157
+ file_id,
158
+ "embedding_start",
159
+ 0.4,
160
+ "Starting to process PDF and create embeddings"
161
+ )
162
+
163
+ # Xử lý PDF và tạo embeddings
164
+ # Tạo callback function để xử lý cập nhật tiến độ
165
+ async def progress_callback_wrapper(step, progress, message):
166
+ if user_id:
167
+ await send_progress_update(user_id, file_id, step, progress, message)
168
+
169
+ # Xử lý PDF và tạo embeddings với callback đã được xử lý đúng cách
170
+ result = await processor.process_pdf(
171
+ file_path=temp_file_path,
172
+ document_id=file_id,
173
+ metadata=metadata,
174
+ progress_callback=progress_callback_wrapper
175
+ )
176
+
177
+ # Nếu thành công, chuyển file vào storage
178
+ if result.get('success'):
179
+ storage_path = os.path.join(STORAGE_DIR, f"{file_id}.pdf")
180
+ shutil.move(temp_file_path, storage_path)
181
+
182
+ # Cập nhật trạng thái trong PostgreSQL nếu có vector_database_id
183
+ if vector_database_id and 'document' in locals() and 'vector_status' in locals():
184
+ vector_status.status = "completed"
185
+ vector_status.embedded_at = datetime.now()
186
+ vector_status.vector_id = file_id
187
+ document.is_embedded = True
188
+ db.commit()
189
+
190
+ # Gửi thông báo hoàn thành qua WebSocket
191
+ if user_id:
192
+ await send_pdf_upload_completed(
193
+ user_id,
194
+ file_id,
195
+ file.filename,
196
+ result.get('chunks_processed', 0)
197
+ )
198
+ else:
199
+ # Cập nhật trạng thái lỗi trong PostgreSQL nếu có vector_database_id
200
+ if vector_database_id and 'vector_status' in locals():
201
+ vector_status.status = "failed"
202
+ vector_status.error_message = result.get('error', 'Unknown error')
203
+ db.commit()
204
+
205
+ # Gửi thông báo lỗi qua WebSocket
206
+ if user_id:
207
+ await send_pdf_upload_failed(
208
+ user_id,
209
+ file_id,
210
+ file.filename,
211
+ result.get('error', 'Unknown error')
212
+ )
213
+
214
+ # Dọn dẹp: xóa file tạm nếu vẫn còn
215
+ if os.path.exists(temp_file_path):
216
+ os.remove(temp_file_path)
217
+
218
+ return result
219
+ except Exception as e:
220
+ # Dọn dẹp nếu có lỗi
221
+ if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
222
+ os.remove(temp_file_path)
223
+
224
+ # Cập nhật trạng thái lỗi trong PostgreSQL nếu có vector_database_id
225
+ if 'vector_database_id' in locals() and vector_database_id and 'vector_status' in locals():
226
+ vector_status.status = "failed"
227
+ vector_status.error_message = str(e)
228
+ db.commit()
229
+
230
+ # Gửi thông báo lỗi qua WebSocket
231
+ if 'user_id' in locals() and user_id and 'file_id' in locals():
232
+ await send_pdf_upload_failed(
233
+ user_id,
234
+ file_id,
235
+ file.filename,
236
+ str(e)
237
+ )
238
+
239
+ return PDFResponse(
240
+ success=False,
241
+ error=str(e)
242
+ )
243
+
244
+ # Function để gửi cập nhật tiến độ - được sử dụng trong callback
245
+ async def send_progress_update(user_id, document_id, step, progress, message):
246
+ if user_id:
247
+ await send_pdf_upload_progress(user_id, document_id, step, progress, message)
248
+
249
+ # Endpoint xóa tài liệu
250
+ @router.delete("/namespace", response_model=PDFResponse)
251
+ async def delete_namespace(
252
+ namespace: str = "Default",
253
+ index_name: str = "testbot768",
254
+ user_id: Optional[str] = None
255
+ ):
256
+ """
257
+ Xóa toàn bộ embeddings trong một namespace từ Pinecone (tương ứng xoá namespace)
258
+
259
+ - **namespace**: Namespace trong Pinecone (mặc định: "Default")
260
+ - **index_name**: Tên index Pinecone (mặc định: "testbot768")
261
+ - **user_id**: ID của người dùng để cập nhật trạng thái qua WebSocket
262
+ """
263
+ try:
264
+ # Gửi thông báo bắt đầu xóa qua WebSocket
265
+ if user_id:
266
+ await send_pdf_delete_started(user_id, namespace)
267
+
268
+ processor = PDFProcessor(index_name=index_name, namespace=namespace)
269
+ result = await processor.delete_namespace()
270
+
271
+ # Gửi thông báo kết quả qua WebSocket
272
+ if user_id:
273
+ if result.get('success'):
274
+ await send_pdf_delete_completed(user_id, namespace)
275
+ else:
276
+ await send_pdf_delete_failed(user_id, namespace, result.get('error', 'Unknown error'))
277
+
278
+ return result
279
+ except Exception as e:
280
+ # Gửi thông báo lỗi qua WebSocket
281
+ if user_id:
282
+ await send_pdf_delete_failed(user_id, namespace, str(e))
283
+
284
+ return PDFResponse(
285
+ success=False,
286
+ error=str(e)
287
+ )
288
+
289
+ # Endpoint lấy danh sách tài liệu
290
+ @router.get("/documents", response_model=DocumentsListResponse)
291
+ async def get_documents(namespace: str = "Default", index_name: str = "testbot768"):
292
+ """
293
+ Lấy thông tin về tất cả tài liệu đã được embed
294
+
295
+ - **namespace**: Namespace trong Pinecone (mặc định: "Default")
296
+ - **index_name**: Tên index Pinecone (mặc định: "testbot768")
297
+ """
298
+ try:
299
+ # Khởi tạo PDF processor
300
+ processor = PDFProcessor(index_name=index_name, namespace=namespace)
301
+
302
+ # Lấy danh sách documents
303
+ result = await processor.list_documents()
304
+
305
+ return result
306
+ except Exception as e:
307
+ return DocumentsListResponse(
308
+ success=False,
309
+ error=str(e)
310
+ )
app/api/pdf_websocket.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, List, Optional, Any
3
+ from fastapi import WebSocket, WebSocketDisconnect, APIRouter
4
+ from pydantic import BaseModel
5
+ import json
6
+ import time
7
+
8
+ # Cấu hình logging
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Models cho Swagger documentation
12
+ class ConnectionStatus(BaseModel):
13
+ user_id: str
14
+ active: bool
15
+ connection_count: int
16
+ last_activity: Optional[float] = None
17
+
18
+ class UserConnection(BaseModel):
19
+ user_id: str
20
+ connection_count: int
21
+
22
+ class AllConnectionsStatus(BaseModel):
23
+ total_users: int
24
+ total_connections: int
25
+ users: List[UserConnection]
26
+
27
+ # Khởi tạo router
28
+ router = APIRouter(
29
+ prefix="/ws",
30
+ tags=["WebSockets"],
31
+ )
32
+
33
+ class ConnectionManager:
34
+ """Quản lý các kết nối WebSocket"""
35
+
36
+ def __init__(self):
37
+ # Lưu trữ các kết nối theo user_id
38
+ self.active_connections: Dict[str, List[WebSocket]] = {}
39
+
40
+ async def connect(self, websocket: WebSocket, user_id: str):
41
+ """Kết nối một WebSocket mới"""
42
+ await websocket.accept()
43
+ if user_id not in self.active_connections:
44
+ self.active_connections[user_id] = []
45
+ self.active_connections[user_id].append(websocket)
46
+ logger.info(f"New WebSocket connection for user {user_id}. Total connections: {len(self.active_connections[user_id])}")
47
+
48
+ def disconnect(self, websocket: WebSocket, user_id: str):
49
+ """Ngắt kết nối WebSocket"""
50
+ if user_id in self.active_connections:
51
+ if websocket in self.active_connections[user_id]:
52
+ self.active_connections[user_id].remove(websocket)
53
+ # Xóa user_id khỏi dict nếu không còn kết nối nào
54
+ if not self.active_connections[user_id]:
55
+ del self.active_connections[user_id]
56
+ logger.info(f"WebSocket disconnected for user {user_id}")
57
+
58
+ async def send_message(self, message: Dict[str, Any], user_id: str):
59
+ """Gửi tin nhắn tới tất cả kết nối của một user"""
60
+ if user_id in self.active_connections:
61
+ disconnected_websockets = []
62
+ for websocket in self.active_connections[user_id]:
63
+ try:
64
+ await websocket.send_text(json.dumps(message))
65
+ except Exception as e:
66
+ logger.error(f"Error sending message to WebSocket: {str(e)}")
67
+ disconnected_websockets.append(websocket)
68
+
69
+ # Xóa các kết nối bị ngắt
70
+ for websocket in disconnected_websockets:
71
+ self.disconnect(websocket, user_id)
72
+
73
+ def get_connection_status(self, user_id: str = None) -> Dict[str, Any]:
74
+ """Lấy thông tin về trạng thái kết nối WebSocket"""
75
+ if user_id:
76
+ # Trả về thông tin kết nối cho user cụ thể
77
+ if user_id in self.active_connections:
78
+ return {
79
+ "user_id": user_id,
80
+ "active": True,
81
+ "connection_count": len(self.active_connections[user_id]),
82
+ "last_activity": time.time()
83
+ }
84
+ else:
85
+ return {
86
+ "user_id": user_id,
87
+ "active": False,
88
+ "connection_count": 0,
89
+ "last_activity": None
90
+ }
91
+ else:
92
+ # Trả về thông tin tất cả kết nối
93
+ result = {
94
+ "total_users": len(self.active_connections),
95
+ "total_connections": sum(len(connections) for connections in self.active_connections.values()),
96
+ "users": []
97
+ }
98
+
99
+ for uid, connections in self.active_connections.items():
100
+ result["users"].append({
101
+ "user_id": uid,
102
+ "connection_count": len(connections)
103
+ })
104
+
105
+ return result
106
+
107
+
108
+ # Tạo instance của ConnectionManager
109
+ manager = ConnectionManager()
110
+
111
+ @router.websocket("/pdf/{user_id}")
112
+ async def websocket_endpoint(websocket: WebSocket, user_id: str):
113
+ """Endpoint WebSocket để cập nhật tiến trình xử lý PDF"""
114
+ await manager.connect(websocket, user_id)
115
+ try:
116
+ while True:
117
+ # Đợi tin nhắn từ client (chỉ để giữ kết nối)
118
+ await websocket.receive_text()
119
+ except WebSocketDisconnect:
120
+ manager.disconnect(websocket, user_id)
121
+ except Exception as e:
122
+ logger.error(f"WebSocket error: {str(e)}")
123
+ manager.disconnect(websocket, user_id)
124
+
125
+ # API endpoints để kiểm tra trạng thái WebSocket
126
+ @router.get("/status", response_model=AllConnectionsStatus, responses={
127
+ 200: {
128
+ "description": "Successful response",
129
+ "content": {
130
+ "application/json": {
131
+ "example": {
132
+ "total_users": 2,
133
+ "total_connections": 3,
134
+ "users": [
135
+ {"user_id": "user1", "connection_count": 2},
136
+ {"user_id": "user2", "connection_count": 1}
137
+ ]
138
+ }
139
+ }
140
+ }
141
+ }
142
+ })
143
+ async def get_all_websocket_connections():
144
+ """
145
+ Lấy thông tin về tất cả kết nối WebSocket hiện tại.
146
+
147
+ Endpoint này trả về:
148
+ - Tổng số người dùng đang kết nối
149
+ - Tổng số kết nối WebSocket
150
+ - Danh sách người dùng kèm theo số lượng kết nối của mỗi người
151
+ """
152
+ return manager.get_connection_status()
153
+
154
+ @router.get("/status/{user_id}", response_model=ConnectionStatus, responses={
155
+ 200: {
156
+ "description": "Successful response for active connection",
157
+ "content": {
158
+ "application/json": {
159
+ "examples": {
160
+ "active_connection": {
161
+ "summary": "Active connection",
162
+ "value": {
163
+ "user_id": "user123",
164
+ "active": True,
165
+ "connection_count": 2,
166
+ "last_activity": 1634567890.123
167
+ }
168
+ },
169
+ "no_connection": {
170
+ "summary": "No active connection",
171
+ "value": {
172
+ "user_id": "user456",
173
+ "active": False,
174
+ "connection_count": 0,
175
+ "last_activity": None
176
+ }
177
+ }
178
+ }
179
+ }
180
+ }
181
+ }
182
+ })
183
+ async def get_user_websocket_status(user_id: str):
184
+ """
185
+ Lấy thông tin về kết nối WebSocket của một người dùng cụ thể.
186
+
187
+ Parameters:
188
+ - **user_id**: ID của người dùng cần kiểm tra
189
+
190
+ Returns:
191
+ - Thông tin về trạng thái kết nối, bao gồm:
192
+ - active: Có đang kết nối hay không
193
+ - connection_count: Số lượng kết nối hiện tại
194
+ - last_activity: Thời gian hoạt động gần nhất
195
+ """
196
+ return manager.get_connection_status(user_id)
197
+
198
+ # Các hàm gửi thông báo cập nhật trạng thái
199
+
200
+ async def send_pdf_upload_started(user_id: str, filename: str, document_id: str):
201
+ """Gửi thông báo bắt đầu upload PDF"""
202
+ await manager.send_message({
203
+ "type": "pdf_upload_started",
204
+ "document_id": document_id,
205
+ "filename": filename,
206
+ "timestamp": int(time.time())
207
+ }, user_id)
208
+
209
+ async def send_pdf_upload_progress(user_id: str, document_id: str, step: str, progress: float, message: str):
210
+ """Gửi thông báo tiến độ upload PDF"""
211
+ await manager.send_message({
212
+ "type": "pdf_upload_progress",
213
+ "document_id": document_id,
214
+ "step": step,
215
+ "progress": progress,
216
+ "message": message,
217
+ "timestamp": int(time.time())
218
+ }, user_id)
219
+
220
+ async def send_pdf_upload_completed(user_id: str, document_id: str, filename: str, chunks: int):
221
+ """Gửi thông báo hoàn thành upload PDF"""
222
+ await manager.send_message({
223
+ "type": "pdf_upload_completed",
224
+ "document_id": document_id,
225
+ "filename": filename,
226
+ "chunks": chunks,
227
+ "timestamp": int(time.time())
228
+ }, user_id)
229
+
230
+ async def send_pdf_upload_failed(user_id: str, document_id: str, filename: str, error: str):
231
+ """Gửi thông báo lỗi upload PDF"""
232
+ await manager.send_message({
233
+ "type": "pdf_upload_failed",
234
+ "document_id": document_id,
235
+ "filename": filename,
236
+ "error": error,
237
+ "timestamp": int(time.time())
238
+ }, user_id)
239
+
240
+ async def send_pdf_delete_started(user_id: str, namespace: str):
241
+ """Gửi thông báo bắt đầu xóa PDF"""
242
+ await manager.send_message({
243
+ "type": "pdf_delete_started",
244
+ "namespace": namespace,
245
+ "timestamp": int(time.time())
246
+ }, user_id)
247
+
248
+ async def send_pdf_delete_completed(user_id: str, namespace: str):
249
+ """Gửi thông báo hoàn thành xóa PDF"""
250
+ await manager.send_message({
251
+ "type": "pdf_delete_completed",
252
+ "namespace": namespace,
253
+ "timestamp": int(time.time())
254
+ }, user_id)
255
+
256
+ async def send_pdf_delete_failed(user_id: str, namespace: str, error: str):
257
+ """Gửi thông báo lỗi xóa PDF"""
258
+ await manager.send_message({
259
+ "type": "pdf_delete_failed",
260
+ "namespace": namespace,
261
+ "error": error,
262
+ "timestamp": int(time.time())
263
+ }, user_id)
app/api/postgresql_routes.py ADDED
The diff for this file is too large to render. See raw diff
 
app/api/rag_routes.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Depends, Query, BackgroundTasks, Request
2
+ from typing import List, Optional, Dict, Any
3
+ import logging
4
+ import time
5
+ import os
6
+ import json
7
+ import hashlib
8
+ import asyncio
9
+ import traceback
10
+ import google.generativeai as genai
11
+ from datetime import datetime
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
14
+ from app.utils.utils import timer_decorator
15
+
16
+ from app.database.mongodb import get_chat_history, get_request_history, session_collection
17
+ from app.database.pinecone import (
18
+ search_vectors,
19
+ get_chain,
20
+ DEFAULT_TOP_K,
21
+ DEFAULT_LIMIT_K,
22
+ DEFAULT_SIMILARITY_METRIC,
23
+ DEFAULT_SIMILARITY_THRESHOLD,
24
+ ALLOWED_METRICS
25
+ )
26
+ from app.models.rag_models import (
27
+ ChatRequest,
28
+ ChatResponse,
29
+ ChatResponseInternal,
30
+ SourceDocument,
31
+ EmbeddingRequest,
32
+ EmbeddingResponse,
33
+ UserMessageModel
34
+ )
35
+
36
+ # Configure logging
37
+ logger = logging.getLogger(__name__)
38
+
39
+ # Configure Google Gemini API
40
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
41
+ genai.configure(api_key=GOOGLE_API_KEY)
42
+
43
+ # Create router
44
+ router = APIRouter(
45
+ prefix="/rag",
46
+ tags=["RAG"],
47
+ )
48
+
49
+ fix_request = PromptTemplate(
50
+ template = """Goal:
51
+ Your task is to extract important keywords from the user's current request, optionally using chat history if relevant.
52
+ You will receive a conversation history and the user's current message.
53
+ Generate a **list of concise keywords** that best represent the user's intent.
54
+
55
+ Return Format:
56
+ Only return keywords (comma-separated, no extra explanation).
57
+ If the current message is NOT related to the chat history or if there is no chat history: Return keywords from the current message only.
58
+ If the current message IS related to the chat history: Return a refined set of keywords based on both history and current message.
59
+
60
+ Warning:
61
+ Only use chat history if the current message is clearly related to the prior context.
62
+
63
+ Conversation History:
64
+ {chat_history}
65
+
66
+ User current message:
67
+ {question}
68
+ """,
69
+ input_variables=["chat_history", "question"],
70
+ )
71
+
72
+ # Create a prompt template with conversation history
73
+ prompt = PromptTemplate(
74
+ template = """Goal:
75
+ You are a professional tour guide assistant that assists users in finding information about places in Da Nang, Vietnam.
76
+ You can provide details on restaurants, cafes, hotels, attractions, and other local venues.
77
+ You have to use core knowledge and conversation history to chat with users, who are Da Nang's tourists.
78
+
79
+ Return Format:
80
+ Respond in friendly, natural, concise and use only English like a real tour guide.
81
+ Always use HTML tags (e.g. <b> for bold) so that Telegram can render the special formatting correctly.
82
+
83
+ Warning:
84
+ Let's support users like a real tour guide, not a bot. The information in core knowledge is your own knowledge.
85
+ Your knowledge is provided in the Core Knowledge. All of information in Core Knowledge is about Da Nang, Vietnam.
86
+ You just care about current time that user mention when user ask about Solana event.
87
+ Only use core knowledge to answer. If you do not have enough information to answer user's question, please reply with "I'm sorry. I don't have information about that" and Give users some more options to ask.
88
+
89
+ Core knowledge:
90
+ {context}
91
+
92
+ Conversation History:
93
+ {chat_history}
94
+
95
+ User message:
96
+ {question}
97
+
98
+ Your message:
99
+ """,
100
+ input_variables = ["context", "question", "chat_history"],
101
+ )
102
+
103
+ # Helper for embeddings
104
+ async def get_embedding(text: str):
105
+ """Get embedding from Google Gemini API"""
106
+ try:
107
+ # Initialize embedding model
108
+ embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
109
+
110
+ # Generate embedding
111
+ result = await embedding_model.aembed_query(text)
112
+
113
+ # Return embedding
114
+ return {
115
+ "embedding": result,
116
+ "text": text,
117
+ "model": "embedding-001"
118
+ }
119
+ except Exception as e:
120
+ logger.error(f"Error generating embedding: {e}")
121
+ raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
122
+
123
+ # Endpoint for generating embeddings
124
+ @router.post("/embedding", response_model=EmbeddingResponse)
125
+ async def create_embedding(request: EmbeddingRequest):
126
+ """
127
+ Generate embedding for text.
128
+
129
+ - **text**: Text to generate embedding for
130
+ """
131
+ try:
132
+ # Get embedding
133
+ embedding_data = await get_embedding(request.text)
134
+
135
+ # Return embedding
136
+ return EmbeddingResponse(**embedding_data)
137
+ except Exception as e:
138
+ logger.error(f"Error generating embedding: {e}")
139
+ raise HTTPException(status_code=500, detail=f"Failed to generate embedding: {str(e)}")
140
+
141
+ @timer_decorator
142
+ @router.post("/chat", response_model=ChatResponse)
143
+ async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
144
+ """
145
+ Get answer for a question using RAG.
146
+
147
+ - **user_id**: User's ID from Telegram
148
+ - **question**: User's question
149
+ - **include_history**: Whether to include user history in prompt (default: True)
150
+ - **use_rag**: Whether to use RAG (default: True)
151
+ - **similarity_top_k**: Number of top similar documents to return after filtering (default: 6)
152
+ - **limit_k**: Maximum number of documents to retrieve from vector store (default: 10)
153
+ - **similarity_metric**: Similarity metric to use - cosine, dotproduct, euclidean (default: cosine)
154
+ - **similarity_threshold**: Threshold for vector similarity (default: 0.75)
155
+ - **session_id**: Optional session ID for tracking conversations
156
+ - **first_name**: User's first name
157
+ - **last_name**: User's last name
158
+ - **username**: User's username
159
+ """
160
+ start_time = time.time()
161
+ try:
162
+ # Save user message first (so it's available for user history)
163
+ session_id = request.session_id or f"{request.user_id}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
164
+ # logger.info(f"Processing chat request for user {request.user_id}, session {session_id}")
165
+
166
+ retriever = get_chain(
167
+ top_k=request.similarity_top_k,
168
+ limit_k=request.limit_k,
169
+ similarity_metric=request.similarity_metric,
170
+ similarity_threshold=request.similarity_threshold
171
+ )
172
+ if not retriever:
173
+ raise HTTPException(status_code=500, detail="Failed to initialize retriever")
174
+
175
+ # Get chat history
176
+ chat_history = get_chat_history(request.user_id) if request.include_history else ""
177
+ logger.info(f"Using chat history: {chat_history[:100]}...")
178
+
179
+ # Initialize Gemini model
180
+ generation_config = {
181
+ "temperature": 0.9,
182
+ "top_p": 1,
183
+ "top_k": 1,
184
+ "max_output_tokens": 2048,
185
+ }
186
+
187
+ safety_settings = [
188
+ {
189
+ "category": "HARM_CATEGORY_HARASSMENT",
190
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
191
+ },
192
+ {
193
+ "category": "HARM_CATEGORY_HATE_SPEECH",
194
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
195
+ },
196
+ {
197
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
198
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
199
+ },
200
+ {
201
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
202
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
203
+ },
204
+ ]
205
+
206
+ model = genai.GenerativeModel(
207
+ model_name='models/gemini-2.0-flash',
208
+ generation_config=generation_config,
209
+ safety_settings=safety_settings
210
+ )
211
+
212
+ prompt_request = fix_request.format(
213
+ question=request.question,
214
+ chat_history=chat_history
215
+ )
216
+
217
+ # Log thời gian bắt đầu final_request
218
+ final_request_start_time = time.time()
219
+ final_request = model.generate_content(prompt_request)
220
+ # Log thời gian hoàn thành final_request
221
+ logger.info(f"Fixed Request: {final_request.text}")
222
+ logger.info(f"Final request generation time: {time.time() - final_request_start_time:.2f} seconds")
223
+ # print(final_request.text)
224
+
225
+ retrieved_docs = retriever.invoke(final_request.text)
226
+ logger.info(f"Retrieve: {retrieved_docs}")
227
+ context = "\n".join([doc.page_content for doc in retrieved_docs])
228
+
229
+ sources = []
230
+ for doc in retrieved_docs:
231
+ source = None
232
+ metadata = {}
233
+
234
+ if hasattr(doc, 'metadata'):
235
+ source = doc.metadata.get('source', None)
236
+ # Extract score information
237
+ score = doc.metadata.get('score', None)
238
+ normalized_score = doc.metadata.get('normalized_score', None)
239
+ # Remove score info from metadata to avoid duplication
240
+ metadata = {k: v for k, v in doc.metadata.items()
241
+ if k not in ['text', 'source', 'score', 'normalized_score']}
242
+
243
+ sources.append(SourceDocument(
244
+ text=doc.page_content,
245
+ source=source,
246
+ score=score,
247
+ normalized_score=normalized_score,
248
+ metadata=metadata
249
+ ))
250
+
251
+ # Generate the prompt using template
252
+ prompt_text = prompt.format(
253
+ context=context,
254
+ question=final_request.text,
255
+ chat_history=chat_history
256
+ )
257
+ logger.info(f"Full prompt with history and context: {prompt_text}")
258
+
259
+ # Generate response
260
+ response = model.generate_content(prompt_text)
261
+ answer = response.text
262
+
263
+ # Calculate processing time
264
+ processing_time = time.time() - start_time
265
+
266
+ # Log full response with sources
267
+ # logger.info(f"Generated response for user {request.user_id}: {answer}")
268
+
269
+ # Create response object for API (without sources)
270
+ chat_response = ChatResponse(
271
+ answer=answer,
272
+ processing_time=processing_time
273
+ )
274
+
275
+ # Return response
276
+ return chat_response
277
+ except Exception as e:
278
+ logger.error(f"Error processing chat request: {e}")
279
+ import traceback
280
+ logger.error(traceback.format_exc())
281
+ raise HTTPException(status_code=500, detail=f"Failed to process chat request: {str(e)}")
282
+
283
+ # Health check endpoint
284
+ @router.get("/health")
285
+ async def health_check():
286
+ """
287
+ Check health of RAG services and retrieval system.
288
+
289
+ Returns:
290
+ - status: "healthy" if all services are working, "degraded" otherwise
291
+ - services: Status of each service (gemini, pinecone)
292
+ - retrieval_config: Current retrieval configuration
293
+ - timestamp: Current time
294
+ """
295
+ services = {
296
+ "gemini": False,
297
+ "pinecone": False
298
+ }
299
+
300
+ # Check Gemini
301
+ try:
302
+ # Initialize simple model
303
+ model = genai.GenerativeModel("gemini-2.0-flash")
304
+ # Test generation
305
+ response = model.generate_content("Hello")
306
+ services["gemini"] = True
307
+ except Exception as e:
308
+ logger.error(f"Gemini health check failed: {e}")
309
+
310
+ # Check Pinecone
311
+ try:
312
+ # Import pinecone function
313
+ from app.database.pinecone import get_pinecone_index
314
+ # Get index
315
+ index = get_pinecone_index()
316
+ # Check if index exists
317
+ if index:
318
+ services["pinecone"] = True
319
+ except Exception as e:
320
+ logger.error(f"Pinecone health check failed: {e}")
321
+
322
+ # Get retrieval configuration
323
+ retrieval_config = {
324
+ "default_top_k": DEFAULT_TOP_K,
325
+ "default_limit_k": DEFAULT_LIMIT_K,
326
+ "default_similarity_metric": DEFAULT_SIMILARITY_METRIC,
327
+ "default_similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD,
328
+ "allowed_metrics": ALLOWED_METRICS
329
+ }
330
+
331
+ # Return health status
332
+ status = "healthy" if all(services.values()) else "degraded"
333
+ return {
334
+ "status": status,
335
+ "services": services,
336
+ "retrieval_config": retrieval_config,
337
+ "timestamp": datetime.now().isoformat()
338
+ }
app/api/websocket_routes.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, status
2
+ from typing import List, Dict
3
+ import logging
4
+ from datetime import datetime
5
+ import asyncio
6
+ import json
7
+ import os
8
+ from dotenv import load_dotenv
9
+ from app.database.mongodb import session_collection
10
+ from app.utils.utils import get_local_time
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Get WebSocket configuration from environment variables
16
+ WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
17
+ WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
18
+ WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
19
+
20
+ # Configure logging
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Create router
24
+ router = APIRouter(
25
+ tags=["WebSocket"],
26
+ )
27
+
28
+ # Store active WebSocket connections
29
+ class ConnectionManager:
30
+ def __init__(self):
31
+ self.active_connections: List[WebSocket] = []
32
+
33
+ async def connect(self, websocket: WebSocket):
34
+ await websocket.accept()
35
+ self.active_connections.append(websocket)
36
+ client_info = f"{websocket.client.host}:{websocket.client.port}" if hasattr(websocket, 'client') else "Unknown"
37
+ logger.info(f"New WebSocket connection from {client_info}. Total connections: {len(self.active_connections)}")
38
+
39
+ def disconnect(self, websocket: WebSocket):
40
+ self.active_connections.remove(websocket)
41
+ logger.info(f"WebSocket connection removed. Total connections: {len(self.active_connections)}")
42
+
43
+ async def broadcast(self, message: Dict):
44
+ if not self.active_connections:
45
+ logger.warning("No active WebSocket connections to broadcast to")
46
+ return
47
+
48
+ disconnected = []
49
+ for connection in self.active_connections:
50
+ try:
51
+ await connection.send_json(message)
52
+ logger.info(f"Message sent to WebSocket connection")
53
+ except Exception as e:
54
+ logger.error(f"Error sending message to WebSocket: {e}")
55
+ disconnected.append(connection)
56
+
57
+ # Remove disconnected connections
58
+ for conn in disconnected:
59
+ if conn in self.active_connections:
60
+ self.active_connections.remove(conn)
61
+ logger.info(f"Removed disconnected WebSocket. Remaining: {len(self.active_connections)}")
62
+
63
+ # Initialize connection manager
64
+ manager = ConnectionManager()
65
+
66
+ # Create full URL of WebSocket server from environment variables
67
+ def get_full_websocket_url(server_side=False):
68
+ if server_side:
69
+ # Relative URL (for server side)
70
+ return WEBSOCKET_PATH
71
+ else:
72
+ # Full URL (for client)
73
+ # Check if should use wss:// for HTTPS
74
+ is_https = True if int(WEBSOCKET_PORT) == 443 else False
75
+ protocol = "wss" if is_https else "ws"
76
+
77
+ # If using default port for protocol, don't include in URL
78
+ if (is_https and int(WEBSOCKET_PORT) == 443) or (not is_https and int(WEBSOCKET_PORT) == 80):
79
+ return f"{protocol}://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
80
+ else:
81
+ return f"{protocol}://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
82
+
83
+ # Add GET endpoint to display WebSocket information in Swagger
84
+ @router.get("/notify",
85
+ summary="WebSocket notifications for Admin Bot",
86
+ description=f"""
87
+ This is documentation for the WebSocket endpoint.
88
+
89
+ To connect to WebSocket:
90
+ 1. Use the path `{get_full_websocket_url()}`
91
+ 2. Connect using a WebSocket client library
92
+ 3. When there are new sessions requiring attention, you will receive notifications through this connection
93
+
94
+ Notifications are sent when:
95
+ - Session response starts with "I'm sorry"
96
+ - The system cannot answer the user's question
97
+
98
+ Make sure to send a "keepalive" message every 5 minutes to maintain the connection.
99
+ """,
100
+ status_code=status.HTTP_200_OK
101
+ )
102
+ async def websocket_documentation():
103
+ """
104
+ Provides information about how to use the WebSocket endpoint /notify.
105
+ This endpoint is for documentation purposes only. To use WebSocket, please connect to the WebSocket URL.
106
+ """
107
+ ws_url = get_full_websocket_url()
108
+ return {
109
+ "websocket_endpoint": WEBSOCKET_PATH,
110
+ "connection_type": "WebSocket",
111
+ "protocol": "ws://",
112
+ "server": WEBSOCKET_SERVER,
113
+ "port": WEBSOCKET_PORT,
114
+ "full_url": ws_url,
115
+ "description": "Endpoint to receive notifications about new sessions requiring attention",
116
+ "notification_format": {
117
+ "type": "sorry_response",
118
+ "timestamp": "YYYY-MM-DD HH:MM:SS",
119
+ "data": {
120
+ "session_id": "session id",
121
+ "factor": "user",
122
+ "action": "action type",
123
+ "message": "User question",
124
+ "response": "I'm sorry...",
125
+ "user_id": "user id",
126
+ "first_name": "user's first name",
127
+ "last_name": "user's last name",
128
+ "username": "username",
129
+ "created_at": "creation time"
130
+ }
131
+ },
132
+ "client_example": """
133
+ import websocket
134
+ import json
135
+ import os
136
+ import time
137
+ import threading
138
+ from dotenv import load_dotenv
139
+
140
+ # Load environment variables
141
+ load_dotenv()
142
+
143
+ # Get WebSocket configuration from environment variables
144
+ WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost")
145
+ WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860")
146
+ WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify")
147
+
148
+ # Create full URL
149
+ ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}"
150
+
151
+ # If using HTTPS, replace ws:// with wss://
152
+ # ws_url = f"wss://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}"
153
+
154
+ # Send keepalive periodically
155
+ def send_keepalive(ws):
156
+ while True:
157
+ try:
158
+ if ws.sock and ws.sock.connected:
159
+ ws.send("keepalive")
160
+ print("Sent keepalive message")
161
+ time.sleep(300) # 5 minutes
162
+ except Exception as e:
163
+ print(f"Error sending keepalive: {e}")
164
+ time.sleep(60)
165
+
166
+ def on_message(ws, message):
167
+ try:
168
+ data = json.loads(message)
169
+ print(f"Received notification: {data}")
170
+ # Process notification, e.g.: send to Telegram Admin
171
+ if data.get("type") == "sorry_response":
172
+ session_data = data.get("data", {})
173
+ user_question = session_data.get("message", "")
174
+ user_name = session_data.get("first_name", "Unknown User")
175
+ print(f"User {user_name} asked: {user_question}")
176
+ # Code to send message to Telegram Admin
177
+ except json.JSONDecodeError:
178
+ print(f"Received non-JSON message: {message}")
179
+ except Exception as e:
180
+ print(f"Error processing message: {e}")
181
+
182
+ def on_error(ws, error):
183
+ print(f"WebSocket error: {error}")
184
+
185
+ def on_close(ws, close_status_code, close_msg):
186
+ print(f"WebSocket connection closed: code={close_status_code}, message={close_msg}")
187
+
188
+ def on_open(ws):
189
+ print(f"WebSocket connection opened to {ws_url}")
190
+ # Send keepalive messages periodically in a separate thread
191
+ keepalive_thread = threading.Thread(target=send_keepalive, args=(ws,), daemon=True)
192
+ keepalive_thread.start()
193
+
194
+ def run_forever_with_reconnect():
195
+ while True:
196
+ try:
197
+ # Connect WebSocket with ping to maintain connection
198
+ ws = websocket.WebSocketApp(
199
+ ws_url,
200
+ on_open=on_open,
201
+ on_message=on_message,
202
+ on_error=on_error,
203
+ on_close=on_close
204
+ )
205
+ ws.run_forever(ping_interval=60, ping_timeout=30)
206
+ print("WebSocket connection lost, reconnecting in 5 seconds...")
207
+ time.sleep(5)
208
+ except Exception as e:
209
+ print(f"WebSocket connection error: {e}")
210
+ time.sleep(5)
211
+
212
+ # Start WebSocket client in a separate thread
213
+ websocket_thread = threading.Thread(target=run_forever_with_reconnect, daemon=True)
214
+ websocket_thread.start()
215
+
216
+ # Keep the program running
217
+ try:
218
+ while True:
219
+ time.sleep(1)
220
+ except KeyboardInterrupt:
221
+ print("Stopping WebSocket client...")
222
+ """
223
+ }
224
+
225
+ @router.websocket("/notify")
226
+ async def websocket_endpoint(websocket: WebSocket):
227
+ """
228
+ WebSocket endpoint to receive notifications about new sessions.
229
+ Admin Bot will connect to this endpoint to receive notifications when there are new sessions requiring attention.
230
+ """
231
+ await manager.connect(websocket)
232
+ try:
233
+ while True:
234
+ # Maintain WebSocket connection
235
+ data = await websocket.receive_text()
236
+ # Echo back to keep connection active
237
+ await websocket.send_json({"status": "connected", "echo": data, "timestamp": datetime.now().isoformat()})
238
+ logger.info(f"Received message from WebSocket: {data}")
239
+ except WebSocketDisconnect:
240
+ logger.info("WebSocket client disconnected")
241
+ manager.disconnect(websocket)
242
+ except Exception as e:
243
+ logger.error(f"WebSocket error: {e}")
244
+ manager.disconnect(websocket)
245
+
246
+ # Function to send notifications over WebSocket
247
+ async def send_notification(data: dict):
248
+ """
249
+ Send notification to all active WebSocket connections.
250
+
251
+ This function is used to notify admin bots about new issues or questions that need attention.
252
+ It's triggered when the system cannot answer a user's question (response starts with "I'm sorry").
253
+
254
+ Args:
255
+ data: The data to send as notification
256
+ """
257
+ try:
258
+ # Log number of active connections and notification attempt
259
+ logger.info(f"Attempting to send notification. Active connections: {len(manager.active_connections)}")
260
+ logger.info(f"Notification data: session_id={data.get('session_id')}, user_id={data.get('user_id')}")
261
+ logger.info(f"Response: {data.get('response', '')[:50]}...")
262
+
263
+ # Check if the response starts with "I'm sorry"
264
+ response = data.get('response', '')
265
+ if not response or not isinstance(response, str):
266
+ logger.warning(f"Invalid response format in notification data: {response}")
267
+ return
268
+
269
+ if not response.strip().lower().startswith("i'm sorry"):
270
+ logger.info(f"Response doesn't start with 'I'm sorry', notification not needed: {response[:50]}...")
271
+ return
272
+
273
+ logger.info(f"Response starts with 'I'm sorry', sending notification")
274
+
275
+ # Format the notification data for admin - format theo chuẩn Admin_bot
276
+ notification_data = {
277
+ "type": "sorry_response", # Đổi type thành sorry_response để phù hợp với Admin_bot
278
+ "timestamp": get_local_time(),
279
+ "user_id": data.get('user_id', 'unknown'),
280
+ "message": data.get('message', ''),
281
+ "response": response,
282
+ "session_id": data.get('session_id', 'unknown'),
283
+ "user_info": {
284
+ "first_name": data.get('first_name', 'User'),
285
+ "last_name": data.get('last_name', ''),
286
+ "username": data.get('username', '')
287
+ }
288
+ }
289
+
290
+ # Check if there are active connections
291
+ if not manager.active_connections:
292
+ logger.warning("No active WebSocket connections for notification broadcast")
293
+ return
294
+
295
+ # Broadcast notification to all active connections
296
+ logger.info(f"Broadcasting notification to {len(manager.active_connections)} connections")
297
+ await manager.broadcast(notification_data)
298
+ logger.info("Notification broadcast completed successfully")
299
+
300
+ except Exception as e:
301
+ logger.error(f"Error sending notification: {e}")
302
+ import traceback
303
+ logger.error(traceback.format_exc())
app/database/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Database connections package
app/database/models.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Float, Text, LargeBinary, JSON
2
+ from sqlalchemy.sql import func
3
+ from sqlalchemy.orm import relationship
4
+ from .postgresql import Base
5
+ import datetime
6
+
7
+ class FAQItem(Base):
8
+ __tablename__ = "faq_item"
9
+
10
+ id = Column(Integer, primary_key=True, index=True)
11
+ question = Column(String, nullable=False)
12
+ answer = Column(String, nullable=False)
13
+ is_active = Column(Boolean, default=True)
14
+ created_at = Column(DateTime, server_default=func.now())
15
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
16
+
17
+ class EmergencyItem(Base):
18
+ __tablename__ = "emergency_item"
19
+
20
+ id = Column(Integer, primary_key=True, index=True)
21
+ name = Column(String, nullable=False)
22
+ phone_number = Column(String, nullable=False)
23
+ description = Column(String, nullable=True)
24
+ address = Column(String, nullable=True)
25
+ location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
26
+ priority = Column(Integer, default=0)
27
+ is_active = Column(Boolean, default=True)
28
+ section = Column(String, nullable=True) # Section field (16.1, 16.2.1, 16.2.2, 16.3)
29
+ section_id = Column(Integer, nullable=True) # Numeric identifier for section
30
+ created_at = Column(DateTime, server_default=func.now())
31
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
32
+
33
+ class EventItem(Base):
34
+ __tablename__ = "event_item"
35
+
36
+ id = Column(Integer, primary_key=True, index=True)
37
+ name = Column(String, nullable=False)
38
+ description = Column(Text, nullable=False)
39
+ address = Column(String, nullable=False)
40
+ location = Column(String, nullable=True) # Will be converted to/from PostGIS POINT type
41
+ date_start = Column(DateTime, nullable=False)
42
+ date_end = Column(DateTime, nullable=True)
43
+ price = Column(JSON, nullable=True)
44
+ url = Column(String, nullable=True)
45
+ is_active = Column(Boolean, default=True)
46
+ featured = Column(Boolean, default=False)
47
+ created_at = Column(DateTime, server_default=func.now())
48
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
49
+
50
+ class AboutPixity(Base):
51
+ __tablename__ = "about_pixity"
52
+
53
+ id = Column(Integer, primary_key=True, index=True)
54
+ content = Column(Text, nullable=False)
55
+ created_at = Column(DateTime, server_default=func.now())
56
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
57
+
58
+ class SolanaSummit(Base):
59
+ __tablename__ = "solana_summit"
60
+
61
+ id = Column(Integer, primary_key=True, index=True)
62
+ content = Column(Text, nullable=False)
63
+ created_at = Column(DateTime, server_default=func.now())
64
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
65
+
66
+ class DaNangBucketList(Base):
67
+ __tablename__ = "danang_bucket_list"
68
+
69
+ id = Column(Integer, primary_key=True, index=True)
70
+ content = Column(Text, nullable=False)
71
+ created_at = Column(DateTime, server_default=func.now())
72
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
73
+
74
+ class VectorDatabase(Base):
75
+ __tablename__ = "vector_database"
76
+
77
+ id = Column(Integer, primary_key=True, index=True)
78
+ name = Column(String, nullable=False, unique=True)
79
+ description = Column(String, nullable=True)
80
+ pinecone_index = Column(String, nullable=False)
81
+ api_key_id = Column(Integer, ForeignKey("api_key.id"), nullable=True)
82
+ status = Column(String, default="active")
83
+ created_at = Column(DateTime, server_default=func.now())
84
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
85
+
86
+ # Relationships
87
+ documents = relationship("Document", back_populates="vector_database")
88
+ vector_statuses = relationship("VectorStatus", back_populates="vector_database")
89
+ engine_associations = relationship("EngineVectorDb", back_populates="vector_database")
90
+ api_key_ref = relationship("ApiKey", foreign_keys=[api_key_id])
91
+
92
+ class Document(Base):
93
+ __tablename__ = "document"
94
+
95
+ id = Column(Integer, primary_key=True, index=True)
96
+ name = Column(String, nullable=False)
97
+ file_type = Column(String, nullable=True)
98
+ content_type = Column(String, nullable=True)
99
+ size = Column(Integer, nullable=True)
100
+ is_embedded = Column(Boolean, default=False)
101
+ vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
102
+ created_at = Column(DateTime, server_default=func.now())
103
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
104
+
105
+ # Relationships
106
+ vector_database = relationship("VectorDatabase", back_populates="documents")
107
+ vector_statuses = relationship("VectorStatus", back_populates="document")
108
+ file_content_ref = relationship("DocumentContent", back_populates="document", uselist=False, cascade="all, delete-orphan")
109
+
110
+ class DocumentContent(Base):
111
+ __tablename__ = "document_content"
112
+
113
+ id = Column(Integer, primary_key=True, index=True)
114
+ document_id = Column(Integer, ForeignKey("document.id"), nullable=False, unique=True)
115
+ file_content = Column(LargeBinary, nullable=True)
116
+ created_at = Column(DateTime, server_default=func.now())
117
+
118
+ # Relationships
119
+ document = relationship("Document", back_populates="file_content_ref")
120
+
121
+ class VectorStatus(Base):
122
+ __tablename__ = "vector_status"
123
+
124
+ id = Column(Integer, primary_key=True, index=True)
125
+ document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
126
+ vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
127
+ vector_id = Column(String, nullable=True)
128
+ status = Column(String, default="pending")
129
+ error_message = Column(String, nullable=True)
130
+ embedded_at = Column(DateTime, nullable=True)
131
+
132
+ # Relationships
133
+ document = relationship("Document", back_populates="vector_statuses")
134
+ vector_database = relationship("VectorDatabase", back_populates="vector_statuses")
135
+
136
+ class TelegramBot(Base):
137
+ __tablename__ = "telegram_bot"
138
+
139
+ id = Column(Integer, primary_key=True, index=True)
140
+ name = Column(String, nullable=False)
141
+ username = Column(String, nullable=False, unique=True)
142
+ token = Column(String, nullable=False)
143
+ status = Column(String, default="inactive")
144
+ created_at = Column(DateTime, server_default=func.now())
145
+ updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
146
+
147
+ # Relationships
148
+ bot_engines = relationship("BotEngine", back_populates="bot")
149
+
150
+ class ChatEngine(Base):
151
+ __tablename__ = "chat_engine"
152
+
153
+ id = Column(Integer, primary_key=True, index=True)
154
+ name = Column(String, nullable=False)
155
+ answer_model = Column(String, nullable=False)
156
+ system_prompt = Column(Text, nullable=True)
157
+ empty_response = Column(String, nullable=True)
158
+ similarity_top_k = Column(Integer, default=3)
159
+ vector_distance_threshold = Column(Float, default=0.75)
160
+ grounding_threshold = Column(Float, default=0.2)
161
+ use_public_information = Column(Boolean, default=False)
162
+ status = Column(String, default="active")
163
+ created_at = Column(DateTime, server_default=func.now())
164
+ last_modified = Column(DateTime, server_default=func.now(), onupdate=func.now())
165
+
166
+ # Relationships
167
+ bot_engines = relationship("BotEngine", back_populates="engine")
168
+ engine_vector_dbs = relationship("EngineVectorDb", back_populates="engine")
169
+
170
+ class BotEngine(Base):
171
+ __tablename__ = "bot_engine"
172
+
173
+ id = Column(Integer, primary_key=True, index=True)
174
+ bot_id = Column(Integer, ForeignKey("telegram_bot.id"), nullable=False)
175
+ engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
176
+ created_at = Column(DateTime, server_default=func.now())
177
+
178
+ # Relationships
179
+ bot = relationship("TelegramBot", back_populates="bot_engines")
180
+ engine = relationship("ChatEngine", back_populates="bot_engines")
181
+
182
+ class EngineVectorDb(Base):
183
+ __tablename__ = "engine_vector_db"
184
+
185
+ id = Column(Integer, primary_key=True, index=True)
186
+ engine_id = Column(Integer, ForeignKey("chat_engine.id"), nullable=False)
187
+ vector_database_id = Column(Integer, ForeignKey("vector_database.id"), nullable=False)
188
+ priority = Column(Integer, default=0)
189
+
190
+ # Relationships
191
+ engine = relationship("ChatEngine", back_populates="engine_vector_dbs")
192
+ vector_database = relationship("VectorDatabase", back_populates="engine_associations")
193
+
194
+ class ApiKey(Base):
195
+ __tablename__ = "api_key"
196
+
197
+ id = Column(Integer, primary_key=True, index=True)
198
+ key_type = Column(String, nullable=False)
199
+ key_value = Column(Text, nullable=False)
200
+ description = Column(Text, nullable=True)
201
+ created_at = Column(DateTime, server_default=func.now())
202
+ last_used = Column(DateTime, nullable=True)
203
+ expires_at = Column(DateTime, nullable=True)
204
+ is_active = Column(Boolean, default=True)
app/database/mongodb.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pymongo import MongoClient
3
+ from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
4
+ from dotenv import load_dotenv
5
+ from datetime import datetime, timedelta
6
+ import pytz
7
+ import logging
8
+
9
+ # Configure logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # MongoDB connection string from .env
16
+ MONGODB_URL = os.getenv("MONGODB_URL")
17
+ DB_NAME = os.getenv("DB_NAME", "Telegram")
18
+ COLLECTION_NAME = os.getenv("COLLECTION_NAME", "session_chat")
19
+
20
+ # Set timeout for MongoDB connection
21
+ MONGODB_TIMEOUT = int(os.getenv("MONGODB_TIMEOUT", "5000")) # 5 seconds by default
22
+
23
+ # Legacy cache settings - now only used for configuration purposes
24
+ HISTORY_CACHE_TTL = int(os.getenv("HISTORY_CACHE_TTL", "3600")) # 1 hour by default
25
+ HISTORY_QUEUE_SIZE = int(os.getenv("HISTORY_QUEUE_SIZE", "10")) # 10 items by default
26
+
27
+ # Create MongoDB connection with timeout
28
+ try:
29
+ client = MongoClient(MONGODB_URL, serverSelectionTimeoutMS=MONGODB_TIMEOUT)
30
+ db = client[DB_NAME]
31
+
32
+ # Collections
33
+ session_collection = db[COLLECTION_NAME]
34
+ logger.info(f"MongoDB connection initialized to {DB_NAME}.{COLLECTION_NAME}")
35
+
36
+ except Exception as e:
37
+ logger.error(f"Failed to initialize MongoDB connection: {e}")
38
+ # Don't raise exception to avoid crash during startup, error handling will be done in functions
39
+
40
+ # Check MongoDB connection
41
+ def check_db_connection():
42
+ """Check MongoDB connection"""
43
+ try:
44
+ # Issue a ping to confirm a successful connection
45
+ client.admin.command('ping')
46
+ logger.info("MongoDB connection is working")
47
+ return True
48
+ except (ConnectionFailure, ServerSelectionTimeoutError) as e:
49
+ logger.error(f"MongoDB connection failed: {e}")
50
+ return False
51
+ except Exception as e:
52
+ logger.error(f"Unknown error when checking MongoDB connection: {e}")
53
+ return False
54
+
55
+ # Timezone for Asia/Ho_Chi_Minh
56
+ asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
57
+
58
+ def get_local_time():
59
+ """Get current time in Asia/Ho_Chi_Minh timezone"""
60
+ return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
61
+
62
+ def get_local_datetime():
63
+ """Get current datetime object in Asia/Ho_Chi_Minh timezone"""
64
+ return datetime.now(asia_tz)
65
+
66
+ # For backward compatibility
67
+ get_vietnam_time = get_local_time
68
+ get_vietnam_datetime = get_local_datetime
69
+
70
+ # Utility functions
71
+ def save_session(session_id, factor, action, first_name, last_name, message, user_id, username, response=None):
72
+ """Save user session to MongoDB"""
73
+ try:
74
+ session_data = {
75
+ "session_id": session_id,
76
+ "factor": factor,
77
+ "action": action,
78
+ "created_at": get_local_time(),
79
+ "created_at_datetime": get_local_datetime(),
80
+ "first_name": first_name,
81
+ "last_name": last_name,
82
+ "message": message,
83
+ "user_id": user_id,
84
+ "username": username,
85
+ "response": response
86
+ }
87
+ result = session_collection.insert_one(session_data)
88
+ logger.info(f"Session saved with ID: {result.inserted_id}")
89
+
90
+ return {
91
+ "acknowledged": result.acknowledged,
92
+ "inserted_id": str(result.inserted_id),
93
+ "session_data": session_data
94
+ }
95
+ except Exception as e:
96
+ logger.error(f"Error saving session: {e}")
97
+ raise
98
+
99
+ def update_session_response(session_id, response):
100
+ """Update a session with response"""
101
+ try:
102
+ # Lấy session hiện có
103
+ existing_session = session_collection.find_one({"session_id": session_id})
104
+
105
+ if not existing_session:
106
+ logger.warning(f"No session found with ID: {session_id}")
107
+ return False
108
+
109
+ result = session_collection.update_one(
110
+ {"session_id": session_id},
111
+ {"$set": {"response": response}}
112
+ )
113
+
114
+ logger.info(f"Session {session_id} updated with response")
115
+ return True
116
+ except Exception as e:
117
+ logger.error(f"Error updating session response: {e}")
118
+ raise
119
+
120
+ def get_recent_sessions(user_id, action, n=3):
121
+ """Get n most recent sessions for a specific user and action"""
122
+ try:
123
+ # Truy vấn trực tiếp từ MongoDB
124
+ result = list(
125
+ session_collection.find(
126
+ {"user_id": user_id, "action": action},
127
+ {"_id": 0, "message": 1, "response": 1}
128
+ ).sort("created_at_datetime", -1).limit(n)
129
+ )
130
+
131
+ logger.debug(f"Retrieved {len(result)} recent sessions for user {user_id}, action {action}")
132
+ return result
133
+ except Exception as e:
134
+ logger.error(f"Error getting recent sessions: {e}")
135
+ return []
136
+
137
+ def get_chat_history(user_id, n = 5) -> str:
138
+ """
139
+ Lấy lịch sử chat cho user_id từ MongoDB và ghép thành chuỗi theo định dạng:
140
+
141
+ User: ...
142
+ Bot: ...
143
+ User: ...
144
+ Bot: ...
145
+
146
+ Chỉ lấy history sau lệnh /start hoặc /clear mới nhất
147
+ """
148
+ try:
149
+ # Tìm session /start hoặc /clear mới nhất
150
+ reset_session = session_collection.find_one(
151
+ {
152
+ "user_id": str(user_id),
153
+ "$or": [
154
+ {"action": "start"},
155
+ {"action": "clear"}
156
+ ]
157
+ },
158
+ sort=[("created_at_datetime", -1)]
159
+ )
160
+
161
+ # Nếu không tìm thấy session reset nào, lấy n session gần nhất
162
+ if reset_session:
163
+ reset_time = reset_session["created_at_datetime"]
164
+ # Lấy các session sau reset_time
165
+ docs = list(
166
+ session_collection.find({
167
+ "user_id": str(user_id),
168
+ "created_at_datetime": {"$gt": reset_time}
169
+ }).sort("created_at_datetime", 1)
170
+ )
171
+ logger.info(f"Lấy {len(docs)} session sau lệnh {reset_session['action']} lúc {reset_time}")
172
+ else:
173
+ # Không tìm thấy reset session, lấy n session gần nhất
174
+ docs = list(session_collection.find({"user_id": str(user_id)}).sort("created_at", -1).limit(n))
175
+ # Đảo ngược để có thứ tự từ cũ đến mới
176
+ docs.reverse()
177
+ logger.info(f"Không tìm thấy session reset, lấy {len(docs)} session gần nhất")
178
+
179
+ if not docs:
180
+ logger.info(f"Không tìm thấy dữ liệu cho user_id: {user_id}")
181
+ return ""
182
+
183
+ conversation_lines = []
184
+ # Xử lý từng document theo cấu trúc mới
185
+ for doc in docs:
186
+ factor = doc.get("factor", "").lower()
187
+ action = doc.get("action", "").lower()
188
+ message = doc.get("message", "")
189
+ response = doc.get("response", "")
190
+
191
+ # Bỏ qua lệnh start và clear
192
+ if action in ["start", "clear"]:
193
+ continue
194
+
195
+ if factor == "user" and action == "asking_freely":
196
+ conversation_lines.append(f"User: {message}")
197
+ conversation_lines.append(f"Bot: {response}")
198
+
199
+ # Ghép các dòng thành chuỗi
200
+ return "\n".join(conversation_lines)
201
+ except Exception as e:
202
+ logger.error(f"Lỗi khi lấy lịch sử chat cho user_id {user_id}: {e}")
203
+ return ""
204
+
205
+ def get_request_history(user_id, n=3):
206
+ """Get the most recent user requests to use as context for retrieval"""
207
+ try:
208
+ # Truy vấn trực tiếp từ MongoDB
209
+ history = get_chat_history(user_id, n)
210
+
211
+ # Just extract the questions for context
212
+ requests = []
213
+ for line in history.split('\n'):
214
+ if line.startswith("User: "):
215
+ requests.append(line[6:]) # Lấy nội dung sau "User: "
216
+
217
+ # Join all recent requests into a single string for context
218
+ return " ".join(requests)
219
+ except Exception as e:
220
+ logger.error(f"Error getting request history: {e}")
221
+ return ""
app/database/pinecone.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pinecone import Pinecone
3
+ from dotenv import load_dotenv
4
+ import logging
5
+ from typing import Optional, List, Dict, Any, Union, Tuple
6
+ import time
7
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
8
+ import google.generativeai as genai
9
+ from langchain_core.retrievers import BaseRetriever
10
+ from langchain.callbacks.manager import Callbacks
11
+ from langchain_core.documents import Document
12
+ from langchain_core.pydantic_v1 import Field
13
+
14
+ # Configure logging
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Load environment variables
18
+ load_dotenv()
19
+
20
+ # Pinecone API key and index name
21
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
22
+ PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
23
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
24
+
25
+ # Pinecone retrieval configuration
26
+ DEFAULT_LIMIT_K = int(os.getenv("PINECONE_DEFAULT_LIMIT_K", "10"))
27
+ DEFAULT_TOP_K = int(os.getenv("PINECONE_DEFAULT_TOP_K", "6"))
28
+ DEFAULT_SIMILARITY_METRIC = os.getenv("PINECONE_DEFAULT_SIMILARITY_METRIC", "cosine")
29
+ DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("PINECONE_DEFAULT_SIMILARITY_THRESHOLD", "0.75"))
30
+ ALLOWED_METRICS = os.getenv("PINECONE_ALLOWED_METRICS", "cosine,dotproduct,euclidean").split(",")
31
+
32
+ # Export constants for importing elsewhere
33
+ __all__ = [
34
+ 'get_pinecone_index',
35
+ 'check_db_connection',
36
+ 'search_vectors',
37
+ 'upsert_vectors',
38
+ 'delete_vectors',
39
+ 'fetch_metadata',
40
+ 'get_chain',
41
+ 'DEFAULT_TOP_K',
42
+ 'DEFAULT_LIMIT_K',
43
+ 'DEFAULT_SIMILARITY_METRIC',
44
+ 'DEFAULT_SIMILARITY_THRESHOLD',
45
+ 'ALLOWED_METRICS',
46
+ 'ThresholdRetriever'
47
+ ]
48
+
49
+ # Configure Google API
50
+ if GOOGLE_API_KEY:
51
+ genai.configure(api_key=GOOGLE_API_KEY)
52
+
53
+ # Initialize global variables to store instances of Pinecone and index
54
+ pc = None
55
+ index = None
56
+ _retriever_instance = None
57
+
58
+ # Check environment variables
59
+ if not PINECONE_API_KEY:
60
+ logger.error("PINECONE_API_KEY is not set in environment variables")
61
+
62
+ if not PINECONE_INDEX_NAME:
63
+ logger.error("PINECONE_INDEX_NAME is not set in environment variables")
64
+
65
+ # Initialize Pinecone
66
+ def init_pinecone():
67
+ """Initialize pinecone connection using new API"""
68
+ global pc, index
69
+
70
+ try:
71
+ # Only initialize if not already initialized
72
+ if pc is None:
73
+ logger.info(f"Initializing Pinecone connection to index {PINECONE_INDEX_NAME}...")
74
+
75
+ # Check if API key and index name are set
76
+ if not PINECONE_API_KEY:
77
+ logger.error("PINECONE_API_KEY is not set in environment variables")
78
+ return None
79
+
80
+ if not PINECONE_INDEX_NAME:
81
+ logger.error("PINECONE_INDEX_NAME is not set in environment variables")
82
+ return None
83
+
84
+ # Initialize Pinecone client using the new API
85
+ pc = Pinecone(api_key=PINECONE_API_KEY)
86
+
87
+ try:
88
+ # Check if index exists
89
+ index_list = pc.list_indexes()
90
+
91
+ if not hasattr(index_list, 'names') or PINECONE_INDEX_NAME not in index_list.names():
92
+ logger.error(f"Index {PINECONE_INDEX_NAME} does not exist in Pinecone")
93
+ return None
94
+
95
+ # Get existing index
96
+ index = pc.Index(PINECONE_INDEX_NAME)
97
+ logger.info(f"Pinecone connection established to index {PINECONE_INDEX_NAME}")
98
+ except Exception as connection_error:
99
+ logger.error(f"Error connecting to Pinecone index: {connection_error}")
100
+ return None
101
+
102
+ return index
103
+ except ImportError as e:
104
+ logger.error(f"Required package for Pinecone is missing: {e}")
105
+ return None
106
+ except Exception as e:
107
+ logger.error(f"Unexpected error initializing Pinecone: {e}")
108
+ return None
109
+
110
+ # Get Pinecone index singleton
111
+ def get_pinecone_index():
112
+ """Get Pinecone index"""
113
+ global index
114
+ if index is None:
115
+ index = init_pinecone()
116
+ return index
117
+
118
+ # Check Pinecone connection
119
+ def check_db_connection():
120
+ """Check Pinecone connection"""
121
+ try:
122
+ pinecone_index = get_pinecone_index()
123
+ if pinecone_index is None:
124
+ return False
125
+
126
+ # Check index information to confirm connection is working
127
+ stats = pinecone_index.describe_index_stats()
128
+
129
+ # Get total vector count from the new result structure
130
+ total_vectors = stats.get('total_vector_count', 0)
131
+ if hasattr(stats, 'namespaces'):
132
+ # If there are namespaces, calculate total vector count from namespaces
133
+ total_vectors = sum(ns.get('vector_count', 0) for ns in stats.namespaces.values())
134
+
135
+ logger.info(f"Pinecone connection is working. Total vectors: {total_vectors}")
136
+ return True
137
+ except Exception as e:
138
+ logger.error(f"Error in Pinecone connection: {e}")
139
+ return False
140
+
141
+ # Convert similarity score based on the metric
142
+ def convert_score(score: float, metric: str) -> float:
143
+ """
144
+ Convert similarity score to a 0-1 scale based on the metric used.
145
+ For metrics like euclidean distance where lower is better, we invert the score.
146
+
147
+ Args:
148
+ score: The raw similarity score
149
+ metric: The similarity metric used
150
+
151
+ Returns:
152
+ A normalized score between 0-1 where higher means more similar
153
+ """
154
+ if metric.lower() in ["euclidean", "l2"]:
155
+ # For distance metrics (lower is better), we inverse and normalize
156
+ # Assuming max reasonable distance is 2.0 for normalized vectors
157
+ return max(0, 1 - (score / 2.0))
158
+ else:
159
+ # For cosine and dot product (higher is better), return as is
160
+ return score
161
+
162
+ # Filter results based on similarity threshold
163
+ def filter_by_threshold(results, threshold: float, metric: str) -> List[Dict]:
164
+ """
165
+ Filter query results based on similarity threshold.
166
+
167
+ Args:
168
+ results: The query results from Pinecone
169
+ threshold: The similarity threshold (0-1)
170
+ metric: The similarity metric used
171
+
172
+ Returns:
173
+ Filtered list of matches
174
+ """
175
+ filtered_matches = []
176
+
177
+ if not hasattr(results, 'matches'):
178
+ return filtered_matches
179
+
180
+ for match in results.matches:
181
+ # Get the score
182
+ score = getattr(match, 'score', 0)
183
+
184
+ # Convert score based on metric
185
+ normalized_score = convert_score(score, metric)
186
+
187
+ # Filter based on threshold
188
+ if normalized_score >= threshold:
189
+ # Add normalized score as an additional attribute
190
+ match.normalized_score = normalized_score
191
+ filtered_matches.append(match)
192
+
193
+ return filtered_matches
194
+
195
+ # Search vectors in Pinecone with advanced options
196
+ async def search_vectors(
197
+ query_vector,
198
+ top_k: int = DEFAULT_TOP_K,
199
+ limit_k: int = DEFAULT_LIMIT_K,
200
+ similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
201
+ similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
202
+ namespace: str = "Default",
203
+ filter: Optional[Dict] = None
204
+ ) -> Dict:
205
+ """
206
+ Search for most similar vectors in Pinecone with advanced filtering options.
207
+
208
+ Args:
209
+ query_vector: The query vector
210
+ top_k: Number of results to return (after threshold filtering)
211
+ limit_k: Maximum number of results to retrieve from Pinecone
212
+ similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
213
+ similarity_threshold: Threshold for similarity (0-1)
214
+ namespace: Namespace to search in
215
+ filter: Filter query
216
+
217
+ Returns:
218
+ Search results with matches filtered by threshold
219
+ """
220
+ try:
221
+ # Validate parameters
222
+ if similarity_metric not in ALLOWED_METRICS:
223
+ logger.warning(f"Invalid similarity metric: {similarity_metric}. Using default: {DEFAULT_SIMILARITY_METRIC}")
224
+ similarity_metric = DEFAULT_SIMILARITY_METRIC
225
+
226
+ if limit_k < top_k:
227
+ logger.warning(f"limit_k ({limit_k}) must be greater than or equal to top_k ({top_k}). Setting limit_k to {top_k}")
228
+ limit_k = top_k
229
+
230
+ # Perform search directly without cache
231
+ pinecone_index = get_pinecone_index()
232
+ if pinecone_index is None:
233
+ logger.error("Failed to get Pinecone index for search")
234
+ return None
235
+
236
+ # Query Pinecone with the provided metric and higher limit_k to allow for threshold filtering
237
+ results = pinecone_index.query(
238
+ vector=query_vector,
239
+ top_k=limit_k, # Retrieve more results than needed to allow for threshold filtering
240
+ namespace=namespace,
241
+ filter=filter,
242
+ include_metadata=True,
243
+ include_values=False, # No need to return vector values to save bandwidth
244
+ metric=similarity_metric # Specify similarity metric
245
+ )
246
+
247
+ # Filter results by threshold
248
+ filtered_matches = filter_by_threshold(results, similarity_threshold, similarity_metric)
249
+
250
+ # Limit to top_k after filtering
251
+ filtered_matches = filtered_matches[:top_k]
252
+
253
+ # Create a new results object with filtered matches
254
+ results.matches = filtered_matches
255
+
256
+ # Log search result metrics
257
+ match_count = len(filtered_matches)
258
+ logger.info(f"Pinecone search returned {match_count} matches after threshold filtering (metric: {similarity_metric}, threshold: {similarity_threshold}, namespace: {namespace})")
259
+
260
+ return results
261
+ except Exception as e:
262
+ logger.error(f"Error searching vectors: {e}")
263
+ return None
264
+
265
+ # Upsert vectors to Pinecone
266
+ async def upsert_vectors(vectors, namespace="Default"):
267
+ """Upsert vectors to Pinecone index"""
268
+ try:
269
+ pinecone_index = get_pinecone_index()
270
+ if pinecone_index is None:
271
+ logger.error("Failed to get Pinecone index for upsert")
272
+ return None
273
+
274
+ response = pinecone_index.upsert(
275
+ vectors=vectors,
276
+ namespace=namespace
277
+ )
278
+
279
+ # Log upsert metrics
280
+ upserted_count = response.get('upserted_count', 0)
281
+ logger.info(f"Upserted {upserted_count} vectors to Pinecone")
282
+
283
+ return response
284
+ except Exception as e:
285
+ logger.error(f"Error upserting vectors: {e}")
286
+ return None
287
+
288
+ # Delete vectors from Pinecone
289
+ async def delete_vectors(ids, namespace="Default"):
290
+ """Delete vectors from Pinecone index"""
291
+ try:
292
+ pinecone_index = get_pinecone_index()
293
+ if pinecone_index is None:
294
+ logger.error("Failed to get Pinecone index for delete")
295
+ return False
296
+
297
+ response = pinecone_index.delete(
298
+ ids=ids,
299
+ namespace=namespace
300
+ )
301
+
302
+ logger.info(f"Deleted vectors with IDs {ids} from Pinecone")
303
+ return True
304
+ except Exception as e:
305
+ logger.error(f"Error deleting vectors: {e}")
306
+ return False
307
+
308
+ # Fetch vector metadata from Pinecone
309
+ async def fetch_metadata(ids, namespace="Default"):
310
+ """Fetch metadata for specific vector IDs"""
311
+ try:
312
+ pinecone_index = get_pinecone_index()
313
+ if pinecone_index is None:
314
+ logger.error("Failed to get Pinecone index for fetch")
315
+ return None
316
+
317
+ response = pinecone_index.fetch(
318
+ ids=ids,
319
+ namespace=namespace
320
+ )
321
+
322
+ return response
323
+ except Exception as e:
324
+ logger.error(f"Error fetching vector metadata: {e}")
325
+ return None
326
+
327
+ # Create a custom retriever class for Langchain integration
328
+ class ThresholdRetriever(BaseRetriever):
329
+ """
330
+ Custom retriever that supports threshold-based filtering and multiple similarity metrics.
331
+ This integrates with the Langchain ecosystem while using our advanced retrieval logic.
332
+ """
333
+
334
+ vectorstore: Any = Field(description="Vector store to use for retrieval")
335
+ embeddings: Any = Field(description="Embeddings model to use for retrieval")
336
+ search_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Search kwargs for the vectorstore")
337
+ top_k: int = Field(default=DEFAULT_TOP_K, description="Number of results to return after filtering")
338
+ limit_k: int = Field(default=DEFAULT_LIMIT_K, description="Maximum number of results to retrieve from Pinecone")
339
+ similarity_metric: str = Field(default=DEFAULT_SIMILARITY_METRIC, description="Similarity metric to use")
340
+ similarity_threshold: float = Field(default=DEFAULT_SIMILARITY_THRESHOLD, description="Threshold for similarity")
341
+ namespace: str = "Default"
342
+
343
+ class Config:
344
+ """Configuration for this pydantic object."""
345
+ arbitrary_types_allowed = True
346
+
347
+ async def search_vectors_sync(
348
+ self, query_vector,
349
+ top_k: int = DEFAULT_TOP_K,
350
+ limit_k: int = DEFAULT_LIMIT_K,
351
+ similarity_metric: str = DEFAULT_SIMILARITY_METRIC,
352
+ similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
353
+ namespace: str = "Default",
354
+ filter: Optional[Dict] = None
355
+ ) -> Dict:
356
+ """Synchronous wrapper for search_vectors"""
357
+ import asyncio
358
+ try:
359
+ # Get current event loop or create a new one
360
+ try:
361
+ loop = asyncio.get_event_loop()
362
+ except RuntimeError:
363
+ loop = asyncio.new_event_loop()
364
+ asyncio.set_event_loop(loop)
365
+
366
+ # Use event loop to run async function
367
+ if loop.is_running():
368
+ # If we're in an event loop, use asyncio.create_task
369
+ task = asyncio.create_task(search_vectors(
370
+ query_vector=query_vector,
371
+ top_k=top_k,
372
+ limit_k=limit_k,
373
+ similarity_metric=similarity_metric,
374
+ similarity_threshold=similarity_threshold,
375
+ namespace=namespace,
376
+ filter=filter
377
+ ))
378
+ return await task
379
+ else:
380
+ # If not in an event loop, just await directly
381
+ return await search_vectors(
382
+ query_vector=query_vector,
383
+ top_k=top_k,
384
+ limit_k=limit_k,
385
+ similarity_metric=similarity_metric,
386
+ similarity_threshold=similarity_threshold,
387
+ namespace=namespace,
388
+ filter=filter
389
+ )
390
+ except Exception as e:
391
+ logger.error(f"Error in search_vectors_sync: {e}")
392
+ return None
393
+
394
+ def _get_relevant_documents(
395
+ self, query: str, *, run_manager: Callbacks = None
396
+ ) -> List[Document]:
397
+ """
398
+ Get documents relevant to the query using threshold-based retrieval.
399
+
400
+ Args:
401
+ query: The query string
402
+ run_manager: The callbacks manager
403
+
404
+ Returns:
405
+ List of relevant documents
406
+ """
407
+ # Generate embedding for query using the embeddings model
408
+ try:
409
+ # Use the embeddings model we stored in the class
410
+ embedding = self.embeddings.embed_query(query)
411
+ except Exception as e:
412
+ logger.error(f"Error generating embedding: {e}")
413
+ # Fallback to creating a new embedding model if needed
414
+ embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
415
+ embedding = embedding_model.embed_query(query)
416
+
417
+ # Perform search with advanced options - avoid asyncio.run()
418
+ import asyncio
419
+
420
+ # Get or create event loop
421
+ try:
422
+ loop = asyncio.get_event_loop()
423
+ except RuntimeError:
424
+ loop = asyncio.new_event_loop()
425
+ asyncio.set_event_loop(loop)
426
+
427
+ # Run asynchronous search in a safe way
428
+ if loop.is_running():
429
+ # We're inside an existing event loop (like in FastAPI)
430
+ # Use a different approach - convert it to a synchronous call
431
+ from concurrent.futures import ThreadPoolExecutor
432
+ import functools
433
+
434
+ # Define a wrapper function to run in a thread
435
+ def run_async_in_thread():
436
+ # Create a new event loop for this thread
437
+ thread_loop = asyncio.new_event_loop()
438
+ asyncio.set_event_loop(thread_loop)
439
+ # Run the coroutine and return the result
440
+ return thread_loop.run_until_complete(search_vectors(
441
+ query_vector=embedding,
442
+ top_k=self.top_k,
443
+ limit_k=self.limit_k,
444
+ similarity_metric=self.similarity_metric,
445
+ similarity_threshold=self.similarity_threshold,
446
+ namespace=self.namespace,
447
+ # filter=self.search_kwargs.get("filter", None)
448
+ ))
449
+
450
+ # Run the async function in a thread
451
+ with ThreadPoolExecutor() as executor:
452
+ search_result = executor.submit(run_async_in_thread).result()
453
+ else:
454
+ # No event loop running, we can use run_until_complete
455
+ search_result = loop.run_until_complete(search_vectors(
456
+ query_vector=embedding,
457
+ top_k=self.top_k,
458
+ limit_k=self.limit_k,
459
+ similarity_metric=self.similarity_metric,
460
+ similarity_threshold=self.similarity_threshold,
461
+ namespace=self.namespace,
462
+ # filter=self.search_kwargs.get("filter", None)
463
+ ))
464
+
465
+ # Convert to documents
466
+ documents = []
467
+ if search_result and hasattr(search_result, 'matches'):
468
+ for match in search_result.matches:
469
+ # Extract metadata
470
+ metadata = {}
471
+ if hasattr(match, 'metadata'):
472
+ metadata = match.metadata
473
+
474
+ # Add score to metadata
475
+ score = getattr(match, 'score', 0)
476
+ normalized_score = getattr(match, 'normalized_score', score)
477
+ metadata['score'] = score
478
+ metadata['normalized_score'] = normalized_score
479
+
480
+ # Extract text
481
+ text = metadata.get('text', '')
482
+ if 'text' in metadata:
483
+ del metadata['text'] # Remove from metadata since it's the content
484
+
485
+ # Create Document
486
+ doc = Document(
487
+ page_content=text,
488
+ metadata=metadata
489
+ )
490
+ documents.append(doc)
491
+
492
+ return documents
493
+
494
+ # Get the retrieval chain with Pinecone vector store
495
+ def get_chain(
496
+ index_name=PINECONE_INDEX_NAME,
497
+ namespace="Default",
498
+ top_k=DEFAULT_TOP_K,
499
+ limit_k=DEFAULT_LIMIT_K,
500
+ similarity_metric=DEFAULT_SIMILARITY_METRIC,
501
+ similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD
502
+ ):
503
+ """
504
+ Get the retrieval chain with Pinecone vector store using threshold-based retrieval.
505
+
506
+ Args:
507
+ index_name: Pinecone index name
508
+ namespace: Pinecone namespace
509
+ top_k: Number of results to return after filtering
510
+ limit_k: Maximum number of results to retrieve from Pinecone
511
+ similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean)
512
+ similarity_threshold: Threshold for similarity (0-1)
513
+
514
+ Returns:
515
+ ThresholdRetriever instance
516
+ """
517
+ global _retriever_instance
518
+ try:
519
+ # If already initialized with same parameters, return cached instance
520
+ if _retriever_instance is not None:
521
+ return _retriever_instance
522
+
523
+ start_time = time.time()
524
+ logger.info("Initializing new retriever chain with threshold-based filtering")
525
+
526
+ # Initialize embeddings model
527
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
528
+
529
+ # Get index
530
+ pinecone_index = get_pinecone_index()
531
+ if not pinecone_index:
532
+ logger.error("Failed to get Pinecone index for retriever chain")
533
+ return None
534
+
535
+ # Get statistics for logging
536
+ try:
537
+ stats = pinecone_index.describe_index_stats()
538
+ total_vectors = stats.get('total_vector_count', 0)
539
+ logger.info(f"Pinecone index stats - Total vectors: {total_vectors}")
540
+ except Exception as e:
541
+ logger.error(f"Error getting index stats: {e}")
542
+
543
+ # Use Pinecone from langchain_community.vectorstores
544
+ from langchain_community.vectorstores import Pinecone as LangchainPinecone
545
+
546
+ logger.info(f"Creating Pinecone vectorstore with index: {index_name}, namespace: {namespace}")
547
+ vectorstore = LangchainPinecone.from_existing_index(
548
+ embedding=embeddings,
549
+ index_name=index_name,
550
+ namespace=namespace,
551
+ text_key="text"
552
+ )
553
+
554
+ # Create threshold-based retriever
555
+ logger.info(f"Creating ThresholdRetriever with top_k={top_k}, limit_k={limit_k}, " +
556
+ f"metric={similarity_metric}, threshold={similarity_threshold}")
557
+
558
+ # Create ThresholdRetriever with both vectorstore and embeddings
559
+ _retriever_instance = ThresholdRetriever(
560
+ vectorstore=vectorstore,
561
+ embeddings=embeddings, # Pass embeddings separately
562
+ top_k=top_k,
563
+ limit_k=limit_k,
564
+ similarity_metric=similarity_metric,
565
+ similarity_threshold=similarity_threshold
566
+ )
567
+
568
+ logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds")
569
+
570
+ return _retriever_instance
571
+ except Exception as e:
572
+ logger.error(f"Error creating retrieval chain: {e}")
573
+ return None
app/database/postgresql.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sqlalchemy import create_engine, text
3
+ from sqlalchemy.ext.declarative import declarative_base
4
+ from sqlalchemy.orm import sessionmaker
5
+ from sqlalchemy.exc import SQLAlchemyError, OperationalError
6
+ from dotenv import load_dotenv
7
+ import logging
8
+
9
+ # Configure logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Get DB connection mode from environment
16
+ DB_CONNECTION_MODE = os.getenv("DB_CONNECTION_MODE", "aiven")
17
+
18
+ # Set connection string based on mode
19
+ if DB_CONNECTION_MODE == "aiven":
20
+ DATABASE_URL = os.getenv("AIVEN_DB_URL")
21
+ else:
22
+ # Default or other connection modes can be added here
23
+ DATABASE_URL = os.getenv("AIVEN_DB_URL")
24
+
25
+ if not DATABASE_URL:
26
+ logger.error("No database URL configured. Please set AIVEN_DB_URL environment variable.")
27
+ DATABASE_URL = "postgresql://localhost/test" # Fallback to avoid crash on startup
28
+
29
+ # Create SQLAlchemy engine with optimized settings
30
+ try:
31
+ engine = create_engine(
32
+ DATABASE_URL,
33
+ pool_size=10, # Limit max connections
34
+ max_overflow=5, # Allow temporary overflow of connections
35
+ pool_timeout=30, # Timeout waiting for connection from pool
36
+ pool_recycle=300, # Recycle connections every 5 minutes
37
+ pool_pre_ping=True, # Verify connection is still valid before using it
38
+ connect_args={
39
+ "connect_timeout": 5, # Connection timeout in seconds
40
+ "keepalives": 1, # Enable TCP keepalives
41
+ "keepalives_idle": 30, # Time before sending keepalives
42
+ "keepalives_interval": 10, # Time between keepalives
43
+ "keepalives_count": 5, # Number of keepalive probes
44
+ "application_name": "pixagent_api" # Identify app in PostgreSQL logs
45
+ },
46
+ # Performance optimizations
47
+ isolation_level="READ COMMITTED", # Lower isolation level for better performance
48
+ echo=False, # Disable SQL echo to reduce overhead
49
+ echo_pool=False, # Disable pool logging
50
+ future=True, # Use SQLAlchemy 2.0 features
51
+ # Execution options for common queries
52
+ execution_options={
53
+ "compiled_cache": {}, # Use an empty dict for compiled query caching
54
+ "logging_token": "SQL", # Tag for query logging
55
+ }
56
+ )
57
+ logger.info("PostgreSQL engine initialized with optimized settings")
58
+ except Exception as e:
59
+ logger.error(f"Failed to initialize PostgreSQL engine: {e}")
60
+ # Don't raise exception to avoid crash on startup
61
+
62
+ # Create optimized session factory
63
+ SessionLocal = sessionmaker(
64
+ autocommit=False,
65
+ autoflush=False,
66
+ bind=engine,
67
+ expire_on_commit=False # Prevent automatic reloading after commit
68
+ )
69
+
70
+ # Base class for declarative models - use sqlalchemy.orm for SQLAlchemy 2.0 compatibility
71
+ from sqlalchemy.orm import declarative_base
72
+ Base = declarative_base()
73
+
74
+ # Check PostgreSQL connection
75
+ def check_db_connection():
76
+ """Check PostgreSQL connection status"""
77
+ try:
78
+ # Simple query to verify connection
79
+ with engine.connect() as connection:
80
+ connection.execute(text("SELECT 1")).fetchone()
81
+ logger.info("PostgreSQL connection successful")
82
+ return True
83
+ except OperationalError as e:
84
+ logger.error(f"PostgreSQL connection failed: {e}")
85
+ return False
86
+ except Exception as e:
87
+ logger.error(f"Unknown error checking PostgreSQL connection: {e}")
88
+ return False
89
+
90
+ # Dependency to get DB session with improved error handling
91
+ def get_db():
92
+ """Get PostgreSQL database session"""
93
+ db = SessionLocal()
94
+ try:
95
+ # Test connection
96
+ db.execute(text("SELECT 1")).fetchone()
97
+ yield db
98
+ except Exception as e:
99
+ logger.error(f"DB connection error: {e}")
100
+ raise
101
+ finally:
102
+ db.close() # Ensure connection is closed and returned to pool
103
+
104
+ # Create tables in database if they don't exist
105
+ def create_tables():
106
+ """Create tables in database"""
107
+ try:
108
+ Base.metadata.create_all(bind=engine)
109
+ logger.info("Database tables created or already exist")
110
+ return True
111
+ except SQLAlchemyError as e:
112
+ logger.error(f"Failed to create database tables (SQLAlchemy error): {e}")
113
+ return False
114
+ except Exception as e:
115
+ logger.error(f"Failed to create database tables (unexpected error): {e}")
116
+ return False
117
+
118
+ # Function to create indexes for better performance
119
+ def create_indexes():
120
+ """Create indexes for better query performance"""
121
+ try:
122
+ with engine.connect() as conn:
123
+ try:
124
+ # Index for featured events - use try-except to handle if index already exists
125
+ conn.execute(text("""
126
+ CREATE INDEX idx_event_featured
127
+ ON event_item(featured)
128
+ """))
129
+ except SQLAlchemyError:
130
+ logger.info("Index idx_event_featured already exists")
131
+
132
+ try:
133
+ # Index for active events
134
+ conn.execute(text("""
135
+ CREATE INDEX idx_event_active
136
+ ON event_item(is_active)
137
+ """))
138
+ except SQLAlchemyError:
139
+ logger.info("Index idx_event_active already exists")
140
+
141
+ try:
142
+ # Index for date filtering
143
+ conn.execute(text("""
144
+ CREATE INDEX idx_event_date_start
145
+ ON event_item(date_start)
146
+ """))
147
+ except SQLAlchemyError:
148
+ logger.info("Index idx_event_date_start already exists")
149
+
150
+ try:
151
+ # Composite index for combined filtering
152
+ conn.execute(text("""
153
+ CREATE INDEX idx_event_featured_active
154
+ ON event_item(featured, is_active)
155
+ """))
156
+ except SQLAlchemyError:
157
+ logger.info("Index idx_event_featured_active already exists")
158
+
159
+ # Indexes for FAQ and Emergency tables
160
+ try:
161
+ # FAQ active flag index
162
+ conn.execute(text("""
163
+ CREATE INDEX idx_faq_active
164
+ ON faq_item(is_active)
165
+ """))
166
+ except SQLAlchemyError:
167
+ logger.info("Index idx_faq_active already exists")
168
+
169
+ try:
170
+ # Emergency contact active flag and priority indexes
171
+ conn.execute(text("""
172
+ CREATE INDEX idx_emergency_active
173
+ ON emergency_item(is_active)
174
+ """))
175
+ except SQLAlchemyError:
176
+ logger.info("Index idx_emergency_active already exists")
177
+
178
+ try:
179
+ conn.execute(text("""
180
+ CREATE INDEX idx_emergency_priority
181
+ ON emergency_item(priority)
182
+ """))
183
+ except SQLAlchemyError:
184
+ logger.info("Index idx_emergency_priority already exists")
185
+
186
+ conn.commit()
187
+
188
+ logger.info("Database indexes created or verified")
189
+ return True
190
+ except SQLAlchemyError as e:
191
+ logger.error(f"Failed to create indexes: {e}")
192
+ return False
app/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Pydantic models package
app/models/mongodb_models.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, ConfigDict
2
+ from typing import Optional, List, Dict, Any
3
+ from datetime import datetime
4
+ import uuid
5
+
6
+ class SessionBase(BaseModel):
7
+ """Base model for session data"""
8
+ session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
9
+ factor: str
10
+ action: str
11
+ first_name: str
12
+ last_name: Optional[str] = None
13
+ message: Optional[str] = None
14
+ user_id: str
15
+ username: Optional[str] = None
16
+
17
+ class SessionCreate(SessionBase):
18
+ """Model for creating new session"""
19
+ response: Optional[str] = None
20
+
21
+ class SessionResponse(SessionBase):
22
+ """Response model for session data"""
23
+ created_at: str
24
+ response: Optional[str] = None
25
+
26
+ model_config = ConfigDict(
27
+ json_schema_extra={
28
+ "example": {
29
+ "session_id": "123e4567-e89b-12d3-a456-426614174000",
30
+ "factor": "user",
31
+ "action": "asking_freely",
32
+ "created_at": "2023-06-01 14:30:45",
33
+ "first_name": "John",
34
+ "last_name": "Doe",
35
+ "message": "How can I find emergency contacts?",
36
+ "user_id": "12345678",
37
+ "username": "johndoe",
38
+ "response": "You can find emergency contacts in the Emergency section..."
39
+ }
40
+ }
41
+ )
42
+
43
+ class HistoryRequest(BaseModel):
44
+ """Request model for history"""
45
+ user_id: str
46
+ n: int = 3
47
+
48
+ class QuestionAnswer(BaseModel):
49
+ """Model for question-answer pair"""
50
+ question: str
51
+ answer: str
52
+
53
+ class HistoryResponse(BaseModel):
54
+ """Response model for history"""
55
+ history: List[QuestionAnswer]
app/models/pdf_models.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Dict, Any
3
+
4
+ class PDFUploadRequest(BaseModel):
5
+ """Request model cho upload PDF"""
6
+ namespace: Optional[str] = Field("Default", description="Namespace trong Pinecone")
7
+ index_name: Optional[str] = Field("testbot768", description="Tên index trong Pinecone")
8
+ title: Optional[str] = Field(None, description="Tiêu đề của tài liệu")
9
+ description: Optional[str] = Field(None, description="Mô tả về tài liệu")
10
+ vector_database_id: Optional[int] = Field(None, description="ID của vector database trong PostgreSQL để sử dụng")
11
+
12
+ class PDFResponse(BaseModel):
13
+ """Response model cho xử lý PDF"""
14
+ success: bool = Field(..., description="Trạng thái xử lý thành công hay không")
15
+ document_id: Optional[str] = Field(None, description="ID của tài liệu")
16
+ chunks_processed: Optional[int] = Field(None, description="Số lượng chunks đã xử lý")
17
+ total_text_length: Optional[int] = Field(None, description="Tổng độ dài văn bản")
18
+ error: Optional[str] = Field(None, description="Thông báo lỗi nếu có")
19
+
20
+ class Config:
21
+ schema_extra = {
22
+ "example": {
23
+ "success": True,
24
+ "document_id": "550e8400-e29b-41d4-a716-446655440000",
25
+ "chunks_processed": 25,
26
+ "total_text_length": 50000
27
+ }
28
+ }
29
+
30
+ class DeleteDocumentRequest(BaseModel):
31
+ """Request model cho xóa document"""
32
+ document_id: str = Field(..., description="ID của tài liệu cần xóa")
33
+ namespace: Optional[str] = Field("Default", description="Namespace trong Pinecone")
34
+ index_name: Optional[str] = Field("testbot768", description="Tên index trong Pinecone")
35
+
36
+ class DocumentsListResponse(BaseModel):
37
+ """Response model cho lấy danh sách tài liệu"""
38
+ success: bool = Field(..., description="Trạng thái xử lý thành công hay không")
39
+ total_vectors: Optional[int] = Field(None, description="Tổng số vectors trong index")
40
+ namespace: Optional[str] = Field(None, description="Namespace đang sử dụng")
41
+ index_name: Optional[str] = Field(None, description="Tên index đang sử dụng")
42
+ error: Optional[str] = Field(None, description="Thông báo lỗi nếu có")
43
+
44
+ class Config:
45
+ schema_extra = {
46
+ "example": {
47
+ "success": True,
48
+ "total_vectors": 5000,
49
+ "namespace": "Default",
50
+ "index_name": "testbot768"
51
+ }
52
+ }
app/models/rag_models.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Dict, Any
3
+
4
+ class ChatRequest(BaseModel):
5
+ """Request model for chat endpoint"""
6
+ user_id: str = Field(..., description="User ID from Telegram")
7
+ question: str = Field(..., description="User's question")
8
+ include_history: bool = Field(True, description="Whether to include user history in prompt")
9
+ use_rag: bool = Field(True, description="Whether to use RAG")
10
+
11
+ # Advanced retrieval parameters
12
+ similarity_top_k: int = Field(6, description="Number of top similar documents to return (after filtering)")
13
+ limit_k: int = Field(10, description="Maximum number of documents to retrieve from vector store")
14
+ similarity_metric: str = Field("cosine", description="Similarity metric to use (cosine, dotproduct, euclidean)")
15
+ similarity_threshold: float = Field(0.75, description="Threshold for vector similarity (0-1)")
16
+
17
+ # User information
18
+ session_id: Optional[str] = Field(None, description="Session ID for tracking conversations")
19
+ first_name: Optional[str] = Field(None, description="User's first name")
20
+ last_name: Optional[str] = Field(None, description="User's last name")
21
+ username: Optional[str] = Field(None, description="User's username")
22
+
23
+ class SourceDocument(BaseModel):
24
+ """Model for source documents"""
25
+ text: str = Field(..., description="Text content of the document")
26
+ source: Optional[str] = Field(None, description="Source of the document")
27
+ score: Optional[float] = Field(None, description="Raw similarity score of the document")
28
+ normalized_score: Optional[float] = Field(None, description="Normalized similarity score (0-1)")
29
+ metadata: Optional[Dict[str, Any]] = Field(None, description="Metadata of the document")
30
+
31
+ class ChatResponse(BaseModel):
32
+ """Response model for chat endpoint"""
33
+ answer: str = Field(..., description="Generated answer")
34
+ processing_time: float = Field(..., description="Processing time in seconds")
35
+
36
+ class ChatResponseInternal(BaseModel):
37
+ """Internal model for chat response with sources - used only for logging"""
38
+ answer: str
39
+ sources: Optional[List[SourceDocument]] = Field(None, description="Source documents used for generating answer")
40
+ processing_time: Optional[float] = None
41
+
42
+ class EmbeddingRequest(BaseModel):
43
+ """Request model for embedding endpoint"""
44
+ text: str = Field(..., description="Text to generate embedding for")
45
+
46
+ class EmbeddingResponse(BaseModel):
47
+ """Response model for embedding endpoint"""
48
+ embedding: List[float] = Field(..., description="Generated embedding")
49
+ text: str = Field(..., description="Text that was embedded")
50
+ model: str = Field(..., description="Model used for embedding")
51
+
52
+ class HealthResponse(BaseModel):
53
+ """Response model for health endpoint"""
54
+ status: str
55
+ services: Dict[str, bool]
56
+ timestamp: str
57
+
58
+ class UserMessageModel(BaseModel):
59
+ """Model for user messages sent to the RAG API"""
60
+ user_id: str = Field(..., description="User ID from the client application")
61
+ session_id: str = Field(..., description="Session ID for tracking the conversation")
62
+ message: str = Field(..., description="User's message/question")
63
+
64
+ # Advanced retrieval parameters (optional)
65
+ similarity_top_k: Optional[int] = Field(None, description="Number of top similar documents to return (after filtering)")
66
+ limit_k: Optional[int] = Field(None, description="Maximum number of documents to retrieve from vector store")
67
+ similarity_metric: Optional[str] = Field(None, description="Similarity metric to use (cosine, dotproduct, euclidean)")
68
+ similarity_threshold: Optional[float] = Field(None, description="Threshold for vector similarity (0-1)")
app/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Utility functions package
app/utils/cache.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import threading
4
+ import logging
5
+ from typing import Dict, Any, Optional, Tuple, List, Callable, Generic, TypeVar, Union
6
+ from datetime import datetime
7
+ from dotenv import load_dotenv
8
+ import json
9
+
10
+ # Thiết lập logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Load biến môi trường
14
+ load_dotenv()
15
+
16
+ # Cấu hình cache từ biến môi trường
17
+ DEFAULT_CACHE_TTL = int(os.getenv("CACHE_TTL_SECONDS", "300")) # Mặc định 5 phút
18
+ DEFAULT_CACHE_CLEANUP_INTERVAL = int(os.getenv("CACHE_CLEANUP_INTERVAL", "60")) # Mặc định 1 phút
19
+ DEFAULT_CACHE_MAX_SIZE = int(os.getenv("CACHE_MAX_SIZE", "1000")) # Mặc định 1000 phần tử
20
+
21
+ # Generic type để có thể sử dụng cho nhiều loại giá trị khác nhau
22
+ T = TypeVar('T')
23
+
24
+ # Cấu trúc cho một phần tử trong cache
25
+ class CacheItem(Generic[T]):
26
+ def __init__(self, value: T, ttl: int = DEFAULT_CACHE_TTL):
27
+ self.value = value
28
+ self.expire_at = time.time() + ttl
29
+ self.last_accessed = time.time()
30
+
31
+ def is_expired(self) -> bool:
32
+ """Kiểm tra xem item có hết hạn chưa"""
33
+ return time.time() > self.expire_at
34
+
35
+ def touch(self) -> None:
36
+ """Cập nhật thời gian truy cập lần cuối"""
37
+ self.last_accessed = time.time()
38
+
39
+ def extend(self, ttl: int = DEFAULT_CACHE_TTL) -> None:
40
+ """Gia hạn thời gian sống của item"""
41
+ self.expire_at = time.time() + ttl
42
+
43
+ # Lớp cache chính
44
+ class InMemoryCache:
45
+ def __init__(
46
+ self,
47
+ ttl: int = DEFAULT_CACHE_TTL,
48
+ cleanup_interval: int = DEFAULT_CACHE_CLEANUP_INTERVAL,
49
+ max_size: int = DEFAULT_CACHE_MAX_SIZE
50
+ ):
51
+ self.cache: Dict[str, CacheItem] = {}
52
+ self.ttl = ttl
53
+ self.cleanup_interval = cleanup_interval
54
+ self.max_size = max_size
55
+ self.lock = threading.RLock() # Sử dụng RLock để tránh deadlock
56
+
57
+ # Khởi động thread dọn dẹp cache định kỳ (active expiration)
58
+ self.cleanup_thread = threading.Thread(target=self._cleanup_task, daemon=True)
59
+ self.cleanup_thread.start()
60
+
61
+ def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
62
+ """Lưu một giá trị vào cache"""
63
+ with self.lock:
64
+ ttl_value = ttl if ttl is not None else self.ttl
65
+
66
+ # Nếu cache đã đầy, xóa bớt các item ít được truy cập nhất
67
+ if len(self.cache) >= self.max_size and key not in self.cache:
68
+ self._evict_lru_items()
69
+
70
+ self.cache[key] = CacheItem(value, ttl_value)
71
+ logger.debug(f"Cache set: {key} (expires in {ttl_value}s)")
72
+
73
+ def get(self, key: str, default: Any = None) -> Any:
74
+ """
75
+ Lấy giá trị từ cache. Nếu key không tồn tại hoặc đã hết hạn, trả về giá trị mặc định.
76
+ Áp dụng lazy expiration: kiểm tra và xóa các item hết hạn khi truy cập.
77
+ """
78
+ with self.lock:
79
+ item = self.cache.get(key)
80
+
81
+ # Nếu không tìm thấy key hoặc item đã hết hạn
82
+ if item is None or item.is_expired():
83
+ # Nếu item tồn tại nhưng đã hết hạn, xóa nó (lazy expiration)
84
+ if item is not None:
85
+ logger.debug(f"Cache miss (expired): {key}")
86
+ del self.cache[key]
87
+ else:
88
+ logger.debug(f"Cache miss (not found): {key}")
89
+ return default
90
+
91
+ # Cập nhật thời gian truy cập
92
+ item.touch()
93
+ logger.debug(f"Cache hit: {key}")
94
+ return item.value
95
+
96
+ def delete(self, key: str) -> bool:
97
+ """Xóa một key khỏi cache"""
98
+ with self.lock:
99
+ if key in self.cache:
100
+ del self.cache[key]
101
+ logger.debug(f"Cache delete: {key}")
102
+ return True
103
+ return False
104
+
105
+ def clear(self) -> None:
106
+ """Xóa tất cả dữ liệu trong cache"""
107
+ with self.lock:
108
+ self.cache.clear()
109
+ logger.debug("Cache cleared")
110
+
111
+ def get_or_set(self, key: str, callback: Callable[[], T], ttl: Optional[int] = None) -> T:
112
+ """
113
+ Lấy giá trị từ cache nếu tồn tại, nếu không thì gọi callback để lấy giá trị
114
+ và lưu vào cache trước khi trả về.
115
+ """
116
+ with self.lock:
117
+ value = self.get(key)
118
+ if value is None:
119
+ value = callback()
120
+ self.set(key, value, ttl)
121
+ return value
122
+
123
+ def _cleanup_task(self) -> None:
124
+ """Thread để dọn dẹp các item đã hết hạn (active expiration)"""
125
+ while True:
126
+ time.sleep(self.cleanup_interval)
127
+ try:
128
+ self._remove_expired_items()
129
+ except Exception as e:
130
+ logger.error(f"Error in cache cleanup task: {e}")
131
+
132
+ def _remove_expired_items(self) -> None:
133
+ """Xóa tất cả các item đã hết hạn trong cache"""
134
+ with self.lock:
135
+ now = time.time()
136
+ expired_keys = [k for k, v in self.cache.items() if v.is_expired()]
137
+ for key in expired_keys:
138
+ del self.cache[key]
139
+
140
+ if expired_keys:
141
+ logger.debug(f"Cleaned up {len(expired_keys)} expired cache items")
142
+
143
+ def _evict_lru_items(self, count: int = 1) -> None:
144
+ """Xóa bỏ các item ít được truy cập nhất khi cache đầy"""
145
+ items = sorted(self.cache.items(), key=lambda x: x[1].last_accessed)
146
+ for i in range(min(count, len(items))):
147
+ del self.cache[items[i][0]]
148
+ logger.debug(f"Evicted {min(count, len(items))} least recently used items from cache")
149
+
150
+ def stats(self) -> Dict[str, Any]:
151
+ """Trả về thống kê về cache"""
152
+ with self.lock:
153
+ now = time.time()
154
+ total_items = len(self.cache)
155
+ expired_items = sum(1 for item in self.cache.values() if item.is_expired())
156
+ memory_usage = self._estimate_memory_usage()
157
+ return {
158
+ "total_items": total_items,
159
+ "expired_items": expired_items,
160
+ "active_items": total_items - expired_items,
161
+ "memory_usage_bytes": memory_usage,
162
+ "memory_usage_mb": memory_usage / (1024 * 1024),
163
+ "max_size": self.max_size
164
+ }
165
+
166
+ def _estimate_memory_usage(self) -> int:
167
+ """Ước tính dung lượng bộ nhớ của cache (gần đúng)"""
168
+ # Ước tính dựa trên kích thước của các key và giá trị
169
+ cache_size = sum(len(k) for k in self.cache.keys())
170
+ for item in self.cache.values():
171
+ try:
172
+ # Ước tính kích thước của value (gần đúng)
173
+ if isinstance(item.value, (str, bytes)):
174
+ cache_size += len(item.value)
175
+ elif isinstance(item.value, (dict, list)):
176
+ cache_size += len(json.dumps(item.value))
177
+ else:
178
+ # Giá trị mặc định cho các loại dữ liệu khác
179
+ cache_size += 100
180
+ except:
181
+ cache_size += 100
182
+
183
+ return cache_size
184
+
185
+ # Singleton instance
186
+ _cache_instance = None
187
+
188
+ def get_cache() -> InMemoryCache:
189
+ """Trả về instance singleton của InMemoryCache"""
190
+ global _cache_instance
191
+ if _cache_instance is None:
192
+ _cache_instance = InMemoryCache()
193
+ return _cache_instance
app/utils/debug_utils.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import traceback
5
+ import json
6
+ import time
7
+ from datetime import datetime
8
+ import platform
9
+
10
+ # Try to import psutil, provide fallback if not available
11
+ try:
12
+ import psutil
13
+ PSUTIL_AVAILABLE = True
14
+ except ImportError:
15
+ PSUTIL_AVAILABLE = False
16
+ logging.warning("psutil module not available. System monitoring features will be limited.")
17
+
18
+ # Configure logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class DebugInfo:
22
+ """Class containing debug information"""
23
+
24
+ @staticmethod
25
+ def get_system_info():
26
+ """Get system information"""
27
+ try:
28
+ info = {
29
+ "os": platform.system(),
30
+ "os_version": platform.version(),
31
+ "python_version": platform.python_version(),
32
+ "cpu_count": os.cpu_count(),
33
+ "timestamp": datetime.now().isoformat()
34
+ }
35
+
36
+ # Add information from psutil if available
37
+ if PSUTIL_AVAILABLE:
38
+ info.update({
39
+ "total_memory": round(psutil.virtual_memory().total / (1024 * 1024 * 1024), 2), # GB
40
+ "available_memory": round(psutil.virtual_memory().available / (1024 * 1024 * 1024), 2), # GB
41
+ "cpu_usage": psutil.cpu_percent(interval=0.1),
42
+ "memory_usage": psutil.virtual_memory().percent,
43
+ "disk_usage": psutil.disk_usage('/').percent,
44
+ })
45
+ else:
46
+ info.update({
47
+ "total_memory": "psutil not available",
48
+ "available_memory": "psutil not available",
49
+ "cpu_usage": "psutil not available",
50
+ "memory_usage": "psutil not available",
51
+ "disk_usage": "psutil not available",
52
+ })
53
+
54
+ return info
55
+ except Exception as e:
56
+ logger.error(f"Error getting system info: {e}")
57
+ return {"error": str(e)}
58
+
59
+ @staticmethod
60
+ def get_env_info():
61
+ """Get environment variable information (masking sensitive information)"""
62
+ try:
63
+ # List of environment variables to mask values
64
+ sensitive_vars = [
65
+ "API_KEY", "SECRET", "PASSWORD", "TOKEN", "AUTH", "MONGODB_URL",
66
+ "AIVEN_DB_URL", "PINECONE_API_KEY", "GOOGLE_API_KEY"
67
+ ]
68
+
69
+ env_vars = {}
70
+ for key, value in os.environ.items():
71
+ # Check if environment variable contains sensitive words
72
+ is_sensitive = any(s in key.upper() for s in sensitive_vars)
73
+
74
+ if is_sensitive and value:
75
+ # Mask value displaying only the first 4 characters
76
+ masked_value = value[:4] + "****" if len(value) > 4 else "****"
77
+ env_vars[key] = masked_value
78
+ else:
79
+ env_vars[key] = value
80
+
81
+ return env_vars
82
+ except Exception as e:
83
+ logger.error(f"Error getting environment info: {e}")
84
+ return {"error": str(e)}
85
+
86
+ @staticmethod
87
+ def get_database_status():
88
+ """Get database connection status"""
89
+ try:
90
+ from app.database.postgresql import check_db_connection as check_postgresql
91
+ from app.database.mongodb import check_db_connection as check_mongodb
92
+ from app.database.pinecone import check_db_connection as check_pinecone
93
+
94
+ return {
95
+ "postgresql": check_postgresql(),
96
+ "mongodb": check_mongodb(),
97
+ "pinecone": check_pinecone(),
98
+ "timestamp": datetime.now().isoformat()
99
+ }
100
+ except Exception as e:
101
+ logger.error(f"Error getting database status: {e}")
102
+ return {"error": str(e)}
103
+
104
+ class PerformanceMonitor:
105
+ """Performance monitoring class"""
106
+
107
+ def __init__(self):
108
+ self.start_time = time.time()
109
+ self.checkpoints = []
110
+
111
+ def checkpoint(self, name):
112
+ """Mark a checkpoint and record the time"""
113
+ current_time = time.time()
114
+ elapsed = current_time - self.start_time
115
+ self.checkpoints.append({
116
+ "name": name,
117
+ "time": current_time,
118
+ "elapsed": elapsed
119
+ })
120
+ logger.debug(f"Checkpoint '{name}' at {elapsed:.4f}s")
121
+ return elapsed
122
+
123
+ def get_report(self):
124
+ """Generate performance report"""
125
+ if not self.checkpoints:
126
+ return {"error": "No checkpoints recorded"}
127
+
128
+ total_time = time.time() - self.start_time
129
+
130
+ # Calculate time between checkpoints
131
+ intervals = []
132
+ prev_time = self.start_time
133
+
134
+ for checkpoint in self.checkpoints:
135
+ interval = checkpoint["time"] - prev_time
136
+ intervals.append({
137
+ "name": checkpoint["name"],
138
+ "interval": interval,
139
+ "elapsed": checkpoint["elapsed"]
140
+ })
141
+ prev_time = checkpoint["time"]
142
+
143
+ return {
144
+ "total_time": total_time,
145
+ "checkpoint_count": len(self.checkpoints),
146
+ "intervals": intervals
147
+ }
148
+
149
+ class ErrorTracker:
150
+ """Class to track and record errors"""
151
+
152
+ def __init__(self, max_errors=100):
153
+ self.errors = []
154
+ self.max_errors = max_errors
155
+
156
+ def track_error(self, error, context=None):
157
+ """Record error information"""
158
+ error_info = {
159
+ "error_type": type(error).__name__,
160
+ "error_message": str(error),
161
+ "traceback": traceback.format_exc(),
162
+ "timestamp": datetime.now().isoformat(),
163
+ "context": context or {}
164
+ }
165
+
166
+ # Add to error list
167
+ self.errors.append(error_info)
168
+
169
+ # Limit the number of stored errors
170
+ if len(self.errors) > self.max_errors:
171
+ self.errors.pop(0) # Remove oldest error
172
+
173
+ return error_info
174
+
175
+ def get_errors(self, limit=None):
176
+ """Get list of recorded errors"""
177
+ if limit is None or limit >= len(self.errors):
178
+ return self.errors
179
+ return self.errors[-limit:] # Return most recent errors
180
+
181
+ # Initialize global objects
182
+ error_tracker = ErrorTracker()
183
+ performance_monitor = PerformanceMonitor()
184
+
185
+ def debug_view(request=None):
186
+ """Create a full debug report"""
187
+ debug_data = {
188
+ "system_info": DebugInfo.get_system_info(),
189
+ "database_status": DebugInfo.get_database_status(),
190
+ "performance": performance_monitor.get_report(),
191
+ "recent_errors": error_tracker.get_errors(limit=10),
192
+ "timestamp": datetime.now().isoformat()
193
+ }
194
+
195
+ # Add request information if available
196
+ if request:
197
+ debug_data["request"] = {
198
+ "method": request.method,
199
+ "url": str(request.url),
200
+ "headers": dict(request.headers),
201
+ "client": {
202
+ "host": request.client.host if request.client else "unknown",
203
+ "port": request.client.port if request.client else "unknown"
204
+ }
205
+ }
206
+
207
+ return debug_data
app/utils/middleware.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request, status
2
+ from fastapi.responses import JSONResponse
3
+ from starlette.middleware.base import BaseHTTPMiddleware
4
+ import logging
5
+ import time
6
+ import traceback
7
+ import uuid
8
+ from .utils import get_local_time
9
+
10
+ # Configure logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class RequestLoggingMiddleware(BaseHTTPMiddleware):
14
+ """Middleware to log requests and responses"""
15
+
16
+ async def dispatch(self, request: Request, call_next):
17
+ request_id = str(uuid.uuid4())
18
+ request.state.request_id = request_id
19
+
20
+ # Log request information
21
+ client_host = request.client.host if request.client else "unknown"
22
+ logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}")
23
+
24
+ # Measure processing time
25
+ start_time = time.time()
26
+
27
+ try:
28
+ # Process request
29
+ response = await call_next(request)
30
+
31
+ # Calculate processing time
32
+ process_time = time.time() - start_time
33
+ logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s")
34
+
35
+ # Add headers
36
+ response.headers["X-Request-ID"] = request_id
37
+ response.headers["X-Process-Time"] = str(process_time)
38
+
39
+ return response
40
+
41
+ except Exception as e:
42
+ # Log error
43
+ process_time = time.time() - start_time
44
+ logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}")
45
+ logger.error(traceback.format_exc())
46
+
47
+ # Return error response
48
+ return JSONResponse(
49
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
50
+ content={
51
+ "detail": "Internal server error",
52
+ "request_id": request_id,
53
+ "timestamp": get_local_time()
54
+ }
55
+ )
56
+
57
+ class ErrorHandlingMiddleware(BaseHTTPMiddleware):
58
+ """Middleware to handle uncaught exceptions in the application"""
59
+
60
+ async def dispatch(self, request: Request, call_next):
61
+ try:
62
+ return await call_next(request)
63
+ except Exception as e:
64
+ # Get request_id if available
65
+ request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
66
+
67
+ # Log error
68
+ logger.error(f"Uncaught exception [{request_id}]: {str(e)}")
69
+ logger.error(traceback.format_exc())
70
+
71
+ # Return error response
72
+ return JSONResponse(
73
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
74
+ content={
75
+ "detail": "Internal server error",
76
+ "request_id": request_id,
77
+ "timestamp": get_local_time()
78
+ }
79
+ )
80
+
81
+ class DatabaseCheckMiddleware(BaseHTTPMiddleware):
82
+ """Middleware to check database connections before each request"""
83
+
84
+ async def dispatch(self, request: Request, call_next):
85
+ # Skip paths that don't need database checks
86
+ skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"]
87
+ if request.url.path in skip_paths:
88
+ return await call_next(request)
89
+
90
+ # Check database connections
91
+ try:
92
+ # TODO: Add checks for MongoDB and Pinecone if needed
93
+ # PostgreSQL check is already done in route handler with get_db() method
94
+
95
+ # Process request normally
96
+ return await call_next(request)
97
+
98
+ except Exception as e:
99
+ # Log error
100
+ logger.error(f"Database connection check failed: {str(e)}")
101
+
102
+ # Return error response
103
+ return JSONResponse(
104
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
105
+ content={
106
+ "detail": "Database connection failed",
107
+ "timestamp": get_local_time()
108
+ }
109
+ )
app/utils/pdf_processor.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
7
+ import logging
8
+ from pinecone import Pinecone
9
+
10
+ from app.database.pinecone import get_pinecone_index, init_pinecone
11
+ from app.database.postgresql import get_db
12
+ from app.database.models import VectorDatabase
13
+
14
+ # Configure logging
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialize embeddings model
18
+ embeddings_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
19
+
20
+ class PDFProcessor:
21
+ """Class for processing PDF files and creating embeddings"""
22
+
23
+ def __init__(self, index_name="testbot768", namespace="Default", api_key=None, vector_db_id=None, mock_mode=False):
24
+ """Initialize with Pinecone index name, namespace and API key"""
25
+ self.index_name = index_name
26
+ self.namespace = namespace
27
+ self.pinecone_index = None
28
+ self.api_key = api_key
29
+ self.vector_db_id = vector_db_id
30
+ self.pinecone_client = None
31
+ self.mock_mode = mock_mode # Add mock mode for testing
32
+
33
+ def _get_api_key_from_db(self):
34
+ """Get API key from database if not provided directly"""
35
+ if self.api_key:
36
+ return self.api_key
37
+
38
+ if not self.vector_db_id:
39
+ logger.error("No API key provided and no vector_db_id to fetch from database")
40
+ return None
41
+
42
+ try:
43
+ # Get database session
44
+ db = next(get_db())
45
+
46
+ # Get vector database
47
+ vector_db = db.query(VectorDatabase).filter(
48
+ VectorDatabase.id == self.vector_db_id
49
+ ).first()
50
+
51
+ if not vector_db:
52
+ logger.error(f"Vector database with ID {self.vector_db_id} not found")
53
+ return None
54
+
55
+ # Get API key from relationship
56
+ if hasattr(vector_db, 'api_key_ref') and vector_db.api_key_ref and hasattr(vector_db.api_key_ref, 'key_value'):
57
+ logger.info(f"Using API key from api_key table for vector database ID {self.vector_db_id}")
58
+ return vector_db.api_key_ref.key_value
59
+
60
+ logger.error(f"No API key found for vector database ID {self.vector_db_id}. Make sure the api_key_id is properly set.")
61
+ return None
62
+ except Exception as e:
63
+ logger.error(f"Error fetching API key from database: {e}")
64
+ return None
65
+
66
+ def _init_pinecone_connection(self):
67
+ """Initialize connection to Pinecone with new API"""
68
+ try:
69
+ # If in mock mode, return a mock index
70
+ if self.mock_mode:
71
+ logger.info("Running in mock mode - simulating Pinecone connection")
72
+ class MockPineconeIndex:
73
+ def upsert(self, vectors, namespace=None):
74
+ logger.info(f"Mock upsert: {len(vectors)} vectors to namespace '{namespace}'")
75
+ return {"upserted_count": len(vectors)}
76
+
77
+ def delete(self, ids=None, delete_all=False, namespace=None):
78
+ logger.info(f"Mock delete: {'all vectors' if delete_all else f'{len(ids)} vectors'} from namespace '{namespace}'")
79
+ return {"deleted_count": 10 if delete_all else len(ids or [])}
80
+
81
+ def describe_index_stats(self):
82
+ logger.info(f"Mock describe_index_stats")
83
+ return {"total_vector_count": 100, "namespaces": {self.namespace: {"vector_count": 50}}}
84
+
85
+ return MockPineconeIndex()
86
+
87
+ # Get API key from database if not provided
88
+ api_key = self._get_api_key_from_db()
89
+
90
+ if not api_key or not self.index_name:
91
+ logger.error("Pinecone API key or index name not available")
92
+ return None
93
+
94
+ # Initialize Pinecone client using the new API
95
+ self.pinecone_client = Pinecone(api_key=api_key)
96
+
97
+ # Get the index
98
+ index_list = self.pinecone_client.list_indexes()
99
+ existing_indexes = index_list.names() if hasattr(index_list, 'names') else []
100
+
101
+ if self.index_name not in existing_indexes:
102
+ logger.error(f"Index {self.index_name} does not exist in Pinecone")
103
+ return None
104
+
105
+ # Connect to the index
106
+ index = self.pinecone_client.Index(self.index_name)
107
+ logger.info(f"Connected to Pinecone index: {self.index_name}")
108
+ return index
109
+ except Exception as e:
110
+ logger.error(f"Error connecting to Pinecone: {e}")
111
+ return None
112
+
113
+ async def process_pdf(self, file_path, document_id=None, metadata=None, progress_callback=None):
114
+ """
115
+ Process PDF file, split into chunks and create embeddings
116
+
117
+ Args:
118
+ file_path (str): Path to the PDF file
119
+ document_id (str, optional): Document ID, if not provided a new ID will be created
120
+ metadata (dict, optional): Additional metadata for the document
121
+ progress_callback (callable, optional): Callback function for progress updates
122
+
123
+ Returns:
124
+ dict: Processing result information including document_id and processed chunks count
125
+ """
126
+ try:
127
+ # Initialize Pinecone connection if not already done
128
+ self.pinecone_index = self._init_pinecone_connection()
129
+ if not self.pinecone_index:
130
+ return {"success": False, "error": "Could not connect to Pinecone"}
131
+
132
+ # Create document_id if not provided
133
+ if not document_id:
134
+ document_id = str(uuid.uuid4())
135
+
136
+ # Load PDF using PyPDFLoader
137
+ logger.info(f"Reading PDF file: {file_path}")
138
+ if progress_callback:
139
+ await progress_callback("pdf_loading", 0.5, "Loading PDF file")
140
+
141
+ loader = PyPDFLoader(file_path)
142
+ pages = loader.load()
143
+
144
+ # Extract and concatenate text from all pages
145
+ all_text = ""
146
+ for page in pages:
147
+ all_text += page.page_content + "\n"
148
+
149
+ if progress_callback:
150
+ await progress_callback("text_extraction", 0.6, "Extracted text from PDF")
151
+
152
+ # Split text into chunks
153
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300)
154
+ chunks = text_splitter.split_text(all_text)
155
+
156
+ logger.info(f"Split PDF file into {len(chunks)} chunks")
157
+ if progress_callback:
158
+ await progress_callback("chunking", 0.7, f"Split document into {len(chunks)} chunks")
159
+
160
+ # Process embeddings for each chunk and upsert to Pinecone
161
+ vectors = []
162
+ for i, chunk in enumerate(chunks):
163
+ # Update embedding progress
164
+ if progress_callback and i % 5 == 0: # Update every 5 chunks to avoid too many notifications
165
+ embedding_progress = 0.7 + (0.3 * (i / len(chunks)))
166
+ await progress_callback("embedding", embedding_progress, f"Processing chunk {i+1}/{len(chunks)}")
167
+
168
+ # Create vector embedding for each chunk
169
+ vector = embeddings_model.embed_query(chunk)
170
+
171
+ # Prepare metadata for vector
172
+ vector_metadata = {
173
+ "document_id": document_id,
174
+ "chunk_index": i,
175
+ "text": chunk
176
+ }
177
+
178
+ # Add additional metadata if provided
179
+ if metadata:
180
+ for key, value in metadata.items():
181
+ if key not in vector_metadata:
182
+ vector_metadata[key] = value
183
+
184
+ # Add vector to list for upserting
185
+ vectors.append({
186
+ "id": f"{document_id}_{i}",
187
+ "values": vector,
188
+ "metadata": vector_metadata
189
+ })
190
+
191
+ # Upsert in batches of 100 to avoid overloading
192
+ if len(vectors) >= 100:
193
+ await self._upsert_vectors(vectors)
194
+ vectors = []
195
+
196
+ # Upsert any remaining vectors
197
+ if vectors:
198
+ await self._upsert_vectors(vectors)
199
+
200
+ logger.info(f"Embedded and saved {len(chunks)} chunks from PDF with document_id: {document_id}")
201
+
202
+ # Final progress update
203
+ if progress_callback:
204
+ await progress_callback("completed", 1.0, "PDF processing complete")
205
+
206
+ return {
207
+ "success": True,
208
+ "document_id": document_id,
209
+ "chunks_processed": len(chunks),
210
+ "total_text_length": len(all_text)
211
+ }
212
+
213
+ except Exception as e:
214
+ logger.error(f"Error processing PDF: {str(e)}")
215
+ if progress_callback:
216
+ await progress_callback("error", 0, f"Error processing PDF: {str(e)}")
217
+ return {
218
+ "success": False,
219
+ "error": str(e)
220
+ }
221
+
222
+ async def _upsert_vectors(self, vectors):
223
+ """Upsert vectors to Pinecone"""
224
+ try:
225
+ if not vectors:
226
+ return
227
+
228
+ # Ensure we have a valid pinecone_index
229
+ if not self.pinecone_index:
230
+ self.pinecone_index = self._init_pinecone_connection()
231
+ if not self.pinecone_index:
232
+ raise Exception("Cannot connect to Pinecone")
233
+
234
+ result = self.pinecone_index.upsert(
235
+ vectors=vectors,
236
+ namespace=self.namespace
237
+ )
238
+
239
+ logger.info(f"Upserted {len(vectors)} vectors to Pinecone")
240
+ return result
241
+ except Exception as e:
242
+ logger.error(f"Error upserting vectors: {str(e)}")
243
+ raise
244
+
245
+ async def delete_namespace(self):
246
+ """
247
+ Delete all vectors in the current namespace (equivalent to deleting the namespace).
248
+ """
249
+ # Initialize connection if needed
250
+ self.pinecone_index = self._init_pinecone_connection()
251
+ if not self.pinecone_index:
252
+ return {"success": False, "error": "Could not connect to Pinecone"}
253
+
254
+ try:
255
+ # delete_all=True will delete all vectors in the namespace
256
+ result = self.pinecone_index.delete(
257
+ delete_all=True,
258
+ namespace=self.namespace
259
+ )
260
+ logger.info(f"Deleted namespace '{self.namespace}' (all vectors).")
261
+ return {"success": True, "detail": result}
262
+ except Exception as e:
263
+ logger.error(f"Error deleting namespace '{self.namespace}': {e}")
264
+ return {"success": False, "error": str(e)}
265
+
266
+ async def list_documents(self):
267
+ """Get list of all document_ids from Pinecone"""
268
+ try:
269
+ # Initialize Pinecone connection if not already done
270
+ self.pinecone_index = self._init_pinecone_connection()
271
+ if not self.pinecone_index:
272
+ return {"success": False, "error": "Could not connect to Pinecone"}
273
+
274
+ # Get index information
275
+ stats = self.pinecone_index.describe_index_stats()
276
+
277
+ # Query to get list of all unique document_ids
278
+ # This method may not be efficient with large datasets, but is the simplest approach
279
+ # In practice, you should maintain a list of document_ids in a separate database
280
+
281
+ return {
282
+ "success": True,
283
+ "total_vectors": stats.get('total_vector_count', 0),
284
+ "namespace": self.namespace,
285
+ "index_name": self.index_name
286
+ }
287
+ except Exception as e:
288
+ logger.error(f"Error getting document list: {str(e)}")
289
+ return {
290
+ "success": False,
291
+ "error": str(e)
292
+ }
app/utils/utils.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import uuid
4
+ import threading
5
+ import os
6
+ from functools import wraps
7
+ from datetime import datetime, timedelta
8
+ import pytz
9
+ from typing import Callable, Any, Dict, Optional, List, Tuple, Set
10
+ import gc
11
+ import heapq
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Asia/Ho_Chi_Minh timezone
21
+ asia_tz = pytz.timezone('Asia/Ho_Chi_Minh')
22
+
23
+ def generate_uuid():
24
+ """Generate a unique identifier"""
25
+ return str(uuid.uuid4())
26
+
27
+ def get_current_time():
28
+ """Get current time in ISO format"""
29
+ return datetime.now().isoformat()
30
+
31
+ def get_local_time():
32
+ """Get current time in Asia/Ho_Chi_Minh timezone"""
33
+ return datetime.now(asia_tz).strftime("%Y-%m-%d %H:%M:%S")
34
+
35
+ def get_local_datetime():
36
+ """Get current datetime object in Asia/Ho_Chi_Minh timezone"""
37
+ return datetime.now(asia_tz)
38
+
39
+ # For backward compatibility
40
+ get_vietnam_time = get_local_time
41
+ get_vietnam_datetime = get_local_datetime
42
+
43
+ def timer_decorator(func: Callable) -> Callable:
44
+ """
45
+ Decorator to time function execution and log results.
46
+ """
47
+ @wraps(func)
48
+ async def wrapper(*args, **kwargs):
49
+ start_time = time.time()
50
+ try:
51
+ result = await func(*args, **kwargs)
52
+ elapsed_time = time.time() - start_time
53
+ logger.info(f"Function {func.__name__} executed in {elapsed_time:.4f} seconds")
54
+ return result
55
+ except Exception as e:
56
+ elapsed_time = time.time() - start_time
57
+ logger.error(f"Function {func.__name__} failed after {elapsed_time:.4f} seconds: {e}")
58
+ raise
59
+ return wrapper
60
+
61
+ def sanitize_input(text):
62
+ """Sanitize input text"""
63
+ if not text:
64
+ return ""
65
+ # Remove potential dangerous characters or patterns
66
+ return text.strip()
67
+
68
+ def truncate_text(text, max_length=100):
69
+ """
70
+ Truncate text to given max length and add ellipsis.
71
+ """
72
+ if not text or len(text) <= max_length:
73
+ return text
74
+ return text[:max_length] + "..."
75
+
76
+ class CacheStrategy:
77
+ """Cache loading strategy enumeration"""
78
+ LAZY = "lazy" # Only load items into cache when requested
79
+ EAGER = "eager" # Preload items into cache at initialization
80
+ MIXED = "mixed" # Preload high-priority items, lazy load others
81
+
82
+ class CacheItem:
83
+ """Represents an item in the cache with metadata"""
84
+ def __init__(self, key: str, value: Any, ttl: int = 300, priority: int = 1):
85
+ self.key = key
86
+ self.value = value
87
+ self.expiry = datetime.now() + timedelta(seconds=ttl)
88
+ self.priority = priority # Higher number = higher priority
89
+ self.access_count = 0 # Track number of accesses
90
+ self.last_accessed = datetime.now()
91
+
92
+ def is_expired(self) -> bool:
93
+ """Check if the item is expired"""
94
+ return datetime.now() > self.expiry
95
+
96
+ def touch(self):
97
+ """Update last accessed time and access count"""
98
+ self.last_accessed = datetime.now()
99
+ self.access_count += 1
100
+
101
+ def __lt__(self, other):
102
+ """For heap comparisons - lower priority items are evicted first"""
103
+ # First compare priority
104
+ if self.priority != other.priority:
105
+ return self.priority < other.priority
106
+ # Then compare access frequency (less frequently accessed items are evicted first)
107
+ if self.access_count != other.access_count:
108
+ return self.access_count < other.access_count
109
+ # Finally compare last access time (oldest accessed first)
110
+ return self.last_accessed < other.last_accessed
111
+
112
+ def get_size(self) -> int:
113
+ """Approximate memory size of the cache item in bytes"""
114
+ try:
115
+ import sys
116
+ return sys.getsizeof(self.value) + sys.getsizeof(self.key) + 64 # Additional overhead
117
+ except:
118
+ # Default estimate if we can't get the size
119
+ return 1024
120
+
121
+ # Enhanced in-memory cache implementation
122
+ class EnhancedCache:
123
+ def __init__(self,
124
+ strategy: str = "lazy",
125
+ max_items: int = 10000,
126
+ max_size_mb: int = 100,
127
+ cleanup_interval: int = 60,
128
+ stats_enabled: bool = True):
129
+ """
130
+ Initialize enhanced cache with configurable strategy.
131
+
132
+ Args:
133
+ strategy: Cache loading strategy (lazy, eager, mixed)
134
+ max_items: Maximum number of items to store in cache
135
+ max_size_mb: Maximum size of cache in MB
136
+ cleanup_interval: Interval in seconds to run cleanup
137
+ stats_enabled: Whether to collect cache statistics
138
+ """
139
+ self._cache: Dict[str, CacheItem] = {}
140
+ self._namespace_cache: Dict[str, Set[str]] = {} # Tracking keys by namespace
141
+ self._strategy = strategy
142
+ self._max_items = max_items
143
+ self._max_size_bytes = max_size_mb * 1024 * 1024
144
+ self._current_size_bytes = 0
145
+ self._stats_enabled = stats_enabled
146
+
147
+ # Statistics
148
+ self._hits = 0
149
+ self._misses = 0
150
+ self._evictions = 0
151
+ self._total_get_time = 0
152
+ self._total_set_time = 0
153
+
154
+ # Setup cleanup thread
155
+ self._last_cleanup = datetime.now()
156
+ self._cleanup_interval = cleanup_interval
157
+ self._lock = threading.RLock()
158
+
159
+ if cleanup_interval > 0:
160
+ self._start_cleanup_thread(cleanup_interval)
161
+
162
+ logger.info(f"Enhanced cache initialized with strategy={strategy}, max_items={max_items}, max_size={max_size_mb}MB")
163
+
164
+ def _start_cleanup_thread(self, interval: int):
165
+ """Start background thread for periodic cleanup"""
166
+ def cleanup_worker():
167
+ while True:
168
+ time.sleep(interval)
169
+ try:
170
+ self.cleanup()
171
+ except Exception as e:
172
+ logger.error(f"Error in cache cleanup: {e}")
173
+
174
+ thread = threading.Thread(target=cleanup_worker, daemon=True)
175
+ thread.start()
176
+ logger.info(f"Cache cleanup thread started with interval {interval}s")
177
+
178
+ def get(self, key: str, namespace: str = None) -> Optional[Any]:
179
+ """Get value from cache if it exists and hasn't expired"""
180
+ if self._stats_enabled:
181
+ start_time = time.time()
182
+
183
+ # Use namespaced key if namespace is provided
184
+ cache_key = f"{namespace}:{key}" if namespace else key
185
+
186
+ with self._lock:
187
+ cache_item = self._cache.get(cache_key)
188
+
189
+ if cache_item:
190
+ if cache_item.is_expired():
191
+ # Clean up expired key
192
+ self._remove_item(cache_key, namespace)
193
+ if self._stats_enabled:
194
+ self._misses += 1
195
+ value = None
196
+ else:
197
+ # Update access metadata
198
+ cache_item.touch()
199
+ if self._stats_enabled:
200
+ self._hits += 1
201
+ value = cache_item.value
202
+ else:
203
+ if self._stats_enabled:
204
+ self._misses += 1
205
+ value = None
206
+
207
+ if self._stats_enabled:
208
+ self._total_get_time += time.time() - start_time
209
+
210
+ return value
211
+
212
+ def set(self, key: str, value: Any, ttl: int = 300, priority: int = 1, namespace: str = None) -> None:
213
+ """Set a value in the cache with TTL in seconds"""
214
+ if self._stats_enabled:
215
+ start_time = time.time()
216
+
217
+ # Use namespaced key if namespace is provided
218
+ cache_key = f"{namespace}:{key}" if namespace else key
219
+
220
+ with self._lock:
221
+ # Create cache item
222
+ cache_item = CacheItem(cache_key, value, ttl, priority)
223
+ item_size = cache_item.get_size()
224
+
225
+ # Check if we need to make room
226
+ if (len(self._cache) >= self._max_items or
227
+ self._current_size_bytes + item_size > self._max_size_bytes):
228
+ self._evict_items(item_size)
229
+
230
+ # Update size tracking
231
+ if cache_key in self._cache:
232
+ # If replacing, subtract old size first
233
+ self._current_size_bytes -= self._cache[cache_key].get_size()
234
+ self._current_size_bytes += item_size
235
+
236
+ # Store the item
237
+ self._cache[cache_key] = cache_item
238
+
239
+ # Update namespace tracking
240
+ if namespace:
241
+ if namespace not in self._namespace_cache:
242
+ self._namespace_cache[namespace] = set()
243
+ self._namespace_cache[namespace].add(cache_key)
244
+
245
+ if self._stats_enabled:
246
+ self._total_set_time += time.time() - start_time
247
+
248
+ def delete(self, key: str, namespace: str = None) -> None:
249
+ """Delete a key from the cache"""
250
+ # Use namespaced key if namespace is provided
251
+ cache_key = f"{namespace}:{key}" if namespace else key
252
+
253
+ with self._lock:
254
+ self._remove_item(cache_key, namespace)
255
+
256
+ def _remove_item(self, key: str, namespace: str = None):
257
+ """Internal method to remove an item and update tracking"""
258
+ if key in self._cache:
259
+ # Update size tracking
260
+ self._current_size_bytes -= self._cache[key].get_size()
261
+ # Remove from cache
262
+ del self._cache[key]
263
+
264
+ # Update namespace tracking
265
+ if namespace and namespace in self._namespace_cache:
266
+ if key in self._namespace_cache[namespace]:
267
+ self._namespace_cache[namespace].remove(key)
268
+ # Cleanup empty sets
269
+ if not self._namespace_cache[namespace]:
270
+ del self._namespace_cache[namespace]
271
+
272
+ def _evict_items(self, needed_space: int = 0) -> None:
273
+ """Evict items to make room in the cache"""
274
+ if not self._cache:
275
+ return
276
+
277
+ with self._lock:
278
+ # Convert cache items to a list for sorting
279
+ items = list(self._cache.values())
280
+
281
+ # Sort by priority, access count, and last accessed time
282
+ items.sort() # Uses the __lt__ method of CacheItem
283
+
284
+ # Evict items until we have enough space
285
+ space_freed = 0
286
+ evicted_count = 0
287
+
288
+ for item in items:
289
+ # Stop if we've made enough room
290
+ if (len(self._cache) - evicted_count <= self._max_items * 0.9 and
291
+ (space_freed >= needed_space or
292
+ self._current_size_bytes - space_freed <= self._max_size_bytes * 0.9)):
293
+ break
294
+
295
+ # Skip high priority items unless absolutely necessary
296
+ if item.priority > 9 and evicted_count < len(items) // 2:
297
+ continue
298
+
299
+ # Evict this item
300
+ item_size = item.get_size()
301
+ namespace = item.key.split(':', 1)[0] if ':' in item.key else None
302
+ self._remove_item(item.key, namespace)
303
+
304
+ space_freed += item_size
305
+ evicted_count += 1
306
+ if self._stats_enabled:
307
+ self._evictions += 1
308
+
309
+ logger.info(f"Cache eviction: removed {evicted_count} items, freed {space_freed / 1024:.2f}KB")
310
+
311
+ def clear(self, namespace: str = None) -> None:
312
+ """
313
+ Clear the cache or a specific namespace
314
+ """
315
+ with self._lock:
316
+ if namespace:
317
+ # Clear only keys in the specified namespace
318
+ if namespace in self._namespace_cache:
319
+ keys_to_remove = list(self._namespace_cache[namespace])
320
+ for key in keys_to_remove:
321
+ self._remove_item(key, namespace)
322
+ # The namespace should be auto-cleaned in _remove_item
323
+ else:
324
+ # Clear the entire cache
325
+ self._cache.clear()
326
+ self._namespace_cache.clear()
327
+ self._current_size_bytes = 0
328
+
329
+ logger.info(f"Cache cleared{' for namespace ' + namespace if namespace else ''}")
330
+
331
+ def cleanup(self) -> None:
332
+ """Remove expired items and run garbage collection if needed"""
333
+ with self._lock:
334
+ now = datetime.now()
335
+ # Only run if it's been at least cleanup_interval since last cleanup
336
+ if (now - self._last_cleanup).total_seconds() < self._cleanup_interval:
337
+ return
338
+
339
+ # Find expired items
340
+ expired_keys = []
341
+ for key, item in self._cache.items():
342
+ if item.is_expired():
343
+ expired_keys.append((key, key.split(':', 1)[0] if ':' in key else None))
344
+
345
+ # Remove expired items
346
+ for key, namespace in expired_keys:
347
+ self._remove_item(key, namespace)
348
+
349
+ # Update last cleanup time
350
+ self._last_cleanup = now
351
+
352
+ # Run garbage collection if we removed several items
353
+ if len(expired_keys) > 100:
354
+ gc.collect()
355
+
356
+ logger.info(f"Cache cleanup: removed {len(expired_keys)} expired items")
357
+
358
+ def get_stats(self) -> Dict:
359
+ """Get cache statistics"""
360
+ with self._lock:
361
+ if not self._stats_enabled:
362
+ return {"stats_enabled": False}
363
+
364
+ # Calculate hit rate
365
+ total_requests = self._hits + self._misses
366
+ hit_rate = (self._hits / total_requests) * 100 if total_requests > 0 else 0
367
+
368
+ # Calculate average times
369
+ avg_get_time = (self._total_get_time / total_requests) * 1000 if total_requests > 0 else 0
370
+ avg_set_time = (self._total_set_time / self._evictions) * 1000 if self._evictions > 0 else 0
371
+
372
+ return {
373
+ "stats_enabled": True,
374
+ "item_count": len(self._cache),
375
+ "max_items": self._max_items,
376
+ "size_bytes": self._current_size_bytes,
377
+ "max_size_bytes": self._max_size_bytes,
378
+ "hits": self._hits,
379
+ "misses": self._misses,
380
+ "hit_rate_percent": round(hit_rate, 2),
381
+ "evictions": self._evictions,
382
+ "avg_get_time_ms": round(avg_get_time, 3),
383
+ "avg_set_time_ms": round(avg_set_time, 3),
384
+ "namespace_count": len(self._namespace_cache),
385
+ "namespaces": list(self._namespace_cache.keys())
386
+ }
387
+
388
+ def preload(self, items: List[Tuple[str, Any, int, int]], namespace: str = None) -> None:
389
+ """
390
+ Preload a list of items into the cache
391
+
392
+ Args:
393
+ items: List of (key, value, ttl, priority) tuples
394
+ namespace: Optional namespace for all items
395
+ """
396
+ for key, value, ttl, priority in items:
397
+ self.set(key, value, ttl, priority, namespace)
398
+
399
+ logger.info(f"Preloaded {len(items)} items into cache{' namespace ' + namespace if namespace else ''}")
400
+
401
+ def get_or_load(self, key: str, loader_func: Callable[[], Any],
402
+ ttl: int = 300, priority: int = 1, namespace: str = None) -> Any:
403
+ """
404
+ Get from cache or load using the provided function
405
+
406
+ Args:
407
+ key: Cache key
408
+ loader_func: Function to call if cache miss occurs
409
+ ttl: TTL in seconds
410
+ priority: Item priority
411
+ namespace: Optional namespace
412
+
413
+ Returns:
414
+ Cached or freshly loaded value
415
+ """
416
+ # Try to get from cache first
417
+ value = self.get(key, namespace)
418
+
419
+ # If not in cache, load it
420
+ if value is None:
421
+ value = loader_func()
422
+ # Only cache if we got a valid value
423
+ if value is not None:
424
+ self.set(key, value, ttl, priority, namespace)
425
+
426
+ return value
427
+
428
+ # Load cache configuration from environment variables
429
+ CACHE_STRATEGY = os.getenv("CACHE_STRATEGY", "mixed")
430
+ CACHE_MAX_ITEMS = int(os.getenv("CACHE_MAX_ITEMS", "10000"))
431
+ CACHE_MAX_SIZE_MB = int(os.getenv("CACHE_MAX_SIZE_MB", "100"))
432
+ CACHE_CLEANUP_INTERVAL = int(os.getenv("CACHE_CLEANUP_INTERVAL", "60"))
433
+ CACHE_STATS_ENABLED = os.getenv("CACHE_STATS_ENABLED", "true").lower() in ("true", "1", "yes")
434
+
435
+ # Initialize the enhanced cache
436
+ cache = EnhancedCache(
437
+ strategy=CACHE_STRATEGY,
438
+ max_items=CACHE_MAX_ITEMS,
439
+ max_size_mb=CACHE_MAX_SIZE_MB,
440
+ cleanup_interval=CACHE_CLEANUP_INTERVAL,
441
+ stats_enabled=CACHE_STATS_ENABLED
442
+ )
443
+
444
+ # Backward compatibility for SimpleCache - for a transition period
445
+ class SimpleCache:
446
+ def __init__(self):
447
+ """Legacy SimpleCache implementation that uses EnhancedCache underneath"""
448
+ logger.warning("SimpleCache is deprecated, please use EnhancedCache directly")
449
+
450
+ def get(self, key: str) -> Optional[Any]:
451
+ """Get value from cache if it exists and hasn't expired"""
452
+ return cache.get(key)
453
+
454
+ def set(self, key: str, value: Any, ttl: int = 300) -> None:
455
+ """Set a value in the cache with TTL in seconds"""
456
+ cache.set(key, value, ttl)
457
+
458
+ def delete(self, key: str) -> None:
459
+ """Delete a key from the cache"""
460
+ cache.delete(key)
461
+
462
+ def clear(self) -> None:
463
+ """Clear the entire cache"""
464
+ cache.clear()
465
+
466
+ def get_host_url(request) -> str:
467
+ """
468
+ Get the host URL from a request object.
469
+ """
470
+ host = request.headers.get("host", "localhost")
471
+ scheme = request.headers.get("x-forwarded-proto", "http")
472
+ return f"{scheme}://{host}"
473
+
474
+ def format_time(timestamp):
475
+ """
476
+ Format a timestamp into a human-readable string.
477
+ """
478
+ return timestamp.strftime("%Y-%m-%d %H:%M:%S")
docker-compose.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3'
2
+
3
+ services:
4
+ backend:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ - "7860:7860"
10
+ env_file:
11
+ - .env
12
+ restart: unless-stopped
13
+ healthcheck:
14
+ test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
15
+ interval: 30s
16
+ timeout: 10s
17
+ retries: 3
18
+ start_period: 40s
19
+ volumes:
20
+ - ./app:/app/app
21
+ command: uvicorn app:app --host 0.0.0.0 --port 7860 --reload
docs/api_documentation.md ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Documentation
2
+
3
+ ## Frontend Setup
4
+
5
+ ```javascript
6
+ // Basic Axios setup
7
+ import axios from 'axios';
8
+
9
+ const api = axios.create({
10
+ baseURL: 'https://api.your-domain.com',
11
+ timeout: 10000,
12
+ headers: {
13
+ 'Content-Type': 'application/json',
14
+ 'Accept': 'application/json'
15
+ }
16
+ });
17
+
18
+ // Error handling
19
+ api.interceptors.response.use(
20
+ response => response.data,
21
+ error => {
22
+ const errorMessage = error.response?.data?.detail || 'An error occurred';
23
+ console.error('API Error:', errorMessage);
24
+ return Promise.reject(errorMessage);
25
+ }
26
+ );
27
+ ```
28
+
29
+ ## Caching System
30
+
31
+ - All GET endpoints support `use_cache=true` parameter (default)
32
+ - Cache TTL: 300 seconds (5 minutes)
33
+ - Cache is automatically invalidated on data changes
34
+
35
+ ## Authentication
36
+
37
+ Currently no authentication is required. If implemented in the future, use JWT Bearer tokens:
38
+
39
+ ```javascript
40
+ const api = axios.create({
41
+ // ...other config
42
+ headers: {
43
+ // ...other headers
44
+ 'Authorization': `Bearer ${token}`
45
+ }
46
+ });
47
+ ```
48
+
49
+ ## Error Codes
50
+
51
+ | Code | Description |
52
+ |------|-------------|
53
+ | 400 | Bad Request |
54
+ | 404 | Not Found |
55
+ | 500 | Internal Server Error |
56
+ | 503 | Service Unavailable |
57
+
58
+ ## PostgreSQL Endpoints
59
+
60
+ ### FAQ Endpoints
61
+
62
+ #### Get FAQs List
63
+ ```
64
+ GET /postgres/faq
65
+ ```
66
+
67
+ Parameters:
68
+ - `skip`: Number of items to skip (default: 0)
69
+ - `limit`: Maximum items to return (default: 100)
70
+ - `active_only`: Return only active items (default: false)
71
+ - `use_cache`: Use cached data if available (default: true)
72
+
73
+ Response:
74
+ ```json
75
+ [
76
+ {
77
+ "question": "How do I book a hotel?",
78
+ "answer": "You can book a hotel through our app or website.",
79
+ "is_active": true,
80
+ "id": 1,
81
+ "created_at": "2023-01-01T00:00:00",
82
+ "updated_at": "2023-01-01T00:00:00"
83
+ }
84
+ ]
85
+ ```
86
+
87
+ Example:
88
+ ```javascript
89
+ async function getFAQs() {
90
+ try {
91
+ const data = await api.get('/postgres/faq', {
92
+ params: { active_only: true, limit: 20 }
93
+ });
94
+ return data;
95
+ } catch (error) {
96
+ console.error('Error fetching FAQs:', error);
97
+ throw error;
98
+ }
99
+ }
100
+ ```
101
+
102
+ #### Create FAQ
103
+ ```
104
+ POST /postgres/faq
105
+ ```
106
+
107
+ Request Body:
108
+ ```json
109
+ {
110
+ "question": "How do I book a hotel?",
111
+ "answer": "You can book a hotel through our app or website.",
112
+ "is_active": true
113
+ }
114
+ ```
115
+
116
+ Response: Created FAQ object
117
+
118
+ #### Get FAQ Detail
119
+ ```
120
+ GET /postgres/faq/{faq_id}
121
+ ```
122
+
123
+ Parameters:
124
+ - `faq_id`: ID of FAQ (required)
125
+ - `use_cache`: Use cached data if available (default: true)
126
+
127
+ Response: FAQ object
128
+
129
+ #### Update FAQ
130
+ ```
131
+ PUT /postgres/faq/{faq_id}
132
+ ```
133
+
134
+ Parameters:
135
+ - `faq_id`: ID of FAQ to update (required)
136
+
137
+ Request Body: Partial or complete FAQ object
138
+ Response: Updated FAQ object
139
+
140
+ #### Delete FAQ
141
+ ```
142
+ DELETE /postgres/faq/{faq_id}
143
+ ```
144
+
145
+ Parameters:
146
+ - `faq_id`: ID of FAQ to delete (required)
147
+
148
+ Response:
149
+ ```json
150
+ {
151
+ "status": "success",
152
+ "message": "FAQ item 1 deleted"
153
+ }
154
+ ```
155
+
156
+ #### Batch Operations
157
+
158
+ Create multiple FAQs:
159
+ ```
160
+ POST /postgres/faqs/batch
161
+ ```
162
+
163
+ Update status of multiple FAQs:
164
+ ```
165
+ PUT /postgres/faqs/batch-update-status
166
+ ```
167
+
168
+ Delete multiple FAQs:
169
+ ```
170
+ DELETE /postgres/faqs/batch
171
+ ```
172
+
173
+ ### Emergency Contact Endpoints
174
+
175
+ #### Get Emergency Contacts
176
+ ```
177
+ GET /postgres/emergency
178
+ ```
179
+
180
+ Parameters:
181
+ - `skip`: Number of items to skip (default: 0)
182
+ - `limit`: Maximum items to return (default: 100)
183
+ - `active_only`: Return only active items (default: false)
184
+ - `use_cache`: Use cached data if available (default: true)
185
+
186
+ Response: Array of Emergency Contact objects
187
+
188
+ #### Create Emergency Contact
189
+ ```
190
+ POST /postgres/emergency
191
+ ```
192
+
193
+ Request Body:
194
+ ```json
195
+ {
196
+ "name": "Fire Department",
197
+ "phone_number": "114",
198
+ "description": "Fire rescue services",
199
+ "address": "Da Nang",
200
+ "location": "16.0544, 108.2022",
201
+ "priority": 1,
202
+ "is_active": true
203
+ }
204
+ ```
205
+
206
+ Response: Created Emergency Contact object
207
+
208
+ #### Get Emergency Contact
209
+ ```
210
+ GET /postgres/emergency/{emergency_id}
211
+ ```
212
+
213
+ #### Update Emergency Contact
214
+ ```
215
+ PUT /postgres/emergency/{emergency_id}
216
+ ```
217
+
218
+ #### Delete Emergency Contact
219
+ ```
220
+ DELETE /postgres/emergency/{emergency_id}
221
+ ```
222
+
223
+ #### Batch Operations
224
+
225
+ Create multiple Emergency Contacts:
226
+ ```
227
+ POST /postgres/emergency/batch
228
+ ```
229
+
230
+ Update status of multiple Emergency Contacts:
231
+ ```
232
+ PUT /postgres/emergency/batch-update-status
233
+ ```
234
+
235
+ Delete multiple Emergency Contacts:
236
+ ```
237
+ DELETE /postgres/emergency/batch
238
+ ```
239
+
240
+ ### Event Endpoints
241
+
242
+ #### Get Events
243
+ ```
244
+ GET /postgres/events
245
+ ```
246
+
247
+ Parameters:
248
+ - `skip`: Number of items to skip (default: 0)
249
+ - `limit`: Maximum items to return (default: 100)
250
+ - `active_only`: Return only active items (default: false)
251
+ - `featured_only`: Return only featured items (default: false)
252
+ - `use_cache`: Use cached data if available (default: true)
253
+
254
+ Response: Array of Event objects
255
+
256
+ #### Create Event
257
+ ```
258
+ POST /postgres/events
259
+ ```
260
+
261
+ Request Body:
262
+ ```json
263
+ {
264
+ "name": "Da Nang Fireworks Festival",
265
+ "description": "International Fireworks Festival Da Nang 2023",
266
+ "address": "Dragon Bridge, Da Nang",
267
+ "location": "16.0610, 108.2277",
268
+ "date_start": "2023-06-01T19:00:00",
269
+ "date_end": "2023-06-01T22:00:00",
270
+ "price": [
271
+ {"type": "VIP", "amount": 500000},
272
+ {"type": "Standard", "amount": 300000}
273
+ ],
274
+ "url": "https://danangfireworks.com",
275
+ "is_active": true,
276
+ "featured": true
277
+ }
278
+ ```
279
+
280
+ Response: Created Event object
281
+
282
+ #### Get Event
283
+ ```
284
+ GET /postgres/events/{event_id}
285
+ ```
286
+
287
+ #### Update Event
288
+ ```
289
+ PUT /postgres/events/{event_id}
290
+ ```
291
+
292
+ #### Delete Event
293
+ ```
294
+ DELETE /postgres/events/{event_id}
295
+ ```
296
+
297
+ #### Batch Operations
298
+
299
+ Create multiple Events:
300
+ ```
301
+ POST /postgres/events/batch
302
+ ```
303
+
304
+ Update status of multiple Events:
305
+ ```
306
+ PUT /postgres/events/batch-update-status
307
+ ```
308
+
309
+ Delete multiple Events:
310
+ ```
311
+ DELETE /postgres/events/batch
312
+ ```
313
+
314
+ ### About Pixity Endpoints
315
+
316
+ #### Get About Pixity
317
+ ```
318
+ GET /postgres/about-pixity
319
+ ```
320
+
321
+ Response:
322
+ ```json
323
+ {
324
+ "content": "PiXity is your smart, AI-powered local companion...",
325
+ "id": 1,
326
+ "created_at": "2023-01-01T00:00:00",
327
+ "updated_at": "2023-01-01T00:00:00"
328
+ }
329
+ ```
330
+
331
+ #### Update About Pixity
332
+ ```
333
+ PUT /postgres/about-pixity
334
+ ```
335
+
336
+ Request Body:
337
+ ```json
338
+ {
339
+ "content": "PiXity is your smart, AI-powered local companion..."
340
+ }
341
+ ```
342
+
343
+ Response: Updated About Pixity object
344
+
345
+ ### Da Nang Bucket List Endpoints
346
+
347
+ #### Get Da Nang Bucket List
348
+ ```
349
+ GET /postgres/danang-bucket-list
350
+ ```
351
+
352
+ Response: Bucket List object with JSON content string
353
+
354
+ #### Update Da Nang Bucket List
355
+ ```
356
+ PUT /postgres/danang-bucket-list
357
+ ```
358
+
359
+ ### Solana Summit Endpoints
360
+
361
+ #### Get Solana Summit
362
+ ```
363
+ GET /postgres/solana-summit
364
+ ```
365
+
366
+ Response: Solana Summit object with JSON content string
367
+
368
+ #### Update Solana Summit
369
+ ```
370
+ PUT /postgres/solana-summit
371
+ ```
372
+
373
+ ### Health Check
374
+ ```
375
+ GET /postgres/health
376
+ ```
377
+
378
+ Response:
379
+ ```json
380
+ {
381
+ "status": "healthy",
382
+ "message": "PostgreSQL connection is working",
383
+ "timestamp": "2023-01-01T00:00:00"
384
+ }
385
+ ```
386
+
387
+ ## MongoDB Endpoints
388
+
389
+ ### Session Endpoints
390
+
391
+ #### Create Session
392
+ ```
393
+ POST /session
394
+ ```
395
+
396
+ Request Body:
397
+ ```json
398
+ {
399
+ "user_id": "user123",
400
+ "query": "How do I book a room?",
401
+ "timestamp": "2023-01-01T00:00:00",
402
+ "metadata": {
403
+ "client_info": "web",
404
+ "location": "Da Nang"
405
+ }
406
+ }
407
+ ```
408
+
409
+ Response: Created Session object with session_id
410
+
411
+ #### Update Session with Response
412
+ ```
413
+ PUT /session/{session_id}/response
414
+ ```
415
+
416
+ Request Body:
417
+ ```json
418
+ {
419
+ "response": "You can book a room through our app or website.",
420
+ "response_timestamp": "2023-01-01T00:00:05",
421
+ "metadata": {
422
+ "response_time_ms": 234,
423
+ "model_version": "gpt-4"
424
+ }
425
+ }
426
+ ```
427
+
428
+ Response: Updated Session object
429
+
430
+ #### Get Session
431
+ ```
432
+ GET /session/{session_id}
433
+ ```
434
+
435
+ Response: Session object
436
+
437
+ #### Get User History
438
+ ```
439
+ GET /history
440
+ ```
441
+
442
+ Parameters:
443
+ - `user_id`: User ID (required)
444
+ - `limit`: Maximum sessions to return (default: 10)
445
+ - `skip`: Number of sessions to skip (default: 0)
446
+
447
+ Response:
448
+ ```json
449
+ {
450
+ "user_id": "user123",
451
+ "sessions": [
452
+ {
453
+ "session_id": "60f7a8b9c1d2e3f4a5b6c7d8",
454
+ "query": "How do I book a room?",
455
+ "timestamp": "2023-01-01T00:00:00",
456
+ "response": "You can book a room through our app or website.",
457
+ "response_timestamp": "2023-01-01T00:00:05"
458
+ }
459
+ ],
460
+ "total_count": 1
461
+ }
462
+ ```
463
+
464
+ #### Health Check
465
+ ```
466
+ GET /health
467
+ ```
468
+
469
+ ## RAG Endpoints
470
+
471
+ ### Create Embedding
472
+ ```
473
+ POST /embedding
474
+ ```
475
+
476
+ Request Body:
477
+ ```json
478
+ {
479
+ "text": "Text to embed"
480
+ }
481
+ ```
482
+
483
+ Response:
484
+ ```json
485
+ {
486
+ "embedding": [0.1, 0.2, 0.3, ...],
487
+ "dimensions": 1536
488
+ }
489
+ ```
490
+
491
+ ### Process Chat Request
492
+ ```
493
+ POST /chat
494
+ ```
495
+
496
+ Request Body:
497
+ ```json
498
+ {
499
+ "query": "Can you tell me about Pixity?",
500
+ "chat_history": [
501
+ {"role": "user", "content": "Hello"},
502
+ {"role": "assistant", "content": "Hello! How can I help you?"}
503
+ ]
504
+ }
505
+ ```
506
+
507
+ Response:
508
+ ```json
509
+ {
510
+ "answer": "Pixity is a platform...",
511
+ "sources": [
512
+ {
513
+ "document_id": "doc123",
514
+ "chunk_id": "chunk456",
515
+ "chunk_text": "Pixity was founded in...",
516
+ "relevance_score": 0.92
517
+ }
518
+ ]
519
+ }
520
+ ```
521
+
522
+ ### Direct RAG Query
523
+ ```
524
+ POST /rag
525
+ ```
526
+
527
+ Request Body:
528
+ ```json
529
+ {
530
+ "query": "Can you tell me about Pixity?",
531
+ "namespace": "about_pixity",
532
+ "top_k": 3
533
+ }
534
+ ```
535
+
536
+ Response: Query results with relevance scores
537
+
538
+ ### Health Check
539
+ ```
540
+ GET /health
541
+ ```
542
+
543
+ ## PDF Processing Endpoints
544
+
545
+ ### Upload and Process PDF
546
+ ```
547
+ POST /pdf/upload
548
+ ```
549
+
550
+ Form Data:
551
+ - `file`: PDF file (required)
552
+ - `namespace`: Vector database namespace (default: "Default")
553
+ - `index_name`: Vector database index name (default: "testbot768")
554
+ - `title`: Document title (optional)
555
+ - `description`: Document description (optional)
556
+ - `user_id`: User ID for WebSocket updates (optional)
557
+
558
+ Response: Processing results with document_id
559
+
560
+ ### Delete Documents in Namespace
561
+ ```
562
+ DELETE /pdf/namespace
563
+ ```
564
+
565
+ Parameters:
566
+ - `namespace`: Vector database namespace (default: "Default")
567
+ - `index_name`: Vector database index name (default: "testbot768")
568
+ - `user_id`: User ID for WebSocket updates (optional)
569
+
570
+ Response: Deletion results
571
+
572
+ ### Get Documents List
573
+ ```
574
+ GET /pdf/documents
575
+ ```
576
+
577
+ Parameters:
578
+ - `namespace`: Vector database namespace (default: "Default")
579
+ - `index_name`: Vector database index name (default: "testbot768")
580
+
581
+ Response: List of documents in the namespace
pytest.ini ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ # Bỏ qua cảnh báo về anyio module và các cảnh báo vận hành nội bộ
3
+ filterwarnings =
4
+ ignore::pytest.PytestAssertRewriteWarning:.*anyio
5
+ ignore:.*general_plain_validator_function.* is deprecated.*:DeprecationWarning
6
+ ignore:.*with_info_plain_validator_function.*:DeprecationWarning
7
+
8
+ # Cấu hình cơ bản khác
9
+ testpaths = tests
10
+ python_files = test_*.py
11
+ python_classes = Test*
12
+ python_functions = test_*
requirements.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI
2
+ fastapi==0.103.1
3
+ uvicorn[standard]==0.23.2
4
+ pydantic==2.4.2
5
+ python-dotenv==1.0.0
6
+ websockets==11.0.3
7
+
8
+ # MongoDB
9
+ pymongo==4.6.1
10
+ dnspython==2.4.2
11
+
12
+ # PostgreSQL
13
+ sqlalchemy==2.0.20
14
+ pydantic-settings==2.0.3
15
+ psycopg2-binary==2.9.7
16
+
17
+ # Pinecone & RAG
18
+ pinecone-client==3.0.0
19
+ langchain==0.1.4
20
+ langchain-core==0.1.19
21
+ langchain-community==0.0.14
22
+ langchain-google-genai==0.0.5
23
+ langchain-pinecone==0.0.1
24
+ faiss-cpu==1.7.4
25
+ google-generativeai==0.3.1
26
+
27
+ # Extras
28
+ pytz==2023.3
29
+ python-multipart==0.0.6
30
+ httpx==0.25.1
31
+ requests==2.31.0
32
+ beautifulsoup4==4.12.2
33
+ redis==5.0.1
34
+
35
+ # Testing
36
+ prometheus-client==0.17.1
37
+ pytest==7.4.0
38
+ pytest-cov==4.1.0
39
+ watchfiles==0.21.0
40
+
41
+ # Core dependencies
42
+ starlette==0.27.0
43
+ psutil==5.9.6
44
+
45
+ # Upload PDF
46
+ pypdf==3.17.4
47
+