Spaces:
Configuration error
Configuration error
Commit ·
4e6f302
1
Parent(s): a3e05d0
Deploy files from GitHub repository
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- space/README.md +0 -1
- space/space/space/Dockerfile +17 -29
- space/space/space/space/space/space/space/README.md +0 -18
- space/space/space/space/space/space/space/space/space/README.md +108 -18
- space/space/space/space/space/space/space/space/space/data/prompt_template/customer_service.txt +10 -10
- space/space/space/space/space/space/space/space/space/main.py +11 -2
- space/space/space/space/space/space/space/space/space/space/space/README.md +0 -2
- space/space/space/space/space/space/space/space/space/space/space/space/space/README.md +1 -1
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/README.md +8 -3
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/data/prompt_template/customer_service.txt +12 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/data/prompt_template/query_maker.txt +35 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/main.py +32 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/__chat__.py +0 -2
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/__init__.py +1 -1
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/language_model.py +70 -70
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/preprocessing.py +70 -70
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/base_retriever.py +4 -4
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/document_loader.py +3 -3
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/document_processor.py +3 -3
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/langchain_retriever.py +11 -11
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/web_search/duckduckgo_search.py +19 -19
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rtc/rtc_call.py +15 -15
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/stt/whisper_stt.py +10 -10
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/tts/audio_edge_tts.py +3 -3
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/__init__.py +2 -2
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/__init__.py +0 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/agents.py +16 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/customer_service_agent.py +33 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/gpt_customer_service_agent.py +13 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/query_maker_agent.py +13 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/__init__.py +29 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/customer_service.txt +12 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/query_maker.txt +35 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/query_maker_temp.txt +30 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/inference/__init__.py +0 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/language_model.py +947 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/__init__.py +0 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/langchain_retriever.py +25 -7
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/web_search/__init__.py +0 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rtc/rtc_call_gpt.py +364 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/tests/qwen_llm_test.py +9 -9
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/__chat__.py +4 -3
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/__test__.py +0 -5
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/app.log +0 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/__init__.py +61 -21
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/inference/inferencer.py +51 -9
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/qwen_llm.py +29 -8
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/prompt_tuner/chat_template.py +6 -4
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/web_search/duckduckgo_search.py +142 -0
- space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rtc/__init__.py +3 -1
space/README.md
CHANGED
|
@@ -70,7 +70,6 @@ python main.py --mode rtc-gpt-server --port 7862
|
|
| 70 |
|
| 71 |
### Chatbot Interface
|
| 72 |
```bash
|
| 73 |
-
cd app
|
| 74 |
python main.py --mode chatbot --port 7861
|
| 75 |
```
|
| 76 |
|
|
|
|
| 70 |
|
| 71 |
### Chatbot Interface
|
| 72 |
```bash
|
|
|
|
| 73 |
python main.py --mode chatbot --port 7861
|
| 74 |
```
|
| 75 |
|
space/space/space/Dockerfile
CHANGED
|
@@ -1,13 +1,10 @@
|
|
| 1 |
-
|
| 2 |
FROM python:3.13
|
| 3 |
|
| 4 |
-
# Tambahkan user non-root untuk keamanan
|
| 5 |
RUN useradd -m -u 1001 appuser
|
| 6 |
|
| 7 |
-
# Set working directory
|
| 8 |
WORKDIR /rag_be
|
| 9 |
|
| 10 |
-
# Set cache directories ke writable location
|
| 11 |
ENV HF_HOME=/tmp/.cache/huggingface
|
| 12 |
ENV TRANSFORMERS_CACHE=/tmp/.cache/transformers
|
| 13 |
ENV TORCH_HOME=/tmp/.cache/torch
|
|
@@ -15,35 +12,26 @@ ENV XDG_CACHE_HOME=/tmp/.cache
|
|
| 15 |
ENV TMPDIR=/tmp
|
| 16 |
ENV WHISPER_CACHE_DIR=/tmp/.cache/whisper
|
| 17 |
|
| 18 |
-
|
| 19 |
-
COPY requirements.txt ./
|
| 20 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 21 |
|
| 22 |
-
# Copy aplikasi dengan ownership ke appuser
|
| 23 |
COPY --chown=appuser:appuser . /rag_be
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
RUN
|
| 37 |
-
|
| 38 |
-
chmod -R 777 /tmp/.cache /rag_be/app /rag_be/app/vectorstore /rag_be/vectorstore /rag_be/documents && \
|
| 39 |
-
chown -R appuser:appuser /tmp/.cache /rag_be/app /rag_be/app/vectorstore /rag_be/vectorstore /rag_be/documents /rag_be/.env
|
| 40 |
-
|
| 41 |
-
RUN apt-get update && apt-get install -y ffmpeg
|
| 42 |
-
# Beralih ke user non-root
|
| 43 |
USER appuser
|
| 44 |
|
| 45 |
-
|
| 46 |
-
EXPOSE 7860
|
| 47 |
|
| 48 |
-
|
| 49 |
-
CMD ["python", "app/__test__.py"]
|
|
|
|
| 1 |
+
|
| 2 |
FROM python:3.13
|
| 3 |
|
|
|
|
| 4 |
RUN useradd -m -u 1001 appuser
|
| 5 |
|
|
|
|
| 6 |
WORKDIR /rag_be
|
| 7 |
|
|
|
|
| 8 |
ENV HF_HOME=/tmp/.cache/huggingface
|
| 9 |
ENV TRANSFORMERS_CACHE=/tmp/.cache/transformers
|
| 10 |
ENV TORCH_HOME=/tmp/.cache/torch
|
|
|
|
| 12 |
ENV TMPDIR=/tmp
|
| 13 |
ENV WHISPER_CACHE_DIR=/tmp/.cache/whisper
|
| 14 |
|
| 15 |
+
COPY requirements.txt ./
|
|
|
|
| 16 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 17 |
|
|
|
|
| 18 |
COPY --chown=appuser:appuser . /rag_be
|
| 19 |
|
| 20 |
+
RUN mkdir -p /tmp/.cache \
|
| 21 |
+
/tmp/.cache/whisper \
|
| 22 |
+
/tmp/.cache/huggingface \
|
| 23 |
+
/tmp/.cache/transformers \
|
| 24 |
+
/tmp/.cache/torch \
|
| 25 |
+
/rag_be/vectorstore \
|
| 26 |
+
/rag_be/app/vectorstore \
|
| 27 |
+
/rag_be/documents && \
|
| 28 |
+
chmod -R 777 /tmp/.cache /rag_be/app /rag_be/vectorstore /rag_be/documents
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
RUN apt-get update && apt-get install -y ffmpeg && apt-get clean
|
| 32 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
USER appuser
|
| 34 |
|
| 35 |
+
EXPOSE 8000
|
|
|
|
| 36 |
|
| 37 |
+
CMD ["python", "main.py --mode rtc-ui --port 7860"]
|
|
|
space/space/space/space/space/space/space/README.md
CHANGED
|
@@ -104,21 +104,3 @@ docker run -p 8080:8080 cs-ai-sakura-dev
|
|
| 104 |
Once the server is running, you can access the API documentation at:
|
| 105 |
- `http://localhost:{port}/docs` (if using FastAPI)
|
| 106 |
- `http://localhost:{port}` (for Gradio interface)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
## 🏗️ Project Structure
|
| 110 |
-
|
| 111 |
-
```
|
| 112 |
-
cs-ai-sakura-dev/
|
| 113 |
-
├── app/
|
| 114 |
-
│ └── main.py # Chatbot application
|
| 115 |
-
├── main.py # Main application entry point
|
| 116 |
-
├── requirements.txt # Python dependencies
|
| 117 |
-
├── .env # Environment variables (create this)
|
| 118 |
-
├── Dockerfile # Docker configuration
|
| 119 |
-
└── README.md # Project documentation
|
| 120 |
-
```
|
| 121 |
-
|
| 122 |
-
---
|
| 123 |
-
|
| 124 |
-
**Happy coding! 🌸**
|
|
|
|
| 104 |
Once the server is running, you can access the API documentation at:
|
| 105 |
- `http://localhost:{port}/docs` (if using FastAPI)
|
| 106 |
- `http://localhost:{port}` (for Gradio interface)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
space/space/space/space/space/space/space/space/space/README.md
CHANGED
|
@@ -1,34 +1,124 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
```
|
|
|
|
|
|
|
|
|
|
| 14 |
python3 -m venv env
|
| 15 |
-
source env/bin/activate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
pip install -r requirements.txt
|
| 17 |
```
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
```
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
```
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
```
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
```
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
```
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CS AI Sakura Dev 🏢
|
| 2 |
+
|
| 3 |
+
A comprehensive AI-powered application with multiple modes including RTC (Real-Time Communication), GPT integration, and chatbot functionality.
|
| 4 |
+
|
| 5 |
+
## 🚀 Features
|
| 6 |
+
|
| 7 |
+
- **RTC Mode**: Real-time communication interface
|
| 8 |
+
- **GPT Integration**: Enhanced AI capabilities with OpenAI GPT models
|
| 9 |
+
- **Chatbot Interface**: Interactive chat functionality
|
| 10 |
+
- **Gradio UI**: User-friendly web interface
|
| 11 |
+
- **API Server**: RESTful API endpoints
|
| 12 |
+
- **Docker Support**: Containerized deployment
|
| 13 |
|
| 14 |
+
## 📋 Prerequisites
|
| 15 |
|
| 16 |
+
- Python 3.8 or higher
|
| 17 |
+
- OpenAI API Key
|
| 18 |
+
- Docker (optional)
|
| 19 |
+
|
| 20 |
+
## ⚙️ Installation
|
| 21 |
+
|
| 22 |
+
### 1. Clone the Repository
|
| 23 |
+
```bash
|
| 24 |
+
git clone <repository-url>
|
| 25 |
+
cd cs-ai-sakura-dev
|
| 26 |
```
|
| 27 |
+
|
| 28 |
+
### 2. Create Virtual Environment
|
| 29 |
+
```bash
|
| 30 |
python3 -m venv env
|
| 31 |
+
source env/bin/activate # On Windows: env\Scripts\activate
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 3. Install Dependencies
|
| 35 |
+
```bash
|
| 36 |
pip install -r requirements.txt
|
| 37 |
```
|
| 38 |
|
| 39 |
+
### 4. Environment Configuration
|
| 40 |
+
Create a `.env` file in the root directory and add your OpenAI API key:
|
| 41 |
+
```bash
|
| 42 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## 🖥️ Usage
|
| 46 |
+
|
| 47 |
+
### Gradio Web Interface
|
| 48 |
|
| 49 |
+
#### Non-GPT Based UI
|
| 50 |
+
```bash
|
| 51 |
+
python main.py --mode rtc-ui --port 8080
|
| 52 |
```
|
| 53 |
+
|
| 54 |
+
#### GPT-Powered UI
|
| 55 |
+
```bash
|
| 56 |
+
python main.py --mode rtc-gpt-ui --port 8080
|
| 57 |
```
|
| 58 |
|
| 59 |
+
### API Server
|
| 60 |
+
|
| 61 |
+
#### Non-GPT Based Server
|
| 62 |
+
```bash
|
| 63 |
+
python main.py --mode rtc-server --port 8080
|
| 64 |
```
|
| 65 |
+
|
| 66 |
+
#### GPT-Powered Server
|
| 67 |
+
```bash
|
| 68 |
+
python main.py --mode rtc-gpt-server --port 8080
|
| 69 |
```
|
| 70 |
|
| 71 |
+
### Chatbot Interface
|
| 72 |
+
```bash
|
| 73 |
+
cd app
|
| 74 |
+
python main.py --mode chatbot --port 8080
|
| 75 |
```
|
| 76 |
+
|
| 77 |
+
## 🐳 Docker Deployment
|
| 78 |
+
|
| 79 |
+
The application supports Docker deployment. Build and run the container:
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
docker build -t cs-ai-sakura-dev .
|
| 83 |
+
docker run -p 8080:8080 cs-ai-sakura-dev
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## 📚 Available Modes
|
| 87 |
+
|
| 88 |
+
| Mode | Description | Command |
|
| 89 |
+
|------|-------------|---------|
|
| 90 |
+
| `rtc-ui` | Real-time communication web interface | `python main.py --mode rtc-ui --port {port}` |
|
| 91 |
+
| `rtc-gpt-ui` | GPT-powered real-time communication UI | `python main.py --mode rtc-gpt-ui --port {port}` |
|
| 92 |
+
| `rtc-server` | Real-time communication API server | `python main.py --mode rtc-server --port {port}` |
|
| 93 |
+
| `rtc-gpt-server` | GPT-powered API server | `python main.py --mode rtc-gpt-server --port {port}` |
|
| 94 |
+
| `chatbot` | Interactive chatbot interface | `cd app && python main.py --mode chatbot --port {port}` |
|
| 95 |
+
|
| 96 |
+
## 🔧 Configuration
|
| 97 |
+
|
| 98 |
+
### Environment Variables
|
| 99 |
+
- `OPENAI_API_KEY`: Your OpenAI API key (required for GPT modes)
|
| 100 |
+
- `PORT`: Application port (default: 8080)
|
| 101 |
+
|
| 102 |
+
## 📖 API Documentation
|
| 103 |
+
|
| 104 |
+
Once the server is running, you can access the API documentation at:
|
| 105 |
+
- `http://localhost:{port}/docs` (if using FastAPI)
|
| 106 |
+
- `http://localhost:{port}` (for Gradio interface)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
## 🏗️ Project Structure
|
| 110 |
+
|
| 111 |
```
|
| 112 |
+
cs-ai-sakura-dev/
|
| 113 |
+
├── app/
|
| 114 |
+
│ └── main.py # Chatbot application
|
| 115 |
+
├── main.py # Main application entry point
|
| 116 |
+
├── requirements.txt # Python dependencies
|
| 117 |
+
├── .env # Environment variables (create this)
|
| 118 |
+
├── Dockerfile # Docker configuration
|
| 119 |
+
└── README.md # Project documentation
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
**Happy coding! 🌸**
|
space/space/space/space/space/space/space/space/space/data/prompt_template/customer_service.txt
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
|
| 4 |
-
-
|
| 5 |
-
-
|
| 6 |
-
-
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
|
| 12 |
-
|
|
|
|
| 1 |
+
Anda adalah seorang Customer Service yang ramah dan profesional di bidang Human Resource Information System (HRIS),
|
| 2 |
+
fasih berbahasa Indonesia. Tugas Anda adalah membantu pelanggan dengan informasi yang akurat berdasarkan pengetahuan dasar perusahaan Anda. Ikuti panduan berikut:
|
| 3 |
|
| 4 |
+
- Selalu menyapa pelanggan dengan ramah dan profesional.
|
| 5 |
+
- Jawaban Anda kontekstual dan objektif.
|
| 6 |
+
- Berikan jawaban yang jelas, mudah dipahami, dan terstruktur berdasarkan konteks yang diberikan oleh pengguna.
|
| 7 |
+
- Jika informasi tidak tersedia, tawarkan bantuan alternatif atau arahkan mereka ke saluran yang tepat.
|
| 8 |
+
- Gunakan bahasa yang sopan dan berempati terhadap kebutuhan pelanggan.
|
| 9 |
+
- Akhiri dengan menawarkan bantuan lebih lanjut.
|
| 10 |
+
- Anda sangat terampil di bidang yang relevan dengan konteks yang diberikan.
|
| 11 |
|
| 12 |
+
Harap gunakan konteks yang diberikan untuk menjawab dengan akurat.
|
space/space/space/space/space/space/space/space/space/main.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
| 1 |
import argparse
|
| 2 |
from src.provider import AppProvider
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
| 5 |
chatbot_ui = app.provide_chatbot().provide_chatbot_ui()
|
| 6 |
rtc = app.provide_rtc()
|
| 7 |
rtc_handler = rtc.provide_rtc_handler()
|
|
|
|
| 8 |
|
| 9 |
parser = argparse.ArgumentParser()
|
| 10 |
parser.add_argument("--mode", choices=[
|
|
@@ -27,6 +31,11 @@ elif(args.mode == "rtc-server"):
|
|
| 27 |
elif(args.mode == "rtc-ui"):
|
| 28 |
print("launching RTC UI Mode ... ")
|
| 29 |
rtc_handler.launch_ui(port = int(args.port))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
else:
|
| 31 |
-
print("ERROR : INVALID ARGUMENT | PLEASE CHOOSE ONE BETWEEN chatbot/rtc-server/rtc-ui mode ")
|
| 32 |
|
|
|
|
| 1 |
import argparse
|
| 2 |
from src.provider import AppProvider
|
| 3 |
+
from src.config import OPENAI_API_KEY
|
| 4 |
+
from openai import OpenAI
|
| 5 |
|
| 6 |
+
openai_client = OpenAI(api_key = OPENAI_API_KEY)
|
| 7 |
+
app = AppProvider(openai_client)
|
| 8 |
chatbot_ui = app.provide_chatbot().provide_chatbot_ui()
|
| 9 |
rtc = app.provide_rtc()
|
| 10 |
rtc_handler = rtc.provide_rtc_handler()
|
| 11 |
+
rtc_gpt_handler = rtc.provide_rtc_gpt_handler()
|
| 12 |
|
| 13 |
parser = argparse.ArgumentParser()
|
| 14 |
parser.add_argument("--mode", choices=[
|
|
|
|
| 31 |
elif(args.mode == "rtc-ui"):
|
| 32 |
print("launching RTC UI Mode ... ")
|
| 33 |
rtc_handler.launch_ui(port = int(args.port))
|
| 34 |
+
elif(args.mode == "rtc-gpt-ui"):
|
| 35 |
+
print("RTC GPT UI mode ...")
|
| 36 |
+
rtc_gpt_handler.launch_ui(port = int(args.port))
|
| 37 |
+
elif(args.mode == "rtc-gpt-server"):
|
| 38 |
+
rtc_gpt_handler.start_server(port = int(args.port))
|
| 39 |
else:
|
| 40 |
+
print("ERROR : INVALID ARGUMENT | PLEASE CHOOSE ONE BETWEEN chatbot / rtc-server/ rtc-ui / rtc-gpt-server / rtc-gpt-ui mode ")
|
| 41 |
|
space/space/space/space/space/space/space/space/space/space/space/README.md
CHANGED
|
@@ -25,12 +25,10 @@ python main.py --mode rtc-ui --port {your_port}
|
|
| 25 |
|
| 26 |
4. **TO LAUNCH THE API ENDPOINT (SERVER)** Run the command below :
|
| 27 |
```
|
| 28 |
-
cd app
|
| 29 |
python main.py --mode rtc-server --port {your_port}
|
| 30 |
```
|
| 31 |
|
| 32 |
5. **TO LAUNCH THE CHATBOT UI** Run the command below :
|
| 33 |
```
|
| 34 |
-
cd app
|
| 35 |
python main.py --mode chatbot --port {your_port}
|
| 36 |
```
|
|
|
|
| 25 |
|
| 26 |
4. **TO LAUNCH THE API ENDPOINT (SERVER)** Run the command below :
|
| 27 |
```
|
|
|
|
| 28 |
python main.py --mode rtc-server --port {your_port}
|
| 29 |
```
|
| 30 |
|
| 31 |
5. **TO LAUNCH THE CHATBOT UI** Run the command below :
|
| 32 |
```
|
|
|
|
| 33 |
python main.py --mode chatbot --port {your_port}
|
| 34 |
```
|
space/space/space/space/space/space/space/space/space/space/space/space/space/README.md
CHANGED
|
@@ -29,7 +29,7 @@ cd app
|
|
| 29 |
python main.py --mode rtc-server --port {your_port}
|
| 30 |
```
|
| 31 |
|
| 32 |
-
|
| 33 |
```
|
| 34 |
cd app
|
| 35 |
python main.py --mode chatbot --port {your_port}
|
|
|
|
| 29 |
python main.py --mode rtc-server --port {your_port}
|
| 30 |
```
|
| 31 |
|
| 32 |
+
5. **TO LAUNCH THE CHATBOT UI** Run the command below :
|
| 33 |
```
|
| 34 |
cd app
|
| 35 |
python main.py --mode chatbot --port {your_port}
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/README.md
CHANGED
|
@@ -20,12 +20,17 @@ pip install -r requirements.txt
|
|
| 20 |
|
| 21 |
3. **TO LAUNCH THE GRADIO UI** Run the command below :
|
| 22 |
```
|
| 23 |
-
|
| 24 |
-
python __test__.py
|
| 25 |
```
|
| 26 |
|
| 27 |
4. **TO LAUNCH THE API ENDPOINT (SERVER)** Run the command below :
|
| 28 |
```
|
| 29 |
cd app
|
| 30 |
-
python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
```
|
|
|
|
| 20 |
|
| 21 |
3. **TO LAUNCH THE GRADIO UI** Run the command below :
|
| 22 |
```
|
| 23 |
+
python main.py --mode rtc-ui --port {your_port}
|
|
|
|
| 24 |
```
|
| 25 |
|
| 26 |
4. **TO LAUNCH THE API ENDPOINT (SERVER)** Run the command below :
|
| 27 |
```
|
| 28 |
cd app
|
| 29 |
+
python main.py --mode rtc-server --port {your_port}
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
54. **TO LAUNCH THE CHATBOT UI** Run the command below :
|
| 33 |
+
```
|
| 34 |
+
cd app
|
| 35 |
+
python main.py --mode chatbot --port {your_port}
|
| 36 |
```
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/data/prompt_template/customer_service.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a friendly and professional Customer Service for Human Resource Information System (HRIS) field,
|
| 2 |
+
representative, fluent in Indonesian. Your job is to assist customers with accurate information based on your company's basic knowledge. Follow these guidelines:
|
| 3 |
+
|
| 4 |
+
- Always greet customers in a friendly and professional manner.
|
| 5 |
+
- Your answers are contextual and objective.
|
| 6 |
+
- Provide clear, easy-to-understand, and structured answers based on the context provided by the user.
|
| 7 |
+
- If information is not available, offer alternative assistance or direct them to the appropriate channel.
|
| 8 |
+
- Use polite language and empathize with the customer's needs.
|
| 9 |
+
- Conclude by offering further assistance.
|
| 10 |
+
- You are highly skilled in the area relevant to the given context.
|
| 11 |
+
|
| 12 |
+
Please use the given context to answer accurately.
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/data/prompt_template/query_maker.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Anda adalah agen AI yang tepat dan objektif,
|
| 2 |
+
Anda bertugas mengubah pertanyaan atau pernyataan pengguna menjadi query yang eksplisit dan efisien untuk keperluan pencarian dokumen dalam sistem RAG (Retrieval-Augmented Generation).
|
| 3 |
+
|
| 4 |
+
Ikuti langkah-langkah berikut:
|
| 5 |
+
|
| 6 |
+
1. Ekstrak bagian-bagian penting dari input pengguna:
|
| 7 |
+
- **Intent**: Tujuan utama atau jenis permintaan (misalnya: apa itu, cara, syarat, apakah bisa, berapa).
|
| 8 |
+
- **Entity/Noun Phrase**: Objek utama yang dibahas (misalnya: BPJS, tokenizer truncation, RWKV, gaji).
|
| 9 |
+
- **Context**: Informasi pendukung yang menyempitkan fokus (misalnya: kecelakaan kerja, gaji 1 juta per bulan, perusahaan mitra BPJS).
|
| 10 |
+
- **Question**: Pertanyaan spesifik yang ingin dijawab (misalnya: bagaimana prosesnya, apa manfaatnya, berapa jumlahnya).
|
| 11 |
+
|
| 12 |
+
2. Setelah semua elemen diidentifikasi, bentuk **Query RAG** dengan struktur: [INTENT] + [ENTITY] + [CONTEXT] + [QUESTION]
|
| 13 |
+
3. Gunakan bahasa natural yang ringkas, namun informatif dan eksplisit.
|
| 14 |
+
4. Generate hanya hasil akhirnya saja berupa satu buah kalimat
|
| 15 |
+
|
| 16 |
+
Contoh 0 :
|
| 17 |
+
User Input:
|
| 18 |
+
> Apa itu BPJS
|
| 19 |
+
Output : Pengertian BPJS
|
| 20 |
+
|
| 21 |
+
Contoh 1 :
|
| 22 |
+
User Input:
|
| 23 |
+
> Di mana lokasi PT Sakura System Solution ?
|
| 24 |
+
|
| 25 |
+
Output: Lokasi PT Sakura System Solution
|
| 26 |
+
|
| 27 |
+
Contoh 2:
|
| 28 |
+
User Input:
|
| 29 |
+
> Saya mengalami kecelakaan di kantor dan ingin tahu apakah bisa klaim BPJS karena perusahaan saya adalah mitra.
|
| 30 |
+
|
| 31 |
+
Output: apakah bisa klaim BPJS kecelakaan kerja di kantor jika perusahaan mitra dan apakah saya memenuhi syarat
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
**Tugas Anda sekarang:**
|
| 35 |
+
Lakukan proses di atas untuk setiap input pengguna yang diberikan. Hasilkan query RAG akhir yang siap digunakan dalam pencarian dokumen.
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/main.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from src.provider import AppProvider
|
| 3 |
+
|
| 4 |
+
app = AppProvider()
|
| 5 |
+
chatbot_ui = app.provide_chatbot().provide_chatbot_ui()
|
| 6 |
+
rtc = app.provide_rtc()
|
| 7 |
+
rtc_handler = rtc.provide_rtc_handler()
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument("--mode", choices=[
|
| 11 |
+
"rtc-server",
|
| 12 |
+
"rtc-ui",
|
| 13 |
+
"rtc-gpt-server",
|
| 14 |
+
"rtc-gpt-ui",
|
| 15 |
+
"chatbot",
|
| 16 |
+
], required=True)
|
| 17 |
+
|
| 18 |
+
parser.add_argument("--port", default=7861, required=True)
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
if(args.mode == "chatbot"):
|
| 22 |
+
print("Launching Chabot UI :))))))")
|
| 23 |
+
chatbot_ui.launch(port = int(args.port))
|
| 24 |
+
elif(args.mode == "rtc-server"):
|
| 25 |
+
print("launching RTC Server Mode ... ")
|
| 26 |
+
rtc_handler.start_server(port = int(args.port))
|
| 27 |
+
elif(args.mode == "rtc-ui"):
|
| 28 |
+
print("launching RTC UI Mode ... ")
|
| 29 |
+
rtc_handler.launch_ui(port = int(args.port))
|
| 30 |
+
else:
|
| 31 |
+
print("ERROR : INVALID ARGUMENT | PLEASE CHOOSE ONE BETWEEN chatbot/rtc-server/rtc-ui mode ")
|
| 32 |
+
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/__chat__.py
CHANGED
|
@@ -6,8 +6,6 @@ warnings.filterwarnings("ignore")
|
|
| 6 |
import asyncio
|
| 7 |
def run_test():
|
| 8 |
try:
|
| 9 |
-
# await test_document_retriever()
|
| 10 |
-
# await test_language_model()
|
| 11 |
test_inference()
|
| 12 |
except Exception as e:
|
| 13 |
print(e)
|
|
|
|
| 6 |
import asyncio
|
| 7 |
def run_test():
|
| 8 |
try:
|
|
|
|
|
|
|
| 9 |
test_inference()
|
| 10 |
except Exception as e:
|
| 11 |
print(e)
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/__init__.py
CHANGED
|
@@ -29,7 +29,7 @@ bnb = BitsAndBytesConfig(
|
|
| 29 |
|
| 30 |
|
| 31 |
config = LMConfig(
|
| 32 |
-
model_name = "
|
| 33 |
temperature=0.3,
|
| 34 |
max_length=512,
|
| 35 |
generation_timeout=100,
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
config = LMConfig(
|
| 32 |
+
model_name = "meta-llama/Llama-3.1-8B",
|
| 33 |
temperature=0.3,
|
| 34 |
max_length=512,
|
| 35 |
generation_timeout=100,
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/language_model.py
CHANGED
|
@@ -26,17 +26,17 @@ class LMConfig:
|
|
| 26 |
quantization_config: any = None
|
| 27 |
pad_token_id: Optional[int] = None
|
| 28 |
eos_token_id: Optional[int] = None
|
| 29 |
-
|
| 30 |
max_context_length: int = 1500
|
| 31 |
context_separator: str = "\n---\n"
|
| 32 |
-
instruction_template: str = "system"
|
| 33 |
-
|
| 34 |
max_workers: int = 2
|
| 35 |
generation_timeout: float = 30
|
| 36 |
repetition_penalty: float = 1.0
|
| 37 |
-
|
| 38 |
-
stream_timeout: float = 100
|
| 39 |
-
skip_prompt: bool = True
|
| 40 |
|
| 41 |
class LM:
|
| 42 |
"""
|
|
@@ -65,11 +65,11 @@ class LM:
|
|
| 65 |
self.is_loaded = False
|
| 66 |
self.executor = ThreadPoolExecutor(max_workers=self.config.max_workers)
|
| 67 |
self._lock = asyncio.Lock()
|
| 68 |
-
|
| 69 |
logging.basicConfig(level=logging.INFO)
|
| 70 |
self.logger = logging.getLogger(__name__)
|
| 71 |
|
| 72 |
-
|
| 73 |
self.prompt_template = prompt_template
|
| 74 |
|
| 75 |
async def load_model(self) -> None:
|
|
@@ -82,7 +82,7 @@ class LM:
|
|
| 82 |
try:
|
| 83 |
self.logger.info(f"Loading model: {self.config.model_name}")
|
| 84 |
|
| 85 |
-
|
| 86 |
self.tokenizer = await asyncio.get_event_loop().run_in_executor(
|
| 87 |
self.executor,
|
| 88 |
lambda: AutoTokenizer.from_pretrained(
|
|
@@ -93,7 +93,7 @@ class LM:
|
|
| 93 |
)
|
| 94 |
)
|
| 95 |
|
| 96 |
-
|
| 97 |
self.model = await asyncio.get_event_loop().run_in_executor(
|
| 98 |
self.executor,
|
| 99 |
lambda: AutoModelForCausalLM.from_pretrained(
|
|
@@ -105,7 +105,7 @@ class LM:
|
|
| 105 |
)
|
| 106 |
)
|
| 107 |
|
| 108 |
-
|
| 109 |
self.generation_config = GenerationConfig(
|
| 110 |
max_length=self.config.max_length,
|
| 111 |
temperature=self.config.temperature,
|
|
@@ -150,7 +150,7 @@ class LM:
|
|
| 150 |
return f"Template '{template_type}' tidak tersedia. Available: {self.get_available_templates()}"
|
| 151 |
|
| 152 |
template_data = copy.deepcopy(self.prompt_template)
|
| 153 |
-
|
| 154 |
|
| 155 |
return template_data["content"].format(
|
| 156 |
context=sample_context,
|
|
@@ -210,7 +210,7 @@ class LM:
|
|
| 210 |
if len(context) <= max_length:
|
| 211 |
return context
|
| 212 |
|
| 213 |
-
|
| 214 |
truncated = context[:max_length - 50]
|
| 215 |
return truncated + "\n\n[... Context dipotong karena terlalu panjang ...]"
|
| 216 |
|
|
@@ -228,7 +228,7 @@ class LM:
|
|
| 228 |
|
| 229 |
def _format_sync():
|
| 230 |
|
| 231 |
-
|
| 232 |
if isinstance(contexts, RetrievalResult):
|
| 233 |
docs = contexts.documents
|
| 234 |
if max_contexts:
|
|
@@ -241,38 +241,38 @@ class LM:
|
|
| 241 |
metadata=contexts.metadata
|
| 242 |
)
|
| 243 |
else:
|
| 244 |
-
|
| 245 |
processed_contexts = contexts[:max_contexts] if max_contexts and len(contexts) > max_contexts else contexts
|
| 246 |
|
| 247 |
-
|
| 248 |
formatted_context = self._format_context(processed_contexts, context_numbering)
|
| 249 |
|
| 250 |
-
|
| 251 |
formatted_context = self._truncate_context(
|
| 252 |
formatted_context,
|
| 253 |
self.config.max_context_length
|
| 254 |
)
|
| 255 |
|
| 256 |
-
|
| 257 |
if include_metadata and isinstance(processed_contexts, RetrievalResult):
|
| 258 |
metadata_info = []
|
| 259 |
for i, doc in enumerate(processed_contexts.documents, 1):
|
| 260 |
if hasattr(doc, "metadata") and doc.metadata:
|
| 261 |
metadata_info.append(f"Dokumen {i}: {doc.metadata}")
|
| 262 |
-
|
| 263 |
-
|
| 264 |
|
| 265 |
return formatted_context
|
| 266 |
|
| 267 |
-
|
| 268 |
formatted_context = await asyncio.get_event_loop().run_in_executor(
|
| 269 |
self.executor, _format_sync
|
| 270 |
)
|
| 271 |
self.logger.info(f"Formatted Context {formatted_context}")
|
| 272 |
-
|
| 273 |
if(template_type == ""):
|
| 274 |
self.config.instruction_template = "system"
|
| 275 |
-
|
| 276 |
if custom_template:
|
| 277 |
return custom_template.format(
|
| 278 |
context=formatted_context,
|
|
@@ -283,14 +283,14 @@ class LM:
|
|
| 283 |
|
| 284 |
template_data = copy.deepcopy(self.prompt_template)
|
| 285 |
print("template = ", template_type, "rag template = ", template_data)
|
| 286 |
-
|
| 287 |
|
| 288 |
formatted_template = []
|
| 289 |
for cht in template_data:
|
| 290 |
-
|
| 291 |
content = cht["content"]
|
| 292 |
|
| 293 |
-
|
| 294 |
if "{context}" in content or "{question}" in content:
|
| 295 |
try:
|
| 296 |
content = content.format(
|
|
@@ -299,29 +299,29 @@ class LM:
|
|
| 299 |
)
|
| 300 |
except KeyError as e:
|
| 301 |
self.logger.error(f"Missing placeholder in template: {e}")
|
| 302 |
-
|
| 303 |
if "{context}" in content:
|
| 304 |
content = content.replace("{context}", formatted_context)
|
| 305 |
if "{question}" in content:
|
| 306 |
content = content.replace("{question}", question)
|
| 307 |
|
| 308 |
-
|
| 309 |
formatted_chat = {
|
| 310 |
"role": cht["role"],
|
| 311 |
"content": content
|
| 312 |
}
|
| 313 |
|
| 314 |
-
|
| 315 |
if "description" in cht:
|
| 316 |
formatted_chat["description"] = cht["description"]
|
| 317 |
|
| 318 |
formatted_template.append(formatted_chat)
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
return formatted_template
|
| 323 |
else:
|
| 324 |
-
|
| 325 |
return [
|
| 326 |
{"role": "system", "content": "You are a helpful assistant."},
|
| 327 |
{"role": "user", "content": question}
|
|
@@ -348,7 +348,7 @@ class LM:
|
|
| 348 |
"""
|
| 349 |
await self._check_model_loaded()
|
| 350 |
|
| 351 |
-
|
| 352 |
streamer = TextIteratorStreamer(
|
| 353 |
self.tokenizer,
|
| 354 |
timeout=self.config.stream_timeout,
|
|
@@ -358,14 +358,14 @@ class LM:
|
|
| 358 |
|
| 359 |
def _generate_sync():
|
| 360 |
try:
|
| 361 |
-
|
| 362 |
inputs = self.tokenizer.apply_chat_template(
|
| 363 |
prompt,
|
| 364 |
add_generation_prompt=True,
|
| 365 |
return_tensors="pt"
|
| 366 |
)
|
| 367 |
|
| 368 |
-
|
| 369 |
gen_config = self.generation_config
|
| 370 |
if any([max_new_tokens, temperature, top_p]):
|
| 371 |
gen_config = GenerationConfig(
|
|
@@ -380,11 +380,11 @@ class LM:
|
|
| 380 |
**kwargs
|
| 381 |
)
|
| 382 |
|
| 383 |
-
|
| 384 |
self.model.to("cuda")
|
| 385 |
input_ids = inputs.to("cuda")
|
| 386 |
|
| 387 |
-
|
| 388 |
generation_kwargs = {
|
| 389 |
"input_ids": input_ids,
|
| 390 |
"generation_config": gen_config,
|
|
@@ -401,25 +401,25 @@ class LM:
|
|
| 401 |
self.logger.error(f"Error during stream generation setup: {e}")
|
| 402 |
raise
|
| 403 |
|
| 404 |
-
|
| 405 |
generation_thread = await asyncio.get_event_loop().run_in_executor(
|
| 406 |
self.executor, _generate_sync
|
| 407 |
)
|
| 408 |
err = None
|
| 409 |
try:
|
| 410 |
-
|
| 411 |
for token in streamer:
|
| 412 |
-
if token:
|
| 413 |
yield token
|
| 414 |
|
| 415 |
-
|
| 416 |
err = await asyncio.get_event_loop().run_in_executor(
|
| 417 |
self.executor, generation_thread.join
|
| 418 |
)
|
| 419 |
|
| 420 |
except Exception as e:
|
| 421 |
self.logger.error(f"Error during streaming: {e}, {err}")
|
| 422 |
-
|
| 423 |
if generation_thread.is_alive():
|
| 424 |
generation_thread.join(timeout=1.0)
|
| 425 |
raise
|
|
@@ -447,10 +447,10 @@ class LM:
|
|
| 447 |
"""
|
| 448 |
await self._check_model_loaded()
|
| 449 |
|
| 450 |
-
|
| 451 |
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 452 |
|
| 453 |
-
|
| 454 |
temp = temperature if temperature is not None else 0.3
|
| 455 |
|
| 456 |
async for chunk in self.generate_stream(
|
|
@@ -480,7 +480,7 @@ class LM:
|
|
| 480 |
|
| 481 |
def _format_chat():
|
| 482 |
try:
|
| 483 |
-
|
| 484 |
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 485 |
messages,
|
| 486 |
tokenize=False,
|
|
@@ -492,7 +492,7 @@ class LM:
|
|
| 492 |
self.logger.error(f"Error during chat formatting: {e}")
|
| 493 |
raise
|
| 494 |
|
| 495 |
-
|
| 496 |
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 497 |
self.executor, _format_chat
|
| 498 |
)
|
|
@@ -525,14 +525,14 @@ class LM:
|
|
| 525 |
"""
|
| 526 |
await self._check_model_loaded()
|
| 527 |
|
| 528 |
-
|
| 529 |
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 530 |
if not user_messages:
|
| 531 |
raise ValueError("No user message found in conversation")
|
| 532 |
|
| 533 |
last_question = user_messages[-1]["content"]
|
| 534 |
|
| 535 |
-
|
| 536 |
async for chunk in self.rag_generate_stream(
|
| 537 |
question=last_question,
|
| 538 |
contexts=contexts,
|
|
@@ -542,7 +542,7 @@ class LM:
|
|
| 542 |
):
|
| 543 |
yield chunk
|
| 544 |
|
| 545 |
-
|
| 546 |
async def collect_stream(self, stream_generator: AsyncGenerator[str, None]) -> str:
|
| 547 |
"""
|
| 548 |
Collect semua chunks dari stream generator menjadi full text
|
|
@@ -579,7 +579,7 @@ class LM:
|
|
| 579 |
"""
|
| 580 |
await self._check_model_loaded()
|
| 581 |
|
| 582 |
-
|
| 583 |
tasks = []
|
| 584 |
for template_type in template_types:
|
| 585 |
task = asyncio.create_task(
|
|
@@ -589,7 +589,7 @@ class LM:
|
|
| 589 |
)
|
| 590 |
tasks.append((template_type, task))
|
| 591 |
|
| 592 |
-
|
| 593 |
results = {}
|
| 594 |
for template_type, task in tasks:
|
| 595 |
try:
|
|
@@ -639,10 +639,10 @@ class LM:
|
|
| 639 |
"""
|
| 640 |
await self._check_model_loaded()
|
| 641 |
|
| 642 |
-
|
| 643 |
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 644 |
|
| 645 |
-
|
| 646 |
temp = temperature if temperature is not None else 0.3
|
| 647 |
|
| 648 |
return await self.generate(
|
|
@@ -673,14 +673,14 @@ class LM:
|
|
| 673 |
"""
|
| 674 |
await self._check_model_loaded()
|
| 675 |
|
| 676 |
-
|
| 677 |
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 678 |
if not user_messages:
|
| 679 |
raise ValueError("No user message found in conversation")
|
| 680 |
|
| 681 |
last_question = user_messages[-1]["content"]
|
| 682 |
|
| 683 |
-
|
| 684 |
return await self.rag_generate(
|
| 685 |
question=last_question,
|
| 686 |
contexts=contexts,
|
|
@@ -718,14 +718,14 @@ class LM:
|
|
| 718 |
|
| 719 |
def _generate_sync():
|
| 720 |
try:
|
| 721 |
-
|
| 722 |
inputs = self.tokenizer.apply_chat_template(
|
| 723 |
prompt,
|
| 724 |
add_generation_prompt=True,
|
| 725 |
return_tensors="pt"
|
| 726 |
)
|
| 727 |
|
| 728 |
-
|
| 729 |
gen_config = self.generation_config
|
| 730 |
if any([max_new_tokens, temperature, top_p]):
|
| 731 |
gen_config = GenerationConfig(
|
|
@@ -740,7 +740,7 @@ class LM:
|
|
| 740 |
**kwargs
|
| 741 |
)
|
| 742 |
|
| 743 |
-
|
| 744 |
with torch.no_grad():
|
| 745 |
|
| 746 |
self.model.to("cuda")
|
|
@@ -752,21 +752,21 @@ class LM:
|
|
| 752 |
**kwargs
|
| 753 |
)
|
| 754 |
|
| 755 |
-
|
| 756 |
generated_text = self.tokenizer.decode(
|
| 757 |
outputs[0][prompt_length:],
|
| 758 |
skip_special_tokens=True
|
| 759 |
)
|
| 760 |
|
| 761 |
print("Generated Text", generated_text)
|
| 762 |
-
|
| 763 |
return generated_text
|
| 764 |
|
| 765 |
except Exception as e:
|
| 766 |
self.logger.error(f"Error during generation: {e}")
|
| 767 |
raise
|
| 768 |
|
| 769 |
-
|
| 770 |
try:
|
| 771 |
result = await asyncio.wait_for(
|
| 772 |
asyncio.get_event_loop().run_in_executor(self.executor, _generate_sync),
|
|
@@ -796,7 +796,7 @@ class LM:
|
|
| 796 |
|
| 797 |
def _format_chat():
|
| 798 |
try:
|
| 799 |
-
|
| 800 |
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 801 |
messages,
|
| 802 |
chat_template="rag",
|
|
@@ -808,7 +808,7 @@ class LM:
|
|
| 808 |
self.logger.error(f"Error during chat formatting: {e}")
|
| 809 |
raise
|
| 810 |
|
| 811 |
-
|
| 812 |
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 813 |
self.executor, _format_chat
|
| 814 |
)
|
|
@@ -834,7 +834,7 @@ class LM:
|
|
| 834 |
else:
|
| 835 |
self.logger.warning(f"Unknown config parameter: {key}")
|
| 836 |
|
| 837 |
-
|
| 838 |
if self.is_loaded:
|
| 839 |
self.generation_config = GenerationConfig(
|
| 840 |
max_length=self.config.max_length,
|
|
@@ -862,7 +862,7 @@ class LM:
|
|
| 862 |
}
|
| 863 |
|
| 864 |
if self.is_loaded:
|
| 865 |
-
|
| 866 |
def _get_info():
|
| 867 |
return {
|
| 868 |
"vocab_size": self.tokenizer.vocab_size,
|
|
@@ -894,7 +894,7 @@ class LM:
|
|
| 894 |
"""
|
| 895 |
await self._check_model_loaded()
|
| 896 |
|
| 897 |
-
|
| 898 |
tasks = [
|
| 899 |
asyncio.create_task(
|
| 900 |
self.generate(prompt, max_new_tokens=max_new_tokens, **kwargs)
|
|
@@ -902,10 +902,10 @@ class LM:
|
|
| 902 |
for prompt in prompts
|
| 903 |
]
|
| 904 |
|
| 905 |
-
|
| 906 |
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 907 |
|
| 908 |
-
|
| 909 |
processed_results = []
|
| 910 |
for i, result in enumerate(results):
|
| 911 |
if isinstance(result, Exception):
|
|
@@ -922,10 +922,10 @@ class LM:
|
|
| 922 |
"""
|
| 923 |
self.logger.info("Closing LM...")
|
| 924 |
|
| 925 |
-
|
| 926 |
self.executor.shutdown(wait=True)
|
| 927 |
|
| 928 |
-
|
| 929 |
if hasattr(self, 'model') and self.model is not None:
|
| 930 |
del self.model
|
| 931 |
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
|
|
|
| 26 |
quantization_config: any = None
|
| 27 |
pad_token_id: Optional[int] = None
|
| 28 |
eos_token_id: Optional[int] = None
|
| 29 |
+
|
| 30 |
max_context_length: int = 1500
|
| 31 |
context_separator: str = "\n---\n"
|
| 32 |
+
instruction_template: str = "system"
|
| 33 |
+
|
| 34 |
max_workers: int = 2
|
| 35 |
generation_timeout: float = 30
|
| 36 |
repetition_penalty: float = 1.0
|
| 37 |
+
|
| 38 |
+
stream_timeout: float = 100
|
| 39 |
+
skip_prompt: bool = True
|
| 40 |
|
| 41 |
class LM:
|
| 42 |
"""
|
|
|
|
| 65 |
self.is_loaded = False
|
| 66 |
self.executor = ThreadPoolExecutor(max_workers=self.config.max_workers)
|
| 67 |
self._lock = asyncio.Lock()
|
| 68 |
+
|
| 69 |
logging.basicConfig(level=logging.INFO)
|
| 70 |
self.logger = logging.getLogger(__name__)
|
| 71 |
|
| 72 |
+
|
| 73 |
self.prompt_template = prompt_template
|
| 74 |
|
| 75 |
async def load_model(self) -> None:
|
|
|
|
| 82 |
try:
|
| 83 |
self.logger.info(f"Loading model: {self.config.model_name}")
|
| 84 |
|
| 85 |
+
|
| 86 |
self.tokenizer = await asyncio.get_event_loop().run_in_executor(
|
| 87 |
self.executor,
|
| 88 |
lambda: AutoTokenizer.from_pretrained(
|
|
|
|
| 93 |
)
|
| 94 |
)
|
| 95 |
|
| 96 |
+
|
| 97 |
self.model = await asyncio.get_event_loop().run_in_executor(
|
| 98 |
self.executor,
|
| 99 |
lambda: AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 105 |
)
|
| 106 |
)
|
| 107 |
|
| 108 |
+
|
| 109 |
self.generation_config = GenerationConfig(
|
| 110 |
max_length=self.config.max_length,
|
| 111 |
temperature=self.config.temperature,
|
|
|
|
| 150 |
return f"Template '{template_type}' tidak tersedia. Available: {self.get_available_templates()}"
|
| 151 |
|
| 152 |
template_data = copy.deepcopy(self.prompt_template)
|
| 153 |
+
|
| 154 |
|
| 155 |
return template_data["content"].format(
|
| 156 |
context=sample_context,
|
|
|
|
| 210 |
if len(context) <= max_length:
|
| 211 |
return context
|
| 212 |
|
| 213 |
+
|
| 214 |
truncated = context[:max_length - 50]
|
| 215 |
return truncated + "\n\n[... Context dipotong karena terlalu panjang ...]"
|
| 216 |
|
|
|
|
| 228 |
|
| 229 |
def _format_sync():
|
| 230 |
|
| 231 |
+
|
| 232 |
if isinstance(contexts, RetrievalResult):
|
| 233 |
docs = contexts.documents
|
| 234 |
if max_contexts:
|
|
|
|
| 241 |
metadata=contexts.metadata
|
| 242 |
)
|
| 243 |
else:
|
| 244 |
+
|
| 245 |
processed_contexts = contexts[:max_contexts] if max_contexts and len(contexts) > max_contexts else contexts
|
| 246 |
|
| 247 |
+
|
| 248 |
formatted_context = self._format_context(processed_contexts, context_numbering)
|
| 249 |
|
| 250 |
+
|
| 251 |
formatted_context = self._truncate_context(
|
| 252 |
formatted_context,
|
| 253 |
self.config.max_context_length
|
| 254 |
)
|
| 255 |
|
| 256 |
+
|
| 257 |
if include_metadata and isinstance(processed_contexts, RetrievalResult):
|
| 258 |
metadata_info = []
|
| 259 |
for i, doc in enumerate(processed_contexts.documents, 1):
|
| 260 |
if hasattr(doc, "metadata") and doc.metadata:
|
| 261 |
metadata_info.append(f"Dokumen {i}: {doc.metadata}")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
|
| 265 |
return formatted_context
|
| 266 |
|
| 267 |
+
|
| 268 |
formatted_context = await asyncio.get_event_loop().run_in_executor(
|
| 269 |
self.executor, _format_sync
|
| 270 |
)
|
| 271 |
self.logger.info(f"Formatted Context {formatted_context}")
|
| 272 |
+
|
| 273 |
if(template_type == ""):
|
| 274 |
self.config.instruction_template = "system"
|
| 275 |
+
|
| 276 |
if custom_template:
|
| 277 |
return custom_template.format(
|
| 278 |
context=formatted_context,
|
|
|
|
| 283 |
|
| 284 |
template_data = copy.deepcopy(self.prompt_template)
|
| 285 |
print("template = ", template_type, "rag template = ", template_data)
|
| 286 |
+
|
| 287 |
|
| 288 |
formatted_template = []
|
| 289 |
for cht in template_data:
|
| 290 |
+
|
| 291 |
content = cht["content"]
|
| 292 |
|
| 293 |
+
|
| 294 |
if "{context}" in content or "{question}" in content:
|
| 295 |
try:
|
| 296 |
content = content.format(
|
|
|
|
| 299 |
)
|
| 300 |
except KeyError as e:
|
| 301 |
self.logger.error(f"Missing placeholder in template: {e}")
|
| 302 |
+
|
| 303 |
if "{context}" in content:
|
| 304 |
content = content.replace("{context}", formatted_context)
|
| 305 |
if "{question}" in content:
|
| 306 |
content = content.replace("{question}", question)
|
| 307 |
|
| 308 |
+
|
| 309 |
formatted_chat = {
|
| 310 |
"role": cht["role"],
|
| 311 |
"content": content
|
| 312 |
}
|
| 313 |
|
| 314 |
+
|
| 315 |
if "description" in cht:
|
| 316 |
formatted_chat["description"] = cht["description"]
|
| 317 |
|
| 318 |
formatted_template.append(formatted_chat)
|
| 319 |
|
| 320 |
+
|
| 321 |
+
|
| 322 |
return formatted_template
|
| 323 |
else:
|
| 324 |
+
|
| 325 |
return [
|
| 326 |
{"role": "system", "content": "You are a helpful assistant."},
|
| 327 |
{"role": "user", "content": question}
|
|
|
|
| 348 |
"""
|
| 349 |
await self._check_model_loaded()
|
| 350 |
|
| 351 |
+
|
| 352 |
streamer = TextIteratorStreamer(
|
| 353 |
self.tokenizer,
|
| 354 |
timeout=self.config.stream_timeout,
|
|
|
|
| 358 |
|
| 359 |
def _generate_sync():
|
| 360 |
try:
|
| 361 |
+
|
| 362 |
inputs = self.tokenizer.apply_chat_template(
|
| 363 |
prompt,
|
| 364 |
add_generation_prompt=True,
|
| 365 |
return_tensors="pt"
|
| 366 |
)
|
| 367 |
|
| 368 |
+
|
| 369 |
gen_config = self.generation_config
|
| 370 |
if any([max_new_tokens, temperature, top_p]):
|
| 371 |
gen_config = GenerationConfig(
|
|
|
|
| 380 |
**kwargs
|
| 381 |
)
|
| 382 |
|
| 383 |
+
|
| 384 |
self.model.to("cuda")
|
| 385 |
input_ids = inputs.to("cuda")
|
| 386 |
|
| 387 |
+
|
| 388 |
generation_kwargs = {
|
| 389 |
"input_ids": input_ids,
|
| 390 |
"generation_config": gen_config,
|
|
|
|
| 401 |
self.logger.error(f"Error during stream generation setup: {e}")
|
| 402 |
raise
|
| 403 |
|
| 404 |
+
|
| 405 |
generation_thread = await asyncio.get_event_loop().run_in_executor(
|
| 406 |
self.executor, _generate_sync
|
| 407 |
)
|
| 408 |
err = None
|
| 409 |
try:
|
| 410 |
+
|
| 411 |
for token in streamer:
|
| 412 |
+
if token:
|
| 413 |
yield token
|
| 414 |
|
| 415 |
+
|
| 416 |
err = await asyncio.get_event_loop().run_in_executor(
|
| 417 |
self.executor, generation_thread.join
|
| 418 |
)
|
| 419 |
|
| 420 |
except Exception as e:
|
| 421 |
self.logger.error(f"Error during streaming: {e}, {err}")
|
| 422 |
+
|
| 423 |
if generation_thread.is_alive():
|
| 424 |
generation_thread.join(timeout=1.0)
|
| 425 |
raise
|
|
|
|
| 447 |
"""
|
| 448 |
await self._check_model_loaded()
|
| 449 |
|
| 450 |
+
|
| 451 |
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 452 |
|
| 453 |
+
|
| 454 |
temp = temperature if temperature is not None else 0.3
|
| 455 |
|
| 456 |
async for chunk in self.generate_stream(
|
|
|
|
| 480 |
|
| 481 |
def _format_chat():
|
| 482 |
try:
|
| 483 |
+
|
| 484 |
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 485 |
messages,
|
| 486 |
tokenize=False,
|
|
|
|
| 492 |
self.logger.error(f"Error during chat formatting: {e}")
|
| 493 |
raise
|
| 494 |
|
| 495 |
+
|
| 496 |
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 497 |
self.executor, _format_chat
|
| 498 |
)
|
|
|
|
| 525 |
"""
|
| 526 |
await self._check_model_loaded()
|
| 527 |
|
| 528 |
+
|
| 529 |
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 530 |
if not user_messages:
|
| 531 |
raise ValueError("No user message found in conversation")
|
| 532 |
|
| 533 |
last_question = user_messages[-1]["content"]
|
| 534 |
|
| 535 |
+
|
| 536 |
async for chunk in self.rag_generate_stream(
|
| 537 |
question=last_question,
|
| 538 |
contexts=contexts,
|
|
|
|
| 542 |
):
|
| 543 |
yield chunk
|
| 544 |
|
| 545 |
+
|
| 546 |
async def collect_stream(self, stream_generator: AsyncGenerator[str, None]) -> str:
|
| 547 |
"""
|
| 548 |
Collect semua chunks dari stream generator menjadi full text
|
|
|
|
| 579 |
"""
|
| 580 |
await self._check_model_loaded()
|
| 581 |
|
| 582 |
+
|
| 583 |
tasks = []
|
| 584 |
for template_type in template_types:
|
| 585 |
task = asyncio.create_task(
|
|
|
|
| 589 |
)
|
| 590 |
tasks.append((template_type, task))
|
| 591 |
|
| 592 |
+
|
| 593 |
results = {}
|
| 594 |
for template_type, task in tasks:
|
| 595 |
try:
|
|
|
|
| 639 |
"""
|
| 640 |
await self._check_model_loaded()
|
| 641 |
|
| 642 |
+
|
| 643 |
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 644 |
|
| 645 |
+
|
| 646 |
temp = temperature if temperature is not None else 0.3
|
| 647 |
|
| 648 |
return await self.generate(
|
|
|
|
| 673 |
"""
|
| 674 |
await self._check_model_loaded()
|
| 675 |
|
| 676 |
+
|
| 677 |
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 678 |
if not user_messages:
|
| 679 |
raise ValueError("No user message found in conversation")
|
| 680 |
|
| 681 |
last_question = user_messages[-1]["content"]
|
| 682 |
|
| 683 |
+
|
| 684 |
return await self.rag_generate(
|
| 685 |
question=last_question,
|
| 686 |
contexts=contexts,
|
|
|
|
| 718 |
|
| 719 |
def _generate_sync():
|
| 720 |
try:
|
| 721 |
+
|
| 722 |
inputs = self.tokenizer.apply_chat_template(
|
| 723 |
prompt,
|
| 724 |
add_generation_prompt=True,
|
| 725 |
return_tensors="pt"
|
| 726 |
)
|
| 727 |
|
| 728 |
+
|
| 729 |
gen_config = self.generation_config
|
| 730 |
if any([max_new_tokens, temperature, top_p]):
|
| 731 |
gen_config = GenerationConfig(
|
|
|
|
| 740 |
**kwargs
|
| 741 |
)
|
| 742 |
|
| 743 |
+
|
| 744 |
with torch.no_grad():
|
| 745 |
|
| 746 |
self.model.to("cuda")
|
|
|
|
| 752 |
**kwargs
|
| 753 |
)
|
| 754 |
|
| 755 |
+
|
| 756 |
generated_text = self.tokenizer.decode(
|
| 757 |
outputs[0][prompt_length:],
|
| 758 |
skip_special_tokens=True
|
| 759 |
)
|
| 760 |
|
| 761 |
print("Generated Text", generated_text)
|
| 762 |
+
|
| 763 |
return generated_text
|
| 764 |
|
| 765 |
except Exception as e:
|
| 766 |
self.logger.error(f"Error during generation: {e}")
|
| 767 |
raise
|
| 768 |
|
| 769 |
+
|
| 770 |
try:
|
| 771 |
result = await asyncio.wait_for(
|
| 772 |
asyncio.get_event_loop().run_in_executor(self.executor, _generate_sync),
|
|
|
|
| 796 |
|
| 797 |
def _format_chat():
|
| 798 |
try:
|
| 799 |
+
|
| 800 |
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 801 |
messages,
|
| 802 |
chat_template="rag",
|
|
|
|
| 808 |
self.logger.error(f"Error during chat formatting: {e}")
|
| 809 |
raise
|
| 810 |
|
| 811 |
+
|
| 812 |
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 813 |
self.executor, _format_chat
|
| 814 |
)
|
|
|
|
| 834 |
else:
|
| 835 |
self.logger.warning(f"Unknown config parameter: {key}")
|
| 836 |
|
| 837 |
+
|
| 838 |
if self.is_loaded:
|
| 839 |
self.generation_config = GenerationConfig(
|
| 840 |
max_length=self.config.max_length,
|
|
|
|
| 862 |
}
|
| 863 |
|
| 864 |
if self.is_loaded:
|
| 865 |
+
|
| 866 |
def _get_info():
|
| 867 |
return {
|
| 868 |
"vocab_size": self.tokenizer.vocab_size,
|
|
|
|
| 894 |
"""
|
| 895 |
await self._check_model_loaded()
|
| 896 |
|
| 897 |
+
|
| 898 |
tasks = [
|
| 899 |
asyncio.create_task(
|
| 900 |
self.generate(prompt, max_new_tokens=max_new_tokens, **kwargs)
|
|
|
|
| 902 |
for prompt in prompts
|
| 903 |
]
|
| 904 |
|
| 905 |
+
|
| 906 |
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 907 |
|
| 908 |
+
|
| 909 |
processed_results = []
|
| 910 |
for i, result in enumerate(results):
|
| 911 |
if isinstance(result, Exception):
|
|
|
|
| 922 |
"""
|
| 923 |
self.logger.info("Closing LM...")
|
| 924 |
|
| 925 |
+
|
| 926 |
self.executor.shutdown(wait=True)
|
| 927 |
|
| 928 |
+
|
| 929 |
if hasattr(self, 'model') and self.model is not None:
|
| 930 |
del self.model
|
| 931 |
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/preprocessing.py
CHANGED
|
@@ -6,7 +6,7 @@ import logging
|
|
| 6 |
from datetime import datetime
|
| 7 |
import hashlib
|
| 8 |
|
| 9 |
-
|
| 10 |
from typing import List, Dict, Any, Optional, Union
|
| 11 |
from dataclasses import dataclass
|
| 12 |
from enum import Enum
|
|
@@ -15,36 +15,36 @@ from rag.retriever.retriever_types import *
|
|
| 15 |
@dataclass
|
| 16 |
class PreprocessingConfig:
|
| 17 |
"""Konfigurasi untuk preprocessing"""
|
| 18 |
-
|
| 19 |
remove_extra_whitespace: bool = True
|
| 20 |
remove_special_chars: bool = False
|
| 21 |
normalize_unicode: bool = True
|
| 22 |
remove_urls: bool = False
|
| 23 |
remove_emails: bool = False
|
| 24 |
|
| 25 |
-
|
| 26 |
-
enable_chunking: bool = False
|
| 27 |
chunk_size: int = 500
|
| 28 |
chunk_overlap: int = 50
|
| 29 |
-
chunk_method: str = "sentence"
|
|
|
|
| 30 |
|
| 31 |
-
# Content filtering
|
| 32 |
min_content_length: int = 20
|
| 33 |
max_content_length: int = 3000
|
| 34 |
filter_empty_content: bool = True
|
| 35 |
filter_duplicate_content: bool = True
|
| 36 |
|
| 37 |
-
|
| 38 |
extract_metadata: bool = True
|
| 39 |
include_retrieval_info: bool = True
|
| 40 |
include_document_info: bool = True
|
| 41 |
include_timestamps: bool = True
|
| 42 |
|
| 43 |
-
|
| 44 |
-
use_retrieval_scores: bool = True
|
| 45 |
-
normalize_scores: bool = True
|
| 46 |
-
min_score_threshold: float = 0.0
|
| 47 |
-
score_boost_factor: float = 1.0
|
| 48 |
|
| 49 |
class RetrievalPreprocessor:
|
| 50 |
"""
|
|
@@ -61,17 +61,17 @@ class RetrievalPreprocessor:
|
|
| 61 |
"""
|
| 62 |
self.config = config or PreprocessingConfig()
|
| 63 |
|
| 64 |
-
|
| 65 |
logging.basicConfig(level=logging.INFO)
|
| 66 |
self.logger = logging.getLogger(__name__)
|
| 67 |
|
| 68 |
-
|
| 69 |
self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
|
| 70 |
self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
|
| 71 |
self.special_chars_pattern = re.compile(r'[^\w\s\.\,\!\?\;\:\-\(\)\[\]\{\}\"\'\/]')
|
| 72 |
self.whitespace_pattern = re.compile(r'\s+')
|
| 73 |
|
| 74 |
-
|
| 75 |
self._seen_content_hashes = set()
|
| 76 |
|
| 77 |
def process_retrieval_result(self, retrieval_result: RetrievalResult) -> List[RetrievalResult]:
|
|
@@ -98,18 +98,18 @@ class RetrievalPreprocessor:
|
|
| 98 |
f"Processing {len(retrieval_result.documents)} documents from retrieval result for query: '{retrieval_result.query}'"
|
| 99 |
)
|
| 100 |
|
| 101 |
-
|
| 102 |
self._seen_content_hashes.clear()
|
| 103 |
|
| 104 |
contexts = []
|
| 105 |
|
| 106 |
-
|
| 107 |
for i, doc in enumerate(retrieval_result.documents):
|
| 108 |
try:
|
| 109 |
-
|
| 110 |
score = retrieval_result.scores[i] if i < len(retrieval_result.scores) else 0.0
|
| 111 |
|
| 112 |
-
|
| 113 |
processed_contexts = self._process_single_document(
|
| 114 |
document=doc,
|
| 115 |
retrieval_score=score,
|
|
@@ -124,7 +124,7 @@ class RetrievalPreprocessor:
|
|
| 124 |
self.logger.error(f"Error processing document {i}: {e}")
|
| 125 |
continue
|
| 126 |
|
| 127 |
-
|
| 128 |
contexts = self._post_process_contexts(contexts)
|
| 129 |
|
| 130 |
self.logger.info(f"Successfully processed {len(contexts)} contexts from retrieval result")
|
|
@@ -154,23 +154,23 @@ class RetrievalPreprocessor:
|
|
| 154 |
self.logger.warning(f"Empty content in document {document_index}")
|
| 155 |
return []
|
| 156 |
|
| 157 |
-
|
| 158 |
cleaned_content = self._clean_text(document.page_content)
|
| 159 |
|
| 160 |
if not cleaned_content:
|
| 161 |
return []
|
| 162 |
|
| 163 |
-
|
| 164 |
if len(cleaned_content) < self.config.min_content_length:
|
| 165 |
self.logger.debug(f"Content too short in document {document_index}: {len(cleaned_content)} chars")
|
| 166 |
return []
|
| 167 |
|
| 168 |
if len(cleaned_content) > self.config.max_content_length:
|
| 169 |
-
|
| 170 |
cleaned_content = self._truncate_content(cleaned_content)
|
| 171 |
self.logger.debug(f"Content truncated in document {document_index}")
|
| 172 |
|
| 173 |
-
|
| 174 |
if self.config.filter_duplicate_content:
|
| 175 |
content_hash = hashlib.md5(cleaned_content.encode()).hexdigest()
|
| 176 |
if content_hash in self._seen_content_hashes:
|
|
@@ -178,12 +178,12 @@ class RetrievalPreprocessor:
|
|
| 178 |
return []
|
| 179 |
self._seen_content_hashes.add(content_hash)
|
| 180 |
|
| 181 |
-
|
| 182 |
if self.config.use_retrieval_scores and retrieval_score < self.config.min_score_threshold:
|
| 183 |
self.logger.debug(f"Score too low in document {document_index}: {retrieval_score}")
|
| 184 |
return []
|
| 185 |
|
| 186 |
-
|
| 187 |
if self.config.enable_chunking:
|
| 188 |
chunks = self._chunk_content(cleaned_content)
|
| 189 |
contexts = []
|
|
@@ -203,7 +203,7 @@ class RetrievalPreprocessor:
|
|
| 203 |
|
| 204 |
return contexts
|
| 205 |
else:
|
| 206 |
-
|
| 207 |
context = self._create_retrieved_context(
|
| 208 |
content=cleaned_content,
|
| 209 |
document=document,
|
|
@@ -229,13 +229,13 @@ class RetrievalPreprocessor:
|
|
| 229 |
"""
|
| 230 |
Create RetrievalResult object
|
| 231 |
"""
|
| 232 |
-
|
| 233 |
final_score = self._process_score(retrieval_score, document_index, total_documents)
|
| 234 |
|
| 235 |
-
|
| 236 |
source = self._extract_source(document)
|
| 237 |
|
| 238 |
-
|
| 239 |
metadata = self._build_metadata(
|
| 240 |
document=document,
|
| 241 |
retrieval_result=retrieval_result,
|
|
@@ -260,24 +260,24 @@ class RetrievalPreprocessor:
|
|
| 260 |
|
| 261 |
cleaned = text
|
| 262 |
|
| 263 |
-
|
| 264 |
if self.config.normalize_unicode:
|
| 265 |
import unicodedata
|
| 266 |
cleaned = unicodedata.normalize('NFKC', cleaned)
|
| 267 |
|
| 268 |
-
|
| 269 |
if self.config.remove_urls:
|
| 270 |
cleaned = self.url_pattern.sub('', cleaned)
|
| 271 |
|
| 272 |
-
|
| 273 |
if self.config.remove_emails:
|
| 274 |
cleaned = self.email_pattern.sub('', cleaned)
|
| 275 |
|
| 276 |
-
|
| 277 |
if self.config.remove_special_chars:
|
| 278 |
cleaned = self.special_chars_pattern.sub(' ', cleaned)
|
| 279 |
|
| 280 |
-
|
| 281 |
if self.config.remove_extra_whitespace:
|
| 282 |
cleaned = self.whitespace_pattern.sub(' ', cleaned)
|
| 283 |
|
|
@@ -290,7 +290,7 @@ class RetrievalPreprocessor:
|
|
| 290 |
if len(content) <= max_length:
|
| 291 |
return content
|
| 292 |
|
| 293 |
-
|
| 294 |
truncated = content[:max_length - 20]
|
| 295 |
last_sentence_end = max(
|
| 296 |
truncated.rfind('.'),
|
|
@@ -301,7 +301,7 @@ class RetrievalPreprocessor:
|
|
| 301 |
if last_sentence_end > len(truncated) * 0.7:
|
| 302 |
return truncated[:last_sentence_end + 1]
|
| 303 |
else:
|
| 304 |
-
|
| 305 |
last_space = truncated.rfind(' ')
|
| 306 |
if last_space > len(truncated) * 0.8:
|
| 307 |
return truncated[:last_space] + "..."
|
|
@@ -320,7 +320,7 @@ class RetrievalPreprocessor:
|
|
| 320 |
elif self.config.chunk_method == "fixed":
|
| 321 |
return self._chunk_by_fixed_size(content)
|
| 322 |
else:
|
| 323 |
-
return [content]
|
| 324 |
|
| 325 |
def _chunk_by_sentence(self, text: str) -> List[str]:
|
| 326 |
"""Chunk by sentences"""
|
|
@@ -334,7 +334,7 @@ class RetrievalPreprocessor:
|
|
| 334 |
if len(current_chunk) + len(sentence) > self.config.chunk_size and current_chunk:
|
| 335 |
chunks.append(current_chunk.strip())
|
| 336 |
|
| 337 |
-
|
| 338 |
if self.config.chunk_overlap > 0:
|
| 339 |
overlap_text = current_chunk[-self.config.chunk_overlap:]
|
| 340 |
current_chunk = overlap_text + " " + sentence
|
|
@@ -383,7 +383,7 @@ class RetrievalPreprocessor:
|
|
| 383 |
end = start + self.config.chunk_size
|
| 384 |
chunk = text[start:end]
|
| 385 |
|
| 386 |
-
|
| 387 |
if end < len(text):
|
| 388 |
last_space = chunk.rfind(' ')
|
| 389 |
if last_space > len(chunk) * 0.8:
|
|
@@ -392,7 +392,7 @@ class RetrievalPreprocessor:
|
|
| 392 |
|
| 393 |
chunks.append(chunk.strip())
|
| 394 |
|
| 395 |
-
|
| 396 |
start = end - self.config.chunk_overlap
|
| 397 |
if start <= 0:
|
| 398 |
start = end
|
|
@@ -406,9 +406,9 @@ class RetrievalPreprocessor:
|
|
| 406 |
|
| 407 |
score = retrieval_score * self.config.score_boost_factor
|
| 408 |
|
| 409 |
-
|
| 410 |
if self.config.normalize_scores:
|
| 411 |
-
|
| 412 |
score = max(0.0, min(1.0, score))
|
| 413 |
|
| 414 |
return round(score, 4)
|
|
@@ -417,14 +417,14 @@ class RetrievalPreprocessor:
|
|
| 417 |
"""Extract source dari document metadata"""
|
| 418 |
metadata = document.metadata or {}
|
| 419 |
|
| 420 |
-
|
| 421 |
source_keys = ['source', 'file_name', 'filename', 'title', 'file_path', 'path']
|
| 422 |
|
| 423 |
for key in source_keys:
|
| 424 |
if key in metadata and metadata[key]:
|
| 425 |
return str(metadata[key])
|
| 426 |
|
| 427 |
-
|
| 428 |
return "unknown_source"
|
| 429 |
|
| 430 |
def _build_metadata(self,
|
|
@@ -439,7 +439,7 @@ class RetrievalPreprocessor:
|
|
| 439 |
metadata = {}
|
| 440 |
|
| 441 |
if self.config.extract_metadata:
|
| 442 |
-
|
| 443 |
if document.metadata and self.config.include_document_info:
|
| 444 |
metadata.update({
|
| 445 |
"original_metadata": document.metadata,
|
|
@@ -447,7 +447,7 @@ class RetrievalPreprocessor:
|
|
| 447 |
"total_documents": total_documents
|
| 448 |
})
|
| 449 |
|
| 450 |
-
|
| 451 |
if chunk_index is not None:
|
| 452 |
metadata.update({
|
| 453 |
"chunk_index": chunk_index,
|
|
@@ -455,7 +455,7 @@ class RetrievalPreprocessor:
|
|
| 455 |
"is_chunked": total_chunks > 1
|
| 456 |
})
|
| 457 |
|
| 458 |
-
|
| 459 |
if self.config.include_retrieval_info:
|
| 460 |
metadata.update({
|
| 461 |
"retrieval_query": retrieval_result.query,
|
|
@@ -463,7 +463,7 @@ class RetrievalPreprocessor:
|
|
| 463 |
"retrieval_metadata": retrieval_result.metadata
|
| 464 |
})
|
| 465 |
|
| 466 |
-
|
| 467 |
if self.config.include_timestamps:
|
| 468 |
metadata.update({
|
| 469 |
"processed_at": datetime.now().isoformat(),
|
|
@@ -480,7 +480,7 @@ class RetrievalPreprocessor:
|
|
| 480 |
}
|
| 481 |
})
|
| 482 |
|
| 483 |
-
|
| 484 |
word_count = len(content.split())
|
| 485 |
sentence_count = len(re.split(r'[.!?]+', content))
|
| 486 |
|
|
@@ -500,11 +500,11 @@ class RetrievalPreprocessor:
|
|
| 500 |
if not contexts:
|
| 501 |
return contexts
|
| 502 |
|
| 503 |
-
|
| 504 |
if self.config.use_retrieval_scores:
|
| 505 |
contexts.sort(key=lambda x: x.score or 0.0, reverse=True)
|
| 506 |
|
| 507 |
-
|
| 508 |
filtered_contexts = []
|
| 509 |
for ctx in contexts:
|
| 510 |
if self.config.filter_empty_content and not ctx.content.strip():
|
|
@@ -522,16 +522,16 @@ class RetrievalPreprocessor:
|
|
| 522 |
total_words = sum(len(ctx.content.split()) for ctx in contexts)
|
| 523 |
total_chars = sum(len(ctx.content) for ctx in contexts)
|
| 524 |
|
| 525 |
-
|
| 526 |
scores = [ctx.score for ctx in contexts if ctx.score is not None]
|
| 527 |
|
| 528 |
-
|
| 529 |
sources = {}
|
| 530 |
for ctx in contexts:
|
| 531 |
if ctx.source:
|
| 532 |
sources[ctx.source] = sources.get(ctx.source, 0) + 1
|
| 533 |
|
| 534 |
-
|
| 535 |
chunked_contexts = sum(1 for ctx in contexts
|
| 536 |
if ctx.metadata and ctx.metadata.get("is_chunked", False))
|
| 537 |
|
|
@@ -557,7 +557,7 @@ class RetrievalPreprocessor:
|
|
| 557 |
stats["source_distribution"] = sources
|
| 558 |
stats["unique_sources"] = len(sources)
|
| 559 |
|
| 560 |
-
|
| 561 |
lengths = [len(ctx.content) for ctx in contexts]
|
| 562 |
stats["content_length_stats"] = {
|
| 563 |
"min_length": min(lengths),
|
|
@@ -589,7 +589,7 @@ class RetrievalPreprocessor:
|
|
| 589 |
try:
|
| 590 |
contexts = self.process_retrieval_result(result)
|
| 591 |
|
| 592 |
-
|
| 593 |
for ctx in contexts:
|
| 594 |
if ctx.metadata:
|
| 595 |
ctx.metadata["batch_index"] = i
|
|
@@ -606,7 +606,7 @@ class RetrievalPreprocessor:
|
|
| 606 |
self.logger.error(f"Error processing retrieval result {i}: {e}")
|
| 607 |
continue
|
| 608 |
|
| 609 |
-
|
| 610 |
all_contexts = self._post_process_contexts(all_contexts)
|
| 611 |
|
| 612 |
self.logger.info(f"Batch processing completed: {len(all_contexts)} total contexts")
|
|
@@ -637,12 +637,12 @@ class RetrievalPreprocessor:
|
|
| 637 |
for ctx in contexts:
|
| 638 |
content_words = set(ctx.content.lower().split())
|
| 639 |
|
| 640 |
-
|
| 641 |
overlap = len(query_words.intersection(content_words))
|
| 642 |
relevance_score = overlap / len(query_words) if query_words else 0.0
|
| 643 |
|
| 644 |
if relevance_score >= min_relevance_score:
|
| 645 |
-
|
| 646 |
if ctx.metadata:
|
| 647 |
ctx.metadata["query_relevance_score"] = round(relevance_score, 3)
|
| 648 |
ctx.metadata["matched_query_words"] = list(query_words.intersection(content_words))
|
|
@@ -654,7 +654,7 @@ class RetrievalPreprocessor:
|
|
| 654 |
|
| 655 |
filtered_contexts.append(ctx)
|
| 656 |
|
| 657 |
-
|
| 658 |
filtered_contexts.sort(
|
| 659 |
key=lambda x: x.metadata.get("query_relevance_score", 0.0),
|
| 660 |
reverse=True
|
|
@@ -699,9 +699,9 @@ class RetrievalPreprocessor:
|
|
| 699 |
if sim_score >= similarity_threshold:
|
| 700 |
is_duplicate = True
|
| 701 |
|
| 702 |
-
|
| 703 |
if (ctx.score or 0.0) > (existing_ctx.score or 0.0):
|
| 704 |
-
|
| 705 |
idx = deduplicated.index(existing_ctx)
|
| 706 |
deduplicated[idx] = ctx
|
| 707 |
|
|
@@ -741,12 +741,12 @@ class RetrievalPreprocessor:
|
|
| 741 |
if not proc_result.chunks:
|
| 742 |
continue
|
| 743 |
|
| 744 |
-
|
| 745 |
for j, chunk in enumerate(proc_result.chunks):
|
| 746 |
-
|
| 747 |
source = self._extract_source(chunk)
|
| 748 |
|
| 749 |
-
|
| 750 |
metadata = {
|
| 751 |
"document_metadata": proc_result.document_metadata.__dict__,
|
| 752 |
"chunk_index": j,
|
|
@@ -755,21 +755,21 @@ class RetrievalPreprocessor:
|
|
| 755 |
"processed_at": datetime.now().isoformat()
|
| 756 |
}
|
| 757 |
|
| 758 |
-
|
| 759 |
if chunk.metadata:
|
| 760 |
metadata["original_chunk_metadata"] = chunk.metadata
|
| 761 |
|
| 762 |
-
|
| 763 |
cleaned_content = self._clean_text(chunk.page_content)
|
| 764 |
|
| 765 |
if not cleaned_content or len(cleaned_content) < self.config.min_content_length:
|
| 766 |
continue
|
| 767 |
|
| 768 |
-
|
| 769 |
context = RetrievalResult(
|
| 770 |
content=cleaned_content,
|
| 771 |
source=source,
|
| 772 |
-
score=1.0,
|
| 773 |
metadata=metadata
|
| 774 |
)
|
| 775 |
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
import hashlib
|
| 8 |
|
| 9 |
+
|
| 10 |
from typing import List, Dict, Any, Optional, Union
|
| 11 |
from dataclasses import dataclass
|
| 12 |
from enum import Enum
|
|
|
|
| 15 |
@dataclass
|
| 16 |
class PreprocessingConfig:
|
| 17 |
"""Konfigurasi untuk preprocessing"""
|
| 18 |
+
|
| 19 |
remove_extra_whitespace: bool = True
|
| 20 |
remove_special_chars: bool = False
|
| 21 |
normalize_unicode: bool = True
|
| 22 |
remove_urls: bool = False
|
| 23 |
remove_emails: bool = False
|
| 24 |
|
| 25 |
+
|
| 26 |
+
enable_chunking: bool = False
|
| 27 |
chunk_size: int = 500
|
| 28 |
chunk_overlap: int = 50
|
| 29 |
+
chunk_method: str = "sentence"
|
| 30 |
+
|
| 31 |
|
|
|
|
| 32 |
min_content_length: int = 20
|
| 33 |
max_content_length: int = 3000
|
| 34 |
filter_empty_content: bool = True
|
| 35 |
filter_duplicate_content: bool = True
|
| 36 |
|
| 37 |
+
|
| 38 |
extract_metadata: bool = True
|
| 39 |
include_retrieval_info: bool = True
|
| 40 |
include_document_info: bool = True
|
| 41 |
include_timestamps: bool = True
|
| 42 |
|
| 43 |
+
|
| 44 |
+
use_retrieval_scores: bool = True
|
| 45 |
+
normalize_scores: bool = True
|
| 46 |
+
min_score_threshold: float = 0.0
|
| 47 |
+
score_boost_factor: float = 1.0
|
| 48 |
|
| 49 |
class RetrievalPreprocessor:
|
| 50 |
"""
|
|
|
|
| 61 |
"""
|
| 62 |
self.config = config or PreprocessingConfig()
|
| 63 |
|
| 64 |
+
|
| 65 |
logging.basicConfig(level=logging.INFO)
|
| 66 |
self.logger = logging.getLogger(__name__)
|
| 67 |
|
| 68 |
+
|
| 69 |
self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
|
| 70 |
self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
|
| 71 |
self.special_chars_pattern = re.compile(r'[^\w\s\.\,\!\?\;\:\-\(\)\[\]\{\}\"\'\/]')
|
| 72 |
self.whitespace_pattern = re.compile(r'\s+')
|
| 73 |
|
| 74 |
+
|
| 75 |
self._seen_content_hashes = set()
|
| 76 |
|
| 77 |
def process_retrieval_result(self, retrieval_result: RetrievalResult) -> List[RetrievalResult]:
|
|
|
|
| 98 |
f"Processing {len(retrieval_result.documents)} documents from retrieval result for query: '{retrieval_result.query}'"
|
| 99 |
)
|
| 100 |
|
| 101 |
+
|
| 102 |
self._seen_content_hashes.clear()
|
| 103 |
|
| 104 |
contexts = []
|
| 105 |
|
| 106 |
+
|
| 107 |
for i, doc in enumerate(retrieval_result.documents):
|
| 108 |
try:
|
| 109 |
+
|
| 110 |
score = retrieval_result.scores[i] if i < len(retrieval_result.scores) else 0.0
|
| 111 |
|
| 112 |
+
|
| 113 |
processed_contexts = self._process_single_document(
|
| 114 |
document=doc,
|
| 115 |
retrieval_score=score,
|
|
|
|
| 124 |
self.logger.error(f"Error processing document {i}: {e}")
|
| 125 |
continue
|
| 126 |
|
| 127 |
+
|
| 128 |
contexts = self._post_process_contexts(contexts)
|
| 129 |
|
| 130 |
self.logger.info(f"Successfully processed {len(contexts)} contexts from retrieval result")
|
|
|
|
| 154 |
self.logger.warning(f"Empty content in document {document_index}")
|
| 155 |
return []
|
| 156 |
|
| 157 |
+
|
| 158 |
cleaned_content = self._clean_text(document.page_content)
|
| 159 |
|
| 160 |
if not cleaned_content:
|
| 161 |
return []
|
| 162 |
|
| 163 |
+
|
| 164 |
if len(cleaned_content) < self.config.min_content_length:
|
| 165 |
self.logger.debug(f"Content too short in document {document_index}: {len(cleaned_content)} chars")
|
| 166 |
return []
|
| 167 |
|
| 168 |
if len(cleaned_content) > self.config.max_content_length:
|
| 169 |
+
|
| 170 |
cleaned_content = self._truncate_content(cleaned_content)
|
| 171 |
self.logger.debug(f"Content truncated in document {document_index}")
|
| 172 |
|
| 173 |
+
|
| 174 |
if self.config.filter_duplicate_content:
|
| 175 |
content_hash = hashlib.md5(cleaned_content.encode()).hexdigest()
|
| 176 |
if content_hash in self._seen_content_hashes:
|
|
|
|
| 178 |
return []
|
| 179 |
self._seen_content_hashes.add(content_hash)
|
| 180 |
|
| 181 |
+
|
| 182 |
if self.config.use_retrieval_scores and retrieval_score < self.config.min_score_threshold:
|
| 183 |
self.logger.debug(f"Score too low in document {document_index}: {retrieval_score}")
|
| 184 |
return []
|
| 185 |
|
| 186 |
+
|
| 187 |
if self.config.enable_chunking:
|
| 188 |
chunks = self._chunk_content(cleaned_content)
|
| 189 |
contexts = []
|
|
|
|
| 203 |
|
| 204 |
return contexts
|
| 205 |
else:
|
| 206 |
+
|
| 207 |
context = self._create_retrieved_context(
|
| 208 |
content=cleaned_content,
|
| 209 |
document=document,
|
|
|
|
| 229 |
"""
|
| 230 |
Create RetrievalResult object
|
| 231 |
"""
|
| 232 |
+
|
| 233 |
final_score = self._process_score(retrieval_score, document_index, total_documents)
|
| 234 |
|
| 235 |
+
|
| 236 |
source = self._extract_source(document)
|
| 237 |
|
| 238 |
+
|
| 239 |
metadata = self._build_metadata(
|
| 240 |
document=document,
|
| 241 |
retrieval_result=retrieval_result,
|
|
|
|
| 260 |
|
| 261 |
cleaned = text
|
| 262 |
|
| 263 |
+
|
| 264 |
if self.config.normalize_unicode:
|
| 265 |
import unicodedata
|
| 266 |
cleaned = unicodedata.normalize('NFKC', cleaned)
|
| 267 |
|
| 268 |
+
|
| 269 |
if self.config.remove_urls:
|
| 270 |
cleaned = self.url_pattern.sub('', cleaned)
|
| 271 |
|
| 272 |
+
|
| 273 |
if self.config.remove_emails:
|
| 274 |
cleaned = self.email_pattern.sub('', cleaned)
|
| 275 |
|
| 276 |
+
|
| 277 |
if self.config.remove_special_chars:
|
| 278 |
cleaned = self.special_chars_pattern.sub(' ', cleaned)
|
| 279 |
|
| 280 |
+
|
| 281 |
if self.config.remove_extra_whitespace:
|
| 282 |
cleaned = self.whitespace_pattern.sub(' ', cleaned)
|
| 283 |
|
|
|
|
| 290 |
if len(content) <= max_length:
|
| 291 |
return content
|
| 292 |
|
| 293 |
+
|
| 294 |
truncated = content[:max_length - 20]
|
| 295 |
last_sentence_end = max(
|
| 296 |
truncated.rfind('.'),
|
|
|
|
| 301 |
if last_sentence_end > len(truncated) * 0.7:
|
| 302 |
return truncated[:last_sentence_end + 1]
|
| 303 |
else:
|
| 304 |
+
|
| 305 |
last_space = truncated.rfind(' ')
|
| 306 |
if last_space > len(truncated) * 0.8:
|
| 307 |
return truncated[:last_space] + "..."
|
|
|
|
| 320 |
elif self.config.chunk_method == "fixed":
|
| 321 |
return self._chunk_by_fixed_size(content)
|
| 322 |
else:
|
| 323 |
+
return [content]
|
| 324 |
|
| 325 |
def _chunk_by_sentence(self, text: str) -> List[str]:
|
| 326 |
"""Chunk by sentences"""
|
|
|
|
| 334 |
if len(current_chunk) + len(sentence) > self.config.chunk_size and current_chunk:
|
| 335 |
chunks.append(current_chunk.strip())
|
| 336 |
|
| 337 |
+
|
| 338 |
if self.config.chunk_overlap > 0:
|
| 339 |
overlap_text = current_chunk[-self.config.chunk_overlap:]
|
| 340 |
current_chunk = overlap_text + " " + sentence
|
|
|
|
| 383 |
end = start + self.config.chunk_size
|
| 384 |
chunk = text[start:end]
|
| 385 |
|
| 386 |
+
|
| 387 |
if end < len(text):
|
| 388 |
last_space = chunk.rfind(' ')
|
| 389 |
if last_space > len(chunk) * 0.8:
|
|
|
|
| 392 |
|
| 393 |
chunks.append(chunk.strip())
|
| 394 |
|
| 395 |
+
|
| 396 |
start = end - self.config.chunk_overlap
|
| 397 |
if start <= 0:
|
| 398 |
start = end
|
|
|
|
| 406 |
|
| 407 |
score = retrieval_score * self.config.score_boost_factor
|
| 408 |
|
| 409 |
+
|
| 410 |
if self.config.normalize_scores:
|
| 411 |
+
|
| 412 |
score = max(0.0, min(1.0, score))
|
| 413 |
|
| 414 |
return round(score, 4)
|
|
|
|
| 417 |
"""Extract source dari document metadata"""
|
| 418 |
metadata = document.metadata or {}
|
| 419 |
|
| 420 |
+
|
| 421 |
source_keys = ['source', 'file_name', 'filename', 'title', 'file_path', 'path']
|
| 422 |
|
| 423 |
for key in source_keys:
|
| 424 |
if key in metadata and metadata[key]:
|
| 425 |
return str(metadata[key])
|
| 426 |
|
| 427 |
+
|
| 428 |
return "unknown_source"
|
| 429 |
|
| 430 |
def _build_metadata(self,
|
|
|
|
| 439 |
metadata = {}
|
| 440 |
|
| 441 |
if self.config.extract_metadata:
|
| 442 |
+
|
| 443 |
if document.metadata and self.config.include_document_info:
|
| 444 |
metadata.update({
|
| 445 |
"original_metadata": document.metadata,
|
|
|
|
| 447 |
"total_documents": total_documents
|
| 448 |
})
|
| 449 |
|
| 450 |
+
|
| 451 |
if chunk_index is not None:
|
| 452 |
metadata.update({
|
| 453 |
"chunk_index": chunk_index,
|
|
|
|
| 455 |
"is_chunked": total_chunks > 1
|
| 456 |
})
|
| 457 |
|
| 458 |
+
|
| 459 |
if self.config.include_retrieval_info:
|
| 460 |
metadata.update({
|
| 461 |
"retrieval_query": retrieval_result.query,
|
|
|
|
| 463 |
"retrieval_metadata": retrieval_result.metadata
|
| 464 |
})
|
| 465 |
|
| 466 |
+
|
| 467 |
if self.config.include_timestamps:
|
| 468 |
metadata.update({
|
| 469 |
"processed_at": datetime.now().isoformat(),
|
|
|
|
| 480 |
}
|
| 481 |
})
|
| 482 |
|
| 483 |
+
|
| 484 |
word_count = len(content.split())
|
| 485 |
sentence_count = len(re.split(r'[.!?]+', content))
|
| 486 |
|
|
|
|
| 500 |
if not contexts:
|
| 501 |
return contexts
|
| 502 |
|
| 503 |
+
|
| 504 |
if self.config.use_retrieval_scores:
|
| 505 |
contexts.sort(key=lambda x: x.score or 0.0, reverse=True)
|
| 506 |
|
| 507 |
+
|
| 508 |
filtered_contexts = []
|
| 509 |
for ctx in contexts:
|
| 510 |
if self.config.filter_empty_content and not ctx.content.strip():
|
|
|
|
| 522 |
total_words = sum(len(ctx.content.split()) for ctx in contexts)
|
| 523 |
total_chars = sum(len(ctx.content) for ctx in contexts)
|
| 524 |
|
| 525 |
+
|
| 526 |
scores = [ctx.score for ctx in contexts if ctx.score is not None]
|
| 527 |
|
| 528 |
+
|
| 529 |
sources = {}
|
| 530 |
for ctx in contexts:
|
| 531 |
if ctx.source:
|
| 532 |
sources[ctx.source] = sources.get(ctx.source, 0) + 1
|
| 533 |
|
| 534 |
+
|
| 535 |
chunked_contexts = sum(1 for ctx in contexts
|
| 536 |
if ctx.metadata and ctx.metadata.get("is_chunked", False))
|
| 537 |
|
|
|
|
| 557 |
stats["source_distribution"] = sources
|
| 558 |
stats["unique_sources"] = len(sources)
|
| 559 |
|
| 560 |
+
|
| 561 |
lengths = [len(ctx.content) for ctx in contexts]
|
| 562 |
stats["content_length_stats"] = {
|
| 563 |
"min_length": min(lengths),
|
|
|
|
| 589 |
try:
|
| 590 |
contexts = self.process_retrieval_result(result)
|
| 591 |
|
| 592 |
+
|
| 593 |
for ctx in contexts:
|
| 594 |
if ctx.metadata:
|
| 595 |
ctx.metadata["batch_index"] = i
|
|
|
|
| 606 |
self.logger.error(f"Error processing retrieval result {i}: {e}")
|
| 607 |
continue
|
| 608 |
|
| 609 |
+
|
| 610 |
all_contexts = self._post_process_contexts(all_contexts)
|
| 611 |
|
| 612 |
self.logger.info(f"Batch processing completed: {len(all_contexts)} total contexts")
|
|
|
|
| 637 |
for ctx in contexts:
|
| 638 |
content_words = set(ctx.content.lower().split())
|
| 639 |
|
| 640 |
+
|
| 641 |
overlap = len(query_words.intersection(content_words))
|
| 642 |
relevance_score = overlap / len(query_words) if query_words else 0.0
|
| 643 |
|
| 644 |
if relevance_score >= min_relevance_score:
|
| 645 |
+
|
| 646 |
if ctx.metadata:
|
| 647 |
ctx.metadata["query_relevance_score"] = round(relevance_score, 3)
|
| 648 |
ctx.metadata["matched_query_words"] = list(query_words.intersection(content_words))
|
|
|
|
| 654 |
|
| 655 |
filtered_contexts.append(ctx)
|
| 656 |
|
| 657 |
+
|
| 658 |
filtered_contexts.sort(
|
| 659 |
key=lambda x: x.metadata.get("query_relevance_score", 0.0),
|
| 660 |
reverse=True
|
|
|
|
| 699 |
if sim_score >= similarity_threshold:
|
| 700 |
is_duplicate = True
|
| 701 |
|
| 702 |
+
|
| 703 |
if (ctx.score or 0.0) > (existing_ctx.score or 0.0):
|
| 704 |
+
|
| 705 |
idx = deduplicated.index(existing_ctx)
|
| 706 |
deduplicated[idx] = ctx
|
| 707 |
|
|
|
|
| 741 |
if not proc_result.chunks:
|
| 742 |
continue
|
| 743 |
|
| 744 |
+
|
| 745 |
for j, chunk in enumerate(proc_result.chunks):
|
| 746 |
+
|
| 747 |
source = self._extract_source(chunk)
|
| 748 |
|
| 749 |
+
|
| 750 |
metadata = {
|
| 751 |
"document_metadata": proc_result.document_metadata.__dict__,
|
| 752 |
"chunk_index": j,
|
|
|
|
| 755 |
"processed_at": datetime.now().isoformat()
|
| 756 |
}
|
| 757 |
|
| 758 |
+
|
| 759 |
if chunk.metadata:
|
| 760 |
metadata["original_chunk_metadata"] = chunk.metadata
|
| 761 |
|
| 762 |
+
|
| 763 |
cleaned_content = self._clean_text(chunk.page_content)
|
| 764 |
|
| 765 |
if not cleaned_content or len(cleaned_content) < self.config.min_content_length:
|
| 766 |
continue
|
| 767 |
|
| 768 |
+
|
| 769 |
context = RetrievalResult(
|
| 770 |
content=cleaned_content,
|
| 771 |
source=source,
|
| 772 |
+
score=1.0,
|
| 773 |
metadata=metadata
|
| 774 |
)
|
| 775 |
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/base_retriever.py
CHANGED
|
@@ -47,7 +47,7 @@ class BaseRetriever(ABC):
|
|
| 47 |
"""Delete documents by IDs"""
|
| 48 |
pass
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
class MultiFormatDocumentLoader(BaseDocumentLoader):
|
| 53 |
"""Document loader supporting multiple formats"""
|
|
@@ -68,10 +68,10 @@ class MultiFormatDocumentLoader(BaseDocumentLoader):
|
|
| 68 |
if not file_path.exists():
|
| 69 |
raise FileNotFoundError(f"File not found: {file_path}")
|
| 70 |
|
| 71 |
-
|
| 72 |
doc_type = self._get_document_type(file_path)
|
| 73 |
|
| 74 |
-
|
| 75 |
loader_func = self.loaders.get(doc_type)
|
| 76 |
if not loader_func:
|
| 77 |
raise ValueError(f"Unsupported file type: {doc_type}")
|
|
@@ -79,7 +79,7 @@ class MultiFormatDocumentLoader(BaseDocumentLoader):
|
|
| 79 |
logger.info(f"Loading {doc_type} document: {file_path}")
|
| 80 |
documents = await loader_func(str(file_path))
|
| 81 |
|
| 82 |
-
|
| 83 |
for doc in documents:
|
| 84 |
doc.metadata.update({
|
| 85 |
"file_path": str(file_path),
|
|
|
|
| 47 |
"""Delete documents by IDs"""
|
| 48 |
pass
|
| 49 |
|
| 50 |
+
|
| 51 |
|
| 52 |
class MultiFormatDocumentLoader(BaseDocumentLoader):
|
| 53 |
"""Document loader supporting multiple formats"""
|
|
|
|
| 68 |
if not file_path.exists():
|
| 69 |
raise FileNotFoundError(f"File not found: {file_path}")
|
| 70 |
|
| 71 |
+
|
| 72 |
doc_type = self._get_document_type(file_path)
|
| 73 |
|
| 74 |
+
|
| 75 |
loader_func = self.loaders.get(doc_type)
|
| 76 |
if not loader_func:
|
| 77 |
raise ValueError(f"Unsupported file type: {doc_type}")
|
|
|
|
| 79 |
logger.info(f"Loading {doc_type} document: {file_path}")
|
| 80 |
documents = await loader_func(str(file_path))
|
| 81 |
|
| 82 |
+
|
| 83 |
for doc in documents:
|
| 84 |
doc.metadata.update({
|
| 85 |
"file_path": str(file_path),
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/document_loader.py
CHANGED
|
@@ -56,10 +56,10 @@ class MultiFormatDocumentLoader(BaseDocumentLoader):
|
|
| 56 |
if not file_path.exists():
|
| 57 |
raise FileNotFoundError(f"File not found: {file_path}")
|
| 58 |
|
| 59 |
-
|
| 60 |
doc_type = self._get_document_type(file_path)
|
| 61 |
|
| 62 |
-
|
| 63 |
loader_func = self.loaders.get(doc_type)
|
| 64 |
if not loader_func:
|
| 65 |
raise ValueError(f"Unsupported file type: {doc_type}")
|
|
@@ -67,7 +67,7 @@ class MultiFormatDocumentLoader(BaseDocumentLoader):
|
|
| 67 |
logger.info(f"Loading {doc_type} document: {file_path}")
|
| 68 |
documents = await loader_func(str(file_path))
|
| 69 |
|
| 70 |
-
|
| 71 |
for doc in documents:
|
| 72 |
doc.metadata.update({
|
| 73 |
"file_path": str(file_path),
|
|
|
|
| 56 |
if not file_path.exists():
|
| 57 |
raise FileNotFoundError(f"File not found: {file_path}")
|
| 58 |
|
| 59 |
+
|
| 60 |
doc_type = self._get_document_type(file_path)
|
| 61 |
|
| 62 |
+
|
| 63 |
loader_func = self.loaders.get(doc_type)
|
| 64 |
if not loader_func:
|
| 65 |
raise ValueError(f"Unsupported file type: {doc_type}")
|
|
|
|
| 67 |
logger.info(f"Loading {doc_type} document: {file_path}")
|
| 68 |
documents = await loader_func(str(file_path))
|
| 69 |
|
| 70 |
+
|
| 71 |
for doc in documents:
|
| 72 |
doc.metadata.update({
|
| 73 |
"file_path": str(file_path),
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/document_processor.py
CHANGED
|
@@ -18,7 +18,7 @@ class DocumentProcessor:
|
|
| 18 |
self.chunk_size = chunk_size
|
| 19 |
self.chunk_overlap = chunk_overlap
|
| 20 |
|
| 21 |
-
|
| 22 |
if separators is None:
|
| 23 |
separators = ["\n\n", "\n", " ", ""]
|
| 24 |
|
|
@@ -34,12 +34,12 @@ class DocumentProcessor:
|
|
| 34 |
try:
|
| 35 |
logger.info(f"Processing {len(documents)} documents")
|
| 36 |
|
| 37 |
-
|
| 38 |
chunks = await asyncio.get_event_loop().run_in_executor(
|
| 39 |
None, self.text_splitter.split_documents, documents
|
| 40 |
)
|
| 41 |
|
| 42 |
-
|
| 43 |
for i, chunk in enumerate(chunks):
|
| 44 |
chunk.metadata.update({
|
| 45 |
"chunk_id": i,
|
|
|
|
| 18 |
self.chunk_size = chunk_size
|
| 19 |
self.chunk_overlap = chunk_overlap
|
| 20 |
|
| 21 |
+
|
| 22 |
if separators is None:
|
| 23 |
separators = ["\n\n", "\n", " ", ""]
|
| 24 |
|
|
|
|
| 34 |
try:
|
| 35 |
logger.info(f"Processing {len(documents)} documents")
|
| 36 |
|
| 37 |
+
|
| 38 |
chunks = await asyncio.get_event_loop().run_in_executor(
|
| 39 |
None, self.text_splitter.split_documents, documents
|
| 40 |
)
|
| 41 |
|
| 42 |
+
|
| 43 |
for i, chunk in enumerate(chunks):
|
| 44 |
chunk.metadata.update({
|
| 45 |
"chunk_id": i,
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/langchain_retriever.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
from rag.retriever.base_retriever import BaseRetriever
|
| 2 |
|
| 3 |
-
|
| 4 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 5 |
from langchain_openai import OpenAIEmbeddings
|
| 6 |
|
| 7 |
-
|
| 8 |
from langchain_community.vectorstores import Chroma, FAISS, Pinecone
|
| 9 |
from langchain.retrievers import EnsembleRetriever
|
| 10 |
|
| 11 |
-
|
| 12 |
from langchain_core.vectorstores import VectorStoreRetriever
|
| 13 |
from langchain_community.retrievers import BM25Retriever
|
| 14 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
@@ -85,8 +85,8 @@ class LangChainRetriever(BaseRetriever):
|
|
| 85 |
search_kwargs={"k": 10}
|
| 86 |
)
|
| 87 |
if self.use_hybrid_search:
|
| 88 |
-
self.bm25_retriever = None
|
| 89 |
-
return vector_retriever
|
| 90 |
else:
|
| 91 |
return vector_retriever
|
| 92 |
except Exception as e:
|
|
@@ -162,13 +162,13 @@ class LangChainRetriever(BaseRetriever):
|
|
| 162 |
return False
|
| 163 |
async def _update_bm25_retriever(self, documents: List[Document]):
|
| 164 |
try:
|
| 165 |
-
|
| 166 |
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 167 |
-
self.bm25_retriever.k = 10
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
# For hybrid search, you have several options:
|
| 170 |
|
| 171 |
-
# Option 1: Use only BM25 retriever (simplest fix)
|
| 172 |
self.retriever = self.bm25_retriever
|
| 173 |
|
| 174 |
vector_retriever = VectorStoreRetriever(
|
|
@@ -178,12 +178,12 @@ class LangChainRetriever(BaseRetriever):
|
|
| 178 |
|
| 179 |
self.retriever = EnsembleRetriever(
|
| 180 |
retrievers=[vector_retriever, self.bm25_retriever],
|
| 181 |
-
weights=[0.5, 0.5]
|
| 182 |
)
|
| 183 |
|
| 184 |
except Exception as e:
|
| 185 |
logger.error(f"Error updating BM25 retriever: {str(e)}")
|
| 186 |
-
|
| 187 |
self.retriever = VectorStoreRetriever(
|
| 188 |
vectorstore=self.vectorstore,
|
| 189 |
search_kwargs={"k": 10}
|
|
|
|
| 1 |
from rag.retriever.base_retriever import BaseRetriever
|
| 2 |
|
| 3 |
+
|
| 4 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 5 |
from langchain_openai import OpenAIEmbeddings
|
| 6 |
|
| 7 |
+
|
| 8 |
from langchain_community.vectorstores import Chroma, FAISS, Pinecone
|
| 9 |
from langchain.retrievers import EnsembleRetriever
|
| 10 |
|
| 11 |
+
|
| 12 |
from langchain_core.vectorstores import VectorStoreRetriever
|
| 13 |
from langchain_community.retrievers import BM25Retriever
|
| 14 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
|
| 85 |
search_kwargs={"k": 10}
|
| 86 |
)
|
| 87 |
if self.use_hybrid_search:
|
| 88 |
+
self.bm25_retriever = None
|
| 89 |
+
return vector_retriever
|
| 90 |
else:
|
| 91 |
return vector_retriever
|
| 92 |
except Exception as e:
|
|
|
|
| 162 |
return False
|
| 163 |
async def _update_bm25_retriever(self, documents: List[Document]):
|
| 164 |
try:
|
| 165 |
+
|
| 166 |
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 167 |
+
self.bm25_retriever.k = 10
|
| 168 |
+
|
| 169 |
+
|
| 170 |
|
|
|
|
| 171 |
|
|
|
|
| 172 |
self.retriever = self.bm25_retriever
|
| 173 |
|
| 174 |
vector_retriever = VectorStoreRetriever(
|
|
|
|
| 178 |
|
| 179 |
self.retriever = EnsembleRetriever(
|
| 180 |
retrievers=[vector_retriever, self.bm25_retriever],
|
| 181 |
+
weights=[0.5, 0.5]
|
| 182 |
)
|
| 183 |
|
| 184 |
except Exception as e:
|
| 185 |
logger.error(f"Error updating BM25 retriever: {str(e)}")
|
| 186 |
+
|
| 187 |
self.retriever = VectorStoreRetriever(
|
| 188 |
vectorstore=self.vectorstore,
|
| 189 |
search_kwargs={"k": 10}
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/web_search/duckduckgo_search.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import AsyncGenerator, List
|
|
| 7 |
|
| 8 |
class DuckDuckGoSearch:
|
| 9 |
def __init__(self, html_loader: AsyncChromiumLoader = None, html_parser = None):
|
| 10 |
-
|
| 11 |
self.html_loader = html_loader or AsyncChromiumLoader([])
|
| 12 |
self.html_parser = html_parser or BeautifulSoupTransformer()
|
| 13 |
self.logger = logging.getLogger("ddgs_logger")
|
|
@@ -16,7 +16,7 @@ class DuckDuckGoSearch:
|
|
| 16 |
"""Get page content from URLs - returns list of documents"""
|
| 17 |
try:
|
| 18 |
self.html_loader.urls = urls
|
| 19 |
-
html = await self.html_loader.aload()
|
| 20 |
self.logger.info(f"search engine aload result: {len(html)} documents loaded")
|
| 21 |
|
| 22 |
docs_transformed = self.html_parser.transform_documents(
|
|
@@ -24,11 +24,11 @@ class DuckDuckGoSearch:
|
|
| 24 |
tags_to_extract=["p"],
|
| 25 |
remove_unwanted_tags=["a"]
|
| 26 |
)
|
| 27 |
-
return docs_transformed
|
| 28 |
|
| 29 |
except Exception as e:
|
| 30 |
self.logger.error(f"Error loading pages: {e}", exc_info=True)
|
| 31 |
-
return []
|
| 32 |
|
| 33 |
def truncate(self, text: str, max_words: int = 400) -> str:
|
| 34 |
"""Truncate text to specified number of words"""
|
|
@@ -51,12 +51,12 @@ class DuckDuckGoSearch:
|
|
| 51 |
try:
|
| 52 |
self.logger.info(f"Searching for: {query} (max_results: {max_results})")
|
| 53 |
|
| 54 |
-
|
| 55 |
results = DDGS().text(query, max_results=max_results)
|
| 56 |
urls = []
|
| 57 |
|
| 58 |
-
|
| 59 |
-
for result in results:
|
| 60 |
url = result.get('href')
|
| 61 |
if url:
|
| 62 |
urls.append(url)
|
|
@@ -67,20 +67,20 @@ class DuckDuckGoSearch:
|
|
| 67 |
self.logger.warning("No URLs found from search results")
|
| 68 |
return
|
| 69 |
|
| 70 |
-
# Step 3: Get page content (await the coroutine first)
|
| 71 |
-
docs = await self.get_page(urls) # ← FIXED: Await first, get list
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
try:
|
| 76 |
if hasattr(doc, 'page_content') and doc.page_content:
|
| 77 |
-
|
| 78 |
page_text = re.sub(r"\n\n+", "\n", doc.page_content)
|
| 79 |
page_text = page_text.strip()
|
| 80 |
|
| 81 |
-
if page_text:
|
| 82 |
text = self.truncate(page_text)
|
| 83 |
-
yield text
|
| 84 |
|
| 85 |
except Exception as e:
|
| 86 |
self.logger.error(f"Error processing document: {e}")
|
|
@@ -88,7 +88,7 @@ class DuckDuckGoSearch:
|
|
| 88 |
|
| 89 |
except Exception as e:
|
| 90 |
self.logger.error(f"Error in search method: {e}", exc_info=True)
|
| 91 |
-
|
| 92 |
|
| 93 |
async def search_with_metadata(self, query: str, max_results: int = 5) -> AsyncGenerator[dict, None]:
|
| 94 |
"""
|
|
@@ -98,7 +98,7 @@ class DuckDuckGoSearch:
|
|
| 98 |
results = DDGS().text(query, max_results=max_results)
|
| 99 |
urls_and_titles = []
|
| 100 |
|
| 101 |
-
|
| 102 |
for result in results:
|
| 103 |
url = result.get('href')
|
| 104 |
title = result.get('title', 'No title')
|
|
@@ -108,11 +108,11 @@ class DuckDuckGoSearch:
|
|
| 108 |
if not urls_and_titles:
|
| 109 |
return
|
| 110 |
|
| 111 |
-
|
| 112 |
urls = [item['url'] for item in urls_and_titles]
|
| 113 |
docs = await self.get_page(urls)
|
| 114 |
|
| 115 |
-
|
| 116 |
for i, doc in enumerate(docs):
|
| 117 |
try:
|
| 118 |
if hasattr(doc, 'page_content') and doc.page_content:
|
|
@@ -122,7 +122,7 @@ class DuckDuckGoSearch:
|
|
| 122 |
if page_text:
|
| 123 |
text = self.truncate(page_text)
|
| 124 |
|
| 125 |
-
|
| 126 |
metadata = {}
|
| 127 |
if i < len(urls_and_titles):
|
| 128 |
metadata = urls_and_titles[i]
|
|
|
|
| 7 |
|
| 8 |
class DuckDuckGoSearch:
|
| 9 |
def __init__(self, html_loader: AsyncChromiumLoader = None, html_parser = None):
|
| 10 |
+
|
| 11 |
self.html_loader = html_loader or AsyncChromiumLoader([])
|
| 12 |
self.html_parser = html_parser or BeautifulSoupTransformer()
|
| 13 |
self.logger = logging.getLogger("ddgs_logger")
|
|
|
|
| 16 |
"""Get page content from URLs - returns list of documents"""
|
| 17 |
try:
|
| 18 |
self.html_loader.urls = urls
|
| 19 |
+
html = await self.html_loader.aload()
|
| 20 |
self.logger.info(f"search engine aload result: {len(html)} documents loaded")
|
| 21 |
|
| 22 |
docs_transformed = self.html_parser.transform_documents(
|
|
|
|
| 24 |
tags_to_extract=["p"],
|
| 25 |
remove_unwanted_tags=["a"]
|
| 26 |
)
|
| 27 |
+
return docs_transformed
|
| 28 |
|
| 29 |
except Exception as e:
|
| 30 |
self.logger.error(f"Error loading pages: {e}", exc_info=True)
|
| 31 |
+
return []
|
| 32 |
|
| 33 |
def truncate(self, text: str, max_words: int = 400) -> str:
|
| 34 |
"""Truncate text to specified number of words"""
|
|
|
|
| 51 |
try:
|
| 52 |
self.logger.info(f"Searching for: {query} (max_results: {max_results})")
|
| 53 |
|
| 54 |
+
|
| 55 |
results = DDGS().text(query, max_results=max_results)
|
| 56 |
urls = []
|
| 57 |
|
| 58 |
+
|
| 59 |
+
for result in results:
|
| 60 |
url = result.get('href')
|
| 61 |
if url:
|
| 62 |
urls.append(url)
|
|
|
|
| 67 |
self.logger.warning("No URLs found from search results")
|
| 68 |
return
|
| 69 |
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
docs = await self.get_page(urls)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
for doc in docs:
|
| 75 |
try:
|
| 76 |
if hasattr(doc, 'page_content') and doc.page_content:
|
| 77 |
+
|
| 78 |
page_text = re.sub(r"\n\n+", "\n", doc.page_content)
|
| 79 |
page_text = page_text.strip()
|
| 80 |
|
| 81 |
+
if page_text:
|
| 82 |
text = self.truncate(page_text)
|
| 83 |
+
yield text
|
| 84 |
|
| 85 |
except Exception as e:
|
| 86 |
self.logger.error(f"Error processing document: {e}")
|
|
|
|
| 88 |
|
| 89 |
except Exception as e:
|
| 90 |
self.logger.error(f"Error in search method: {e}", exc_info=True)
|
| 91 |
+
|
| 92 |
|
| 93 |
async def search_with_metadata(self, query: str, max_results: int = 5) -> AsyncGenerator[dict, None]:
|
| 94 |
"""
|
|
|
|
| 98 |
results = DDGS().text(query, max_results=max_results)
|
| 99 |
urls_and_titles = []
|
| 100 |
|
| 101 |
+
|
| 102 |
for result in results:
|
| 103 |
url = result.get('href')
|
| 104 |
title = result.get('title', 'No title')
|
|
|
|
| 108 |
if not urls_and_titles:
|
| 109 |
return
|
| 110 |
|
| 111 |
+
|
| 112 |
urls = [item['url'] for item in urls_and_titles]
|
| 113 |
docs = await self.get_page(urls)
|
| 114 |
|
| 115 |
+
|
| 116 |
for i, doc in enumerate(docs):
|
| 117 |
try:
|
| 118 |
if hasattr(doc, 'page_content') and doc.page_content:
|
|
|
|
| 122 |
if page_text:
|
| 123 |
text = self.truncate(page_text)
|
| 124 |
|
| 125 |
+
|
| 126 |
metadata = {}
|
| 127 |
if i < len(urls_and_titles):
|
| 128 |
metadata = urls_and_titles[i]
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rtc/rtc_call.py
CHANGED
|
@@ -31,7 +31,7 @@ import re
|
|
| 31 |
|
| 32 |
|
| 33 |
from rag import cs_agent
|
| 34 |
-
|
| 35 |
load_dotenv()
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
| 37 |
|
|
@@ -101,7 +101,7 @@ class RTCHandler:
|
|
| 101 |
llm_time = time.time()
|
| 102 |
self.full_response = ""
|
| 103 |
|
| 104 |
-
|
| 105 |
async def stream_text_to_audio():
|
| 106 |
chunk_size = 1024
|
| 107 |
no_buffer = 0
|
|
@@ -113,7 +113,7 @@ class RTCHandler:
|
|
| 113 |
chunk = stream_data["data"]["chunk"]
|
| 114 |
self.full_response += chunk
|
| 115 |
text_buffer += chunk
|
| 116 |
-
|
| 117 |
if re.search(r'[.,?;!]', chunk):
|
| 118 |
try:
|
| 119 |
audio_buffer_gen = await self.edge_tts.generate_audio_buffer(text_buffer)
|
|
@@ -121,37 +121,37 @@ class RTCHandler:
|
|
| 121 |
|
| 122 |
audio_buffer.seek(0)
|
| 123 |
|
| 124 |
-
|
| 125 |
audio_segment = AudioSegment.from_file(audio_buffer, format="mp3")
|
| 126 |
samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32) / (2 ** 15)
|
| 127 |
|
| 128 |
-
|
| 129 |
if audio_segment.channels == 2:
|
| 130 |
samples = samples.reshape((-1, 2)).mean(axis=1)
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
import torch
|
| 135 |
import torchaudio
|
| 136 |
|
| 137 |
-
|
| 138 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 139 |
|
| 140 |
-
# Convert numpy array to torch tensor and move to GPU
|
| 141 |
-
audio_tensor = torch.from_numpy(samples).unsqueeze(0).to(device) # Add batch dimension and move to GPU
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
resampler = torchaudio.transforms.Resample(
|
| 145 |
orig_freq=audio_segment.frame_rate,
|
| 146 |
new_freq=24000
|
| 147 |
).to(device)
|
| 148 |
|
| 149 |
-
|
| 150 |
resampled_tensor = resampler(audio_tensor)
|
| 151 |
|
| 152 |
-
|
| 153 |
resampled = resampled_tensor.squeeze(0).cpu().numpy()
|
| 154 |
-
|
| 155 |
for i in range(0, len(resampled), chunk_size):
|
| 156 |
yield (24000, resampled[i:i + chunk_size])
|
| 157 |
no_buffer = 0
|
|
@@ -169,7 +169,7 @@ class RTCHandler:
|
|
| 169 |
print(f"\nTotal time: {total_time:.2f}s")
|
| 170 |
break
|
| 171 |
|
| 172 |
-
|
| 173 |
loop = asyncio.new_event_loop()
|
| 174 |
asyncio.set_event_loop(loop)
|
| 175 |
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
from rag import cs_agent
|
| 34 |
+
|
| 35 |
load_dotenv()
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
| 37 |
|
|
|
|
| 101 |
llm_time = time.time()
|
| 102 |
self.full_response = ""
|
| 103 |
|
| 104 |
+
|
| 105 |
async def stream_text_to_audio():
|
| 106 |
chunk_size = 1024
|
| 107 |
no_buffer = 0
|
|
|
|
| 113 |
chunk = stream_data["data"]["chunk"]
|
| 114 |
self.full_response += chunk
|
| 115 |
text_buffer += chunk
|
| 116 |
+
|
| 117 |
if re.search(r'[.,?;!]', chunk):
|
| 118 |
try:
|
| 119 |
audio_buffer_gen = await self.edge_tts.generate_audio_buffer(text_buffer)
|
|
|
|
| 121 |
|
| 122 |
audio_buffer.seek(0)
|
| 123 |
|
| 124 |
+
|
| 125 |
audio_segment = AudioSegment.from_file(audio_buffer, format="mp3")
|
| 126 |
samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32) / (2 ** 15)
|
| 127 |
|
| 128 |
+
|
| 129 |
if audio_segment.channels == 2:
|
| 130 |
samples = samples.reshape((-1, 2)).mean(axis=1)
|
| 131 |
|
| 132 |
+
|
| 133 |
+
|
| 134 |
import torch
|
| 135 |
import torchaudio
|
| 136 |
|
| 137 |
+
|
| 138 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 139 |
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
audio_tensor = torch.from_numpy(samples).unsqueeze(0).to(device)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
resampler = torchaudio.transforms.Resample(
|
| 145 |
orig_freq=audio_segment.frame_rate,
|
| 146 |
new_freq=24000
|
| 147 |
).to(device)
|
| 148 |
|
| 149 |
+
|
| 150 |
resampled_tensor = resampler(audio_tensor)
|
| 151 |
|
| 152 |
+
|
| 153 |
resampled = resampled_tensor.squeeze(0).cpu().numpy()
|
| 154 |
+
|
| 155 |
for i in range(0, len(resampled), chunk_size):
|
| 156 |
yield (24000, resampled[i:i + chunk_size])
|
| 157 |
no_buffer = 0
|
|
|
|
| 169 |
print(f"\nTotal time: {total_time:.2f}s")
|
| 170 |
break
|
| 171 |
|
| 172 |
+
|
| 173 |
loop = asyncio.new_event_loop()
|
| 174 |
asyncio.set_event_loop(loop)
|
| 175 |
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/stt/whisper_stt.py
CHANGED
|
@@ -15,27 +15,27 @@ class WhisperSTT:
|
|
| 15 |
model_size: Model size (tiny, base, small, medium, large)
|
| 16 |
device: Device to use ("auto", "cuda", "cpu")
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
cache_dir = os.environ.get('WHISPER_CACHE_DIR', '/tmp/.cache/whisper')
|
| 20 |
os.makedirs(cache_dir, exist_ok=True)
|
| 21 |
|
| 22 |
-
|
| 23 |
if device == "auto":
|
| 24 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
else:
|
| 26 |
self.device = device
|
| 27 |
|
| 28 |
-
|
| 29 |
if self.device == "cuda" and not torch.cuda.is_available():
|
| 30 |
print("Warning: CUDA requested but not available. Falling back to CPU.")
|
| 31 |
self.device = "cpu"
|
| 32 |
|
| 33 |
-
|
| 34 |
print(f"Loading Whisper model '{model_size}' on device: {self.device}")
|
| 35 |
self.model = whisper.load_model(model_size, device=self.device, download_root=cache_dir)
|
| 36 |
-
self.language = "id"
|
|
|
|
| 37 |
|
| 38 |
-
# Print GPU info if using CUDA
|
| 39 |
if self.device == "cuda":
|
| 40 |
gpu_name = torch.cuda.get_device_name(0)
|
| 41 |
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
@@ -52,23 +52,23 @@ class WhisperSTT:
|
|
| 52 |
Returns:
|
| 53 |
Transcribed text
|
| 54 |
"""
|
| 55 |
-
|
| 56 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 57 |
tmp.write(audio.read())
|
| 58 |
tmp.flush()
|
| 59 |
tmp_path = tmp.name
|
| 60 |
|
| 61 |
try:
|
| 62 |
-
|
| 63 |
result = self.model.transcribe(
|
| 64 |
tmp_path,
|
| 65 |
language=language,
|
| 66 |
-
|
| 67 |
fp16=self.device == "cuda"
|
| 68 |
)
|
| 69 |
return result.get("text", "")
|
| 70 |
finally:
|
| 71 |
-
|
| 72 |
os.remove(tmp_path)
|
| 73 |
|
| 74 |
def get_device_info(self) -> dict:
|
|
|
|
| 15 |
model_size: Model size (tiny, base, small, medium, large)
|
| 16 |
device: Device to use ("auto", "cuda", "cpu")
|
| 17 |
"""
|
| 18 |
+
|
| 19 |
cache_dir = os.environ.get('WHISPER_CACHE_DIR', '/tmp/.cache/whisper')
|
| 20 |
os.makedirs(cache_dir, exist_ok=True)
|
| 21 |
|
| 22 |
+
|
| 23 |
if device == "auto":
|
| 24 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
else:
|
| 26 |
self.device = device
|
| 27 |
|
| 28 |
+
|
| 29 |
if self.device == "cuda" and not torch.cuda.is_available():
|
| 30 |
print("Warning: CUDA requested but not available. Falling back to CPU.")
|
| 31 |
self.device = "cpu"
|
| 32 |
|
| 33 |
+
|
| 34 |
print(f"Loading Whisper model '{model_size}' on device: {self.device}")
|
| 35 |
self.model = whisper.load_model(model_size, device=self.device, download_root=cache_dir)
|
| 36 |
+
self.language = "id"
|
| 37 |
+
|
| 38 |
|
|
|
|
| 39 |
if self.device == "cuda":
|
| 40 |
gpu_name = torch.cuda.get_device_name(0)
|
| 41 |
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
|
|
| 52 |
Returns:
|
| 53 |
Transcribed text
|
| 54 |
"""
|
| 55 |
+
|
| 56 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 57 |
tmp.write(audio.read())
|
| 58 |
tmp.flush()
|
| 59 |
tmp_path = tmp.name
|
| 60 |
|
| 61 |
try:
|
| 62 |
+
|
| 63 |
result = self.model.transcribe(
|
| 64 |
tmp_path,
|
| 65 |
language=language,
|
| 66 |
+
|
| 67 |
fp16=self.device == "cuda"
|
| 68 |
)
|
| 69 |
return result.get("text", "")
|
| 70 |
finally:
|
| 71 |
+
|
| 72 |
os.remove(tmp_path)
|
| 73 |
|
| 74 |
def get_device_info(self) -> dict:
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/tts/audio_edge_tts.py
CHANGED
|
@@ -29,7 +29,7 @@ class EdgeTTS:
|
|
| 29 |
pitch=self.pitch_str
|
| 30 |
)
|
| 31 |
|
| 32 |
-
|
| 33 |
async for chunk in communicate.stream():
|
| 34 |
if chunk["type"] == "audio":
|
| 35 |
yield chunk["data"]
|
|
@@ -52,7 +52,7 @@ class EdgeTTS:
|
|
| 52 |
pitch=self.pitch_str
|
| 53 |
)
|
| 54 |
|
| 55 |
-
|
| 56 |
audio_buffer = io.BytesIO()
|
| 57 |
async for chunk in communicate.stream():
|
| 58 |
if chunk["type"] == "audio":
|
|
@@ -85,7 +85,7 @@ class EdgeTTS:
|
|
| 85 |
|
| 86 |
async for chunk in communicate.stream():
|
| 87 |
if chunk["type"] == "audio":
|
| 88 |
-
|
| 89 |
callback_func(chunk["data"], None)
|
| 90 |
|
| 91 |
except Exception as e:
|
|
|
|
| 29 |
pitch=self.pitch_str
|
| 30 |
)
|
| 31 |
|
| 32 |
+
|
| 33 |
async for chunk in communicate.stream():
|
| 34 |
if chunk["type"] == "audio":
|
| 35 |
yield chunk["data"]
|
|
|
|
| 52 |
pitch=self.pitch_str
|
| 53 |
)
|
| 54 |
|
| 55 |
+
|
| 56 |
audio_buffer = io.BytesIO()
|
| 57 |
async for chunk in communicate.stream():
|
| 58 |
if chunk["type"] == "audio":
|
|
|
|
| 85 |
|
| 86 |
async for chunk in communicate.stream():
|
| 87 |
if chunk["type"] == "audio":
|
| 88 |
+
|
| 89 |
callback_func(chunk["data"], None)
|
| 90 |
|
| 91 |
except Exception as e:
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/__init__.py
CHANGED
|
@@ -49,11 +49,11 @@ inferencer_config = InferencerConfig(
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
-
embedding_model="
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
-
chunk_size=
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
+
embedding_model="BAAI/bge-large-en",
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
+
chunk_size=3000,
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/__init__.py
ADDED
|
File without changes
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/agents.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.pipeline.language_model import LM
|
| 2 |
+
from rag.inference.inferencer import Inferencer
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
class Agent(ABC):
|
| 5 |
+
def __init__(self, inferencer:Inferencer, prompt_template = [
|
| 6 |
+
{
|
| 7 |
+
"role" : "system",
|
| 8 |
+
"content":"You are an agent that doing some specic task"
|
| 9 |
+
}
|
| 10 |
+
]):
|
| 11 |
+
self.inferencer = inferencer
|
| 12 |
+
self.inferencer.model.prompt_template = prompt_template
|
| 13 |
+
self.prompt = prompt_template
|
| 14 |
+
@abstractmethod
|
| 15 |
+
async def get_result(self):
|
| 16 |
+
pass
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/customer_service_agent.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.agents.agents import Agent
|
| 2 |
+
from rag.inference.inferencer import Inferencer
|
| 3 |
+
|
| 4 |
+
class CSAgent(Agent):
|
| 5 |
+
def __init__(self, inferencer : Inferencer , prompt_template):
|
| 6 |
+
super().__init__(inferencer, prompt_template)
|
| 7 |
+
self.inferencer = inferencer
|
| 8 |
+
self.prompt_template = prompt_template
|
| 9 |
+
self.file_paths = [
|
| 10 |
+
"../documents/bpjs.pdf",
|
| 11 |
+
# "../documents/pph21.pdf",
|
| 12 |
+
# "../documents/lembur.pdf",
|
| 13 |
+
# "../documents/uu13.pdf",
|
| 14 |
+
"../documents/file.pdf",
|
| 15 |
+
]
|
| 16 |
+
async def load_documents(self):
|
| 17 |
+
for file_path in self.file_paths:
|
| 18 |
+
await self.add_doc(file_path)
|
| 19 |
+
|
| 20 |
+
async def add_doc(self, file_path):
|
| 21 |
+
result = await self.inferencer.retriever.add_document_from_file(file_path)
|
| 22 |
+
if result.success:
|
| 23 |
+
print(f"Successfully processed: {result.document_metadata.file_name}")
|
| 24 |
+
print(f"Chunks created: {result.document_metadata.chunk_count}")
|
| 25 |
+
else:
|
| 26 |
+
print(f"Failed to process: {result.error_message}")
|
| 27 |
+
|
| 28 |
+
async def get_result(self, question):
|
| 29 |
+
self.inferencer.model.prompt_template = self.prompt_template
|
| 30 |
+
async for item in self.inferencer.infer_stream(query = question,
|
| 31 |
+
enable_reranking=False,
|
| 32 |
+
k=3):
|
| 33 |
+
yield item
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/gpt_customer_service_agent.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.agents.agents import Agent
|
| 2 |
+
from rag.pipeline.language_model import LM
|
| 3 |
+
from rag.inference.inferencer import Inferencer
|
| 4 |
+
|
| 5 |
+
class GPTCSAgent(Agent):
|
| 6 |
+
def __init__(self, inferencer : Inferencer , prompt_template):
|
| 7 |
+
super().__init__(inferencer, prompt_template)
|
| 8 |
+
self.inferencer = inferencer
|
| 9 |
+
self.prompt_template = prompt_template
|
| 10 |
+
async def get_result(self, question : str):
|
| 11 |
+
self.inferencer.model.prompt_template = self.prompt_template
|
| 12 |
+
print("Question received :", question)
|
| 13 |
+
return await self.inferencer.infer(query = question)
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/agents/query_maker_agent.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.agents.agents import Agent
|
| 2 |
+
from rag.pipeline.language_model import LM
|
| 3 |
+
from rag.inference.inferencer import Inferencer
|
| 4 |
+
|
| 5 |
+
class QueryMakerAgent(Agent):
|
| 6 |
+
def __init__(self, inferencer : Inferencer , prompt_template):
|
| 7 |
+
super().__init__(inferencer, prompt_template)
|
| 8 |
+
self.inferencer = inferencer
|
| 9 |
+
self.prompt_template = prompt_template
|
| 10 |
+
async def get_result(self, question : str):
|
| 11 |
+
self.inferencer.model.prompt_template = self.prompt_template
|
| 12 |
+
print("Question received :", question)
|
| 13 |
+
return await self.inferencer.infer(query = question)
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def read_template_txt(file_path):
|
| 2 |
+
"""Baca file txt biasa"""
|
| 3 |
+
with open(f"rag/chat_template/{file_path}.txt", 'r', encoding='utf-8') as f:
|
| 4 |
+
return f.read()
|
| 5 |
+
def get_chat_template(file_name):
|
| 6 |
+
sys_prompt = read_template_txt(file_name)
|
| 7 |
+
return [
|
| 8 |
+
{
|
| 9 |
+
"role" : "system",
|
| 10 |
+
"content" : f"""
|
| 11 |
+
{sys_prompt}
|
| 12 |
+
"""
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"role" : "user",
|
| 16 |
+
"content" : """
|
| 17 |
+
|
| 18 |
+
Please answer properly:
|
| 19 |
+
{question}
|
| 20 |
+
|
| 21 |
+
From given context :
|
| 22 |
+
{context}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
}
|
| 28 |
+
]
|
| 29 |
+
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/customer_service.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a friendly and professional Customer Service for Human Resource Information System (HRIS) field,
|
| 2 |
+
representative, fluent in Indonesian. Your job is to assist customers with accurate information based on your company's basic knowledge. Follow these guidelines:
|
| 3 |
+
|
| 4 |
+
- Always greet customers in a friendly and professional manner.
|
| 5 |
+
- Your answers are contextual and objective.
|
| 6 |
+
- Provide clear, easy-to-understand, and structured answers based on the context provided by the user.
|
| 7 |
+
- If information is not available, offer alternative assistance or direct them to the appropriate channel.
|
| 8 |
+
- Use polite language and empathize with the customer's needs.
|
| 9 |
+
- Conclude by offering further assistance.
|
| 10 |
+
- You are highly skilled in the area relevant to the given context.
|
| 11 |
+
|
| 12 |
+
Please use the given context to answer accurately.
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/query_maker.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Anda adalah agen AI yang tepat dan objektif,
|
| 2 |
+
Anda bertugas mengubah pertanyaan atau pernyataan pengguna menjadi query yang eksplisit dan efisien untuk keperluan pencarian dokumen dalam sistem RAG (Retrieval-Augmented Generation).
|
| 3 |
+
|
| 4 |
+
Ikuti langkah-langkah berikut:
|
| 5 |
+
|
| 6 |
+
1. Ekstrak bagian-bagian penting dari input pengguna:
|
| 7 |
+
- **Intent**: Tujuan utama atau jenis permintaan (misalnya: apa itu, cara, syarat, apakah bisa, berapa).
|
| 8 |
+
- **Entity/Noun Phrase**: Objek utama yang dibahas (misalnya: BPJS, tokenizer truncation, RWKV, gaji).
|
| 9 |
+
- **Context**: Informasi pendukung yang menyempitkan fokus (misalnya: kecelakaan kerja, gaji 1 juta per bulan, perusahaan mitra BPJS).
|
| 10 |
+
- **Question**: Pertanyaan spesifik yang ingin dijawab (misalnya: bagaimana prosesnya, apa manfaatnya, berapa jumlahnya).
|
| 11 |
+
|
| 12 |
+
2. Setelah semua elemen diidentifikasi, bentuk **Query RAG** dengan struktur: [INTENT] + [ENTITY] + [CONTEXT] + [QUESTION]
|
| 13 |
+
3. Gunakan bahasa natural yang ringkas, namun informatif dan eksplisit.
|
| 14 |
+
4. Generate hanya hasil akhirnya saja berupa satu buah kalimat
|
| 15 |
+
|
| 16 |
+
Contoh 0 :
|
| 17 |
+
User Input:
|
| 18 |
+
> Apa itu BPJS
|
| 19 |
+
Output : Pengertian BPJS
|
| 20 |
+
|
| 21 |
+
Contoh 1 :
|
| 22 |
+
User Input:
|
| 23 |
+
> Di mana lokasi PT Sakura System Solution ?
|
| 24 |
+
|
| 25 |
+
Output: Lokasi PT Sakura System Solution
|
| 26 |
+
|
| 27 |
+
Contoh 2:
|
| 28 |
+
User Input:
|
| 29 |
+
> Saya mengalami kecelakaan di kantor dan ingin tahu apakah bisa klaim BPJS karena perusahaan saya adalah mitra.
|
| 30 |
+
|
| 31 |
+
Output: apakah bisa klaim BPJS kecelakaan kerja di kantor jika perusahaan mitra dan apakah saya memenuhi syarat
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
**Tugas Anda sekarang:**
|
| 35 |
+
Lakukan proses di atas untuk setiap input pengguna yang diberikan. Hasilkan query RAG akhir yang siap digunakan dalam pencarian dokumen.
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/chat_template/query_maker_temp.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Anda adalah agen AI yang tepat dan objektif,
|
| 2 |
+
Anda bertugas mengubah pertanyaan atau pernyataan pengguna menjadi query yang eksplisit dan efisien untuk keperluan pencarian dokumen dalam sistem RAG (Retrieval-Augmented Generation).
|
| 3 |
+
|
| 4 |
+
Ikuti langkah-langkah berikut:
|
| 5 |
+
|
| 6 |
+
1. Ekstrak bagian-bagian penting dari input pengguna:
|
| 7 |
+
- **Intent**: Tujuan utama atau jenis permintaan (misalnya: apa itu, cara, syarat, apakah bisa, berapa).
|
| 8 |
+
- **Entity/Noun Phrase**: Objek utama yang dibahas (misalnya: BPJS, tokenizer truncation, RWKV, gaji).
|
| 9 |
+
- **Context**: Informasi pendukung yang menyempitkan fokus (misalnya: kecelakaan kerja, gaji 1 juta per bulan, perusahaan mitra BPJS).
|
| 10 |
+
- **Question**: Pertanyaan spesifik yang ingin dijawab (misalnya: bagaimana prosesnya, apa manfaatnya, berapa jumlahnya).
|
| 11 |
+
|
| 12 |
+
2. Setelah semua elemen diidentifikasi, bentuk **Query RAG** dengan struktur: [INTENT] + [ENTITY] + [CONTEXT] + [QUESTION]
|
| 13 |
+
3. Gunakan bahasa natural yang ringkas, namun informatif dan eksplisit.
|
| 14 |
+
4. Generate hanya hasil akhirnya saja berupa satu buah kalimat
|
| 15 |
+
|
| 16 |
+
Contoh 1 :
|
| 17 |
+
User Input:
|
| 18 |
+
> Di mana lokasi PT Sakura System Solution ?
|
| 19 |
+
|
| 20 |
+
Output: Lokasi PT Sakura System Solution
|
| 21 |
+
|
| 22 |
+
Contoh 2:
|
| 23 |
+
User Input:
|
| 24 |
+
> Saya mengalami kecelakaan di kantor dan ingin tahu apakah bisa klaim BPJS karena perusahaan saya adalah mitra.
|
| 25 |
+
|
| 26 |
+
Output: apakah bisa klaim BPJS kecelakaan kerja di kantor jika perusahaan mitra dan apakah saya memenuhi syarat
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
**Tugas Anda sekarang:**
|
| 30 |
+
Lakukan proses di atas untuk setiap input pengguna yang diberikan. Hasilkan query RAG akhir yang siap digunakan dalam pencarian dokumen.
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/inference/__init__.py
ADDED
|
File without changes
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/language_model.py
ADDED
|
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import asyncio
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer, BitsAndBytesConfig
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Optional, Dict, Any, List, Union, Callable, Awaitable, AsyncGenerator
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
from functools import partial
|
| 11 |
+
from threading import Thread
|
| 12 |
+
from rag.retriever.retriever_types import RetrievalResult
|
| 13 |
+
from langchain_core.documents import Document
|
| 14 |
+
import copy
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class LMConfig:
|
| 18 |
+
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 19 |
+
device: str = "cuda"
|
| 20 |
+
torch_dtype: torch.dtype = torch.float16
|
| 21 |
+
max_length: int = 2048
|
| 22 |
+
temperature: float = 0.7
|
| 23 |
+
top_p: float = 0.8
|
| 24 |
+
top_k: int = 50
|
| 25 |
+
do_sample: bool = True
|
| 26 |
+
quantization_config: any = None
|
| 27 |
+
pad_token_id: Optional[int] = None
|
| 28 |
+
eos_token_id: Optional[int] = None
|
| 29 |
+
# RAG-specific configs
|
| 30 |
+
max_context_length: int = 1500
|
| 31 |
+
context_separator: str = "\n---\n"
|
| 32 |
+
instruction_template: str = "system" # "system", "instruction", "custom"
|
| 33 |
+
# Async-specific configs
|
| 34 |
+
max_workers: int = 2
|
| 35 |
+
generation_timeout: float = 30
|
| 36 |
+
repetition_penalty: float = 1.0
|
| 37 |
+
# Streaming-specific configs
|
| 38 |
+
stream_timeout: float = 100 # timeout untuk stream chunk
|
| 39 |
+
skip_prompt: bool = True # skip prompt dari streaming output
|
| 40 |
+
|
| 41 |
+
class LM:
|
| 42 |
+
"""
|
| 43 |
+
Async LLM Qwen 0.5B dengan interface yang mudah digunakan
|
| 44 |
+
Termasuk prompt formatting khusus untuk RAG (Retrieval-Augmented Generation)
|
| 45 |
+
Dan support untuk text streaming
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: Optional[LMConfig] = None, prompt_template = [
|
| 49 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 50 |
+
{"role": "user", "content": "{question}"}
|
| 51 |
+
] ):
|
| 52 |
+
"""
|
| 53 |
+
Inisialisasi LM
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
config: Konfigurasi model (optional, akan menggunakan default jika None)
|
| 57 |
+
"""
|
| 58 |
+
if(config is None):
|
| 59 |
+
self.config = LMConfig()
|
| 60 |
+
else:
|
| 61 |
+
self.config = config
|
| 62 |
+
self.tokenizer : AutoTokenizer = None
|
| 63 |
+
self.model = None
|
| 64 |
+
self.generation_config = None
|
| 65 |
+
self.is_loaded = False
|
| 66 |
+
self.executor = ThreadPoolExecutor(max_workers=self.config.max_workers)
|
| 67 |
+
self._lock = asyncio.Lock()
|
| 68 |
+
# Setup logging
|
| 69 |
+
logging.basicConfig(level=logging.INFO)
|
| 70 |
+
self.logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
# RAG prompt templates
|
| 73 |
+
self.prompt_template = prompt_template
|
| 74 |
+
|
| 75 |
+
async def load_model(self) -> None:
|
| 76 |
+
"""Load model dan tokenizer secara async"""
|
| 77 |
+
async with self._lock:
|
| 78 |
+
if self.is_loaded:
|
| 79 |
+
self.logger.info("Model already loaded")
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
self.logger.info(f"Loading model: {self.config.model_name}")
|
| 84 |
+
|
| 85 |
+
# Load tokenizer dalam thread pool
|
| 86 |
+
self.tokenizer = await asyncio.get_event_loop().run_in_executor(
|
| 87 |
+
self.executor,
|
| 88 |
+
lambda: AutoTokenizer.from_pretrained(
|
| 89 |
+
self.config.model_name,
|
| 90 |
+
trust_remote_code=True,
|
| 91 |
+
torch_dtype="auto",
|
| 92 |
+
device_map="auto",
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Load model dalam thread pool
|
| 97 |
+
self.model = await asyncio.get_event_loop().run_in_executor(
|
| 98 |
+
self.executor,
|
| 99 |
+
lambda: AutoModelForCausalLM.from_pretrained(
|
| 100 |
+
self.config.model_name,
|
| 101 |
+
quantization_config=self.config.quantization_config,
|
| 102 |
+
torch_dtype=self.config.torch_dtype,
|
| 103 |
+
device_map=self.config.device,
|
| 104 |
+
trust_remote_code=True
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Setup generation config
|
| 109 |
+
self.generation_config = GenerationConfig(
|
| 110 |
+
max_length=self.config.max_length,
|
| 111 |
+
temperature=self.config.temperature,
|
| 112 |
+
top_p=self.config.top_p,
|
| 113 |
+
top_k=self.config.top_k,
|
| 114 |
+
do_sample=self.config.do_sample,
|
| 115 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 116 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 117 |
+
repetition_penalty = self.config.repetition_penalty,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.is_loaded = True
|
| 121 |
+
self.logger.info("Model loaded successfully!")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
self.logger.error(f"Error loading model: {e}")
|
| 125 |
+
raise
|
| 126 |
+
|
| 127 |
+
def get_available_templates(self) -> List[str]:
|
| 128 |
+
"""
|
| 129 |
+
Dapatkan list template yang tersedia
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
List of available template names
|
| 133 |
+
"""
|
| 134 |
+
return list(self.prompt_template)
|
| 135 |
+
|
| 136 |
+
def preview_template(self, template_type: str, sample_question: str = "Apa itu AI?",
|
| 137 |
+
sample_context: str = "Artificial Intelligence adalah teknologi...") -> str:
|
| 138 |
+
"""
|
| 139 |
+
Preview template dengan sample data
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
template_type: Template type to preview
|
| 143 |
+
sample_question: Sample question
|
| 144 |
+
sample_context: Sample context
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Preview of formatted template
|
| 148 |
+
"""
|
| 149 |
+
if template_type not in self.prompt_template:
|
| 150 |
+
return f"Template '{template_type}' tidak tersedia. Available: {self.get_available_templates()}"
|
| 151 |
+
|
| 152 |
+
template_data = copy.deepcopy(self.prompt_template)
|
| 153 |
+
# template_key = "user_template" if "user_template" in template_data else "template"
|
| 154 |
+
|
| 155 |
+
return template_data["content"].format(
|
| 156 |
+
context=sample_context,
|
| 157 |
+
question=sample_question
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def _format_context(self, contexts: Union[List[str], RetrievalResult], numbering: bool = True) -> str:
|
| 161 |
+
"""
|
| 162 |
+
Format retrieved contexts menjadi string yang coherent
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
contexts: List of contexts (string atau RetrievalResult objects)
|
| 166 |
+
numbering: Whether to add document numbering
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Formatted context string
|
| 170 |
+
"""
|
| 171 |
+
if not contexts:
|
| 172 |
+
return ""
|
| 173 |
+
|
| 174 |
+
formatted_contexts = []
|
| 175 |
+
self.logger.info(f"Context : {contexts}")
|
| 176 |
+
self.logger.info(f"Is RetrievalResult Contexts = {isinstance(contexts, RetrievalResult)}")
|
| 177 |
+
if isinstance(contexts, RetrievalResult):
|
| 178 |
+
for i, ctx in enumerate(contexts.documents, 1):
|
| 179 |
+
if numbering:
|
| 180 |
+
header = f"[Dokumen {i}"
|
| 181 |
+
if contexts.scores[i - 1]:
|
| 182 |
+
header += f" (Skor: {contexts.scores[i - 1]:.3f})"
|
| 183 |
+
header += "]"
|
| 184 |
+
else:
|
| 185 |
+
header = "[Dokumen"
|
| 186 |
+
header += "]"
|
| 187 |
+
formatted_contexts.append(f"{header}\n{ctx.page_content}")
|
| 188 |
+
else:
|
| 189 |
+
for i, ctx in enumerate(contexts, 1):
|
| 190 |
+
if isinstance(ctx, str):
|
| 191 |
+
header = f"[Dokumen {i}]" if numbering else "[Dokumen]"
|
| 192 |
+
formatted_contexts.append(f"{header}\n{ctx}")
|
| 193 |
+
else:
|
| 194 |
+
header = f"[Dokumen {i}]" if numbering else "[Dokumen]"
|
| 195 |
+
formatted_contexts.append(f"{header}\n{str(ctx)}")
|
| 196 |
+
|
| 197 |
+
return self.config.context_separator.join(formatted_contexts)
|
| 198 |
+
|
| 199 |
+
def _truncate_context(self, context: str, max_length: int) -> str:
|
| 200 |
+
"""
|
| 201 |
+
Truncate context jika terlalu panjang
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
context: Context string
|
| 205 |
+
max_length: Maximum length in characters
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Truncated context
|
| 209 |
+
"""
|
| 210 |
+
if len(context) <= max_length:
|
| 211 |
+
return context
|
| 212 |
+
|
| 213 |
+
# Truncate dan tambahkan indicator
|
| 214 |
+
truncated = context[:max_length - 50]
|
| 215 |
+
return truncated + "\n\n[... Context dipotong karena terlalu panjang ...]"
|
| 216 |
+
|
| 217 |
+
async def format_rag_prompt(self,
|
| 218 |
+
question: str,
|
| 219 |
+
contexts: Union[List[str], RetrievalResult],
|
| 220 |
+
template_type: Optional[str] = None,
|
| 221 |
+
custom_template: Optional[str] = None,
|
| 222 |
+
include_metadata: bool = True,
|
| 223 |
+
context_numbering: bool = True,
|
| 224 |
+
max_contexts: Optional[int] = None) -> str:
|
| 225 |
+
"""
|
| 226 |
+
Format prompt untuk RAG dengan berbagai template options (async)
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def _format_sync():
|
| 230 |
+
|
| 231 |
+
# Handle RetrievalResult secara eksplisit
|
| 232 |
+
if isinstance(contexts, RetrievalResult):
|
| 233 |
+
docs = contexts.documents
|
| 234 |
+
if max_contexts:
|
| 235 |
+
docs = docs[:max_contexts]
|
| 236 |
+
processed_contexts = RetrievalResult(
|
| 237 |
+
documents=docs,
|
| 238 |
+
scores=contexts.scores[:len(docs)] if contexts.scores else [],
|
| 239 |
+
query=contexts.query,
|
| 240 |
+
retrieval_time=contexts.retrieval_time,
|
| 241 |
+
metadata=contexts.metadata
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
# contexts diasumsikan sebagai list biasa (list[str] atau list[Document])
|
| 245 |
+
processed_contexts = contexts[:max_contexts] if max_contexts and len(contexts) > max_contexts else contexts
|
| 246 |
+
|
| 247 |
+
# Format context menjadi string
|
| 248 |
+
formatted_context = self._format_context(processed_contexts, context_numbering)
|
| 249 |
+
|
| 250 |
+
# Truncate jika panjang melebihi batas
|
| 251 |
+
formatted_context = self._truncate_context(
|
| 252 |
+
formatted_context,
|
| 253 |
+
self.config.max_context_length
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Tambah metadata jika diizinkan dan konteks adalah RetrievalResult
|
| 257 |
+
if include_metadata and isinstance(processed_contexts, RetrievalResult):
|
| 258 |
+
metadata_info = []
|
| 259 |
+
for i, doc in enumerate(processed_contexts.documents, 1):
|
| 260 |
+
if hasattr(doc, "metadata") and doc.metadata:
|
| 261 |
+
metadata_info.append(f"Dokumen {i}: {doc.metadata}")
|
| 262 |
+
# if metadata_info:
|
| 263 |
+
# formatted_context += f"\n\n[Metadata]\n" + "\n".join(metadata_info)
|
| 264 |
+
|
| 265 |
+
return formatted_context
|
| 266 |
+
|
| 267 |
+
# Jalankan _format_sync di thread pool
|
| 268 |
+
formatted_context = await asyncio.get_event_loop().run_in_executor(
|
| 269 |
+
self.executor, _format_sync
|
| 270 |
+
)
|
| 271 |
+
self.logger.info(f"Formatted Context {formatted_context}")
|
| 272 |
+
# Tentukan template yang akan dipakai
|
| 273 |
+
if(template_type == ""):
|
| 274 |
+
self.config.instruction_template = "system"
|
| 275 |
+
# Gunakan custom template jika disediakan
|
| 276 |
+
if custom_template:
|
| 277 |
+
return custom_template.format(
|
| 278 |
+
context=formatted_context,
|
| 279 |
+
question=question
|
| 280 |
+
)
|
| 281 |
+
elif self.prompt_template:
|
| 282 |
+
print("question", question)
|
| 283 |
+
|
| 284 |
+
template_data = copy.deepcopy(self.prompt_template)
|
| 285 |
+
print("template = ", template_type, "rag template = ", template_data)
|
| 286 |
+
# template_key = "user_template" if "user_template" in template_data else "template"
|
| 287 |
+
|
| 288 |
+
formatted_template = []
|
| 289 |
+
for cht in template_data:
|
| 290 |
+
# Create a copy of the content to avoid modifying the original
|
| 291 |
+
content = cht["content"]
|
| 292 |
+
|
| 293 |
+
# Format both placeholders at once to avoid KeyError
|
| 294 |
+
if "{context}" in content or "{question}" in content:
|
| 295 |
+
try:
|
| 296 |
+
content = content.format(
|
| 297 |
+
context=formatted_context,
|
| 298 |
+
question=question
|
| 299 |
+
)
|
| 300 |
+
except KeyError as e:
|
| 301 |
+
self.logger.error(f"Missing placeholder in template: {e}")
|
| 302 |
+
# Fallback: format only available placeholders
|
| 303 |
+
if "{context}" in content:
|
| 304 |
+
content = content.replace("{context}", formatted_context)
|
| 305 |
+
if "{question}" in content:
|
| 306 |
+
content = content.replace("{question}", question)
|
| 307 |
+
|
| 308 |
+
# Create new dict with formatted content
|
| 309 |
+
formatted_chat = {
|
| 310 |
+
"role": cht["role"],
|
| 311 |
+
"content": content
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
# Copy other fields if they exist
|
| 315 |
+
if "description" in cht:
|
| 316 |
+
formatted_chat["description"] = cht["description"]
|
| 317 |
+
|
| 318 |
+
formatted_template.append(formatted_chat)
|
| 319 |
+
|
| 320 |
+
# self.logger.info(f"Formatted Template {formatted_template}")
|
| 321 |
+
# print("Forrmatted Template", formatted_template)
|
| 322 |
+
return formatted_template
|
| 323 |
+
else:
|
| 324 |
+
# Fallback default template
|
| 325 |
+
return [
|
| 326 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 327 |
+
{"role": "user", "content": question}
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
async def generate_stream(self,
|
| 331 |
+
prompt: List[Dict],
|
| 332 |
+
max_new_tokens: Optional[int] = None,
|
| 333 |
+
temperature: Optional[float] = None,
|
| 334 |
+
top_p: Optional[float] = None,
|
| 335 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 336 |
+
"""
|
| 337 |
+
Generate text dari prompt secara streaming async
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
prompt: Input text prompt
|
| 341 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 342 |
+
temperature: Temperature untuk generation (override config)
|
| 343 |
+
top_p: Top-p untuk generation (override config)
|
| 344 |
+
**kwargs: Parameter tambahan untuk generation
|
| 345 |
+
|
| 346 |
+
Yields:
|
| 347 |
+
Generated text chunks
|
| 348 |
+
"""
|
| 349 |
+
await self._check_model_loaded()
|
| 350 |
+
|
| 351 |
+
# Setup streamer
|
| 352 |
+
streamer = TextIteratorStreamer(
|
| 353 |
+
self.tokenizer,
|
| 354 |
+
timeout=self.config.stream_timeout,
|
| 355 |
+
skip_prompt=self.config.skip_prompt,
|
| 356 |
+
skip_special_tokens=True
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def _generate_sync():
|
| 360 |
+
try:
|
| 361 |
+
# Tokenize input
|
| 362 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 363 |
+
prompt,
|
| 364 |
+
add_generation_prompt=True,
|
| 365 |
+
return_tensors="pt"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Override generation config jika diperlukan
|
| 369 |
+
gen_config = self.generation_config
|
| 370 |
+
if any([max_new_tokens, temperature, top_p]):
|
| 371 |
+
gen_config = GenerationConfig(
|
| 372 |
+
max_new_tokens=max_new_tokens or self.config.max_length,
|
| 373 |
+
temperature=temperature or self.config.temperature,
|
| 374 |
+
top_p=top_p or self.config.top_p,
|
| 375 |
+
top_k=self.config.top_k,
|
| 376 |
+
do_sample=self.config.do_sample,
|
| 377 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 378 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 379 |
+
repetition_penalty=self.config.repetition_penalty,
|
| 380 |
+
**kwargs
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Move to GPU
|
| 384 |
+
self.model.to("cuda")
|
| 385 |
+
input_ids = inputs.to("cuda")
|
| 386 |
+
|
| 387 |
+
# Generate dalam thread terpisah
|
| 388 |
+
generation_kwargs = {
|
| 389 |
+
"input_ids": input_ids,
|
| 390 |
+
"generation_config": gen_config,
|
| 391 |
+
"streamer": streamer,
|
| 392 |
+
**kwargs
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
| 396 |
+
thread.start()
|
| 397 |
+
|
| 398 |
+
return thread
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
self.logger.error(f"Error during stream generation setup: {e}")
|
| 402 |
+
raise
|
| 403 |
+
|
| 404 |
+
# Setup generation thread
|
| 405 |
+
generation_thread = await asyncio.get_event_loop().run_in_executor(
|
| 406 |
+
self.executor, _generate_sync
|
| 407 |
+
)
|
| 408 |
+
err = None
|
| 409 |
+
try:
|
| 410 |
+
# Stream tokens
|
| 411 |
+
for token in streamer:
|
| 412 |
+
if token: # Skip empty tokens
|
| 413 |
+
yield token
|
| 414 |
+
|
| 415 |
+
# Wait for generation thread to finish
|
| 416 |
+
err = await asyncio.get_event_loop().run_in_executor(
|
| 417 |
+
self.executor, generation_thread.join
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
except Exception as e:
|
| 421 |
+
self.logger.error(f"Error during streaming: {e}, {err}")
|
| 422 |
+
# Make sure thread is cleaned up
|
| 423 |
+
if generation_thread.is_alive():
|
| 424 |
+
generation_thread.join(timeout=1.0)
|
| 425 |
+
raise
|
| 426 |
+
|
| 427 |
+
async def rag_generate_stream(self,
|
| 428 |
+
question: str,
|
| 429 |
+
contexts: Union[List[str], RetrievalResult],
|
| 430 |
+
template_type: Optional[str] = None,
|
| 431 |
+
max_new_tokens: Optional[int] = None,
|
| 432 |
+
temperature: Optional[float] = None,
|
| 433 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 434 |
+
"""
|
| 435 |
+
Generate jawaban untuk RAG secara streaming async
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
question: User question
|
| 439 |
+
contexts: List of retrieved contexts
|
| 440 |
+
template_type: Template type untuk formatting
|
| 441 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 442 |
+
temperature: Temperature untuk generation
|
| 443 |
+
**kwargs: Parameter tambahan untuk generation
|
| 444 |
+
|
| 445 |
+
Yields:
|
| 446 |
+
Generated answer chunks
|
| 447 |
+
"""
|
| 448 |
+
await self._check_model_loaded()
|
| 449 |
+
|
| 450 |
+
# Format prompt
|
| 451 |
+
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 452 |
+
|
| 453 |
+
# Generate dengan temperature yang lebih rendah untuk RAG (lebih faktual)
|
| 454 |
+
temp = temperature if temperature is not None else 0.3
|
| 455 |
+
|
| 456 |
+
async for chunk in self.generate_stream(
|
| 457 |
+
prompt=prompt,
|
| 458 |
+
max_new_tokens=max_new_tokens,
|
| 459 |
+
temperature=temp,
|
| 460 |
+
**kwargs
|
| 461 |
+
):
|
| 462 |
+
yield chunk
|
| 463 |
+
|
| 464 |
+
async def chat_stream(self,
|
| 465 |
+
messages: List[Dict[str, str]],
|
| 466 |
+
max_new_tokens: Optional[int] = None,
|
| 467 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 468 |
+
"""
|
| 469 |
+
Chat dengan format conversation secara streaming async
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 473 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 474 |
+
**kwargs: Parameter tambahan untuk generation
|
| 475 |
+
|
| 476 |
+
Yields:
|
| 477 |
+
Response text chunks
|
| 478 |
+
"""
|
| 479 |
+
await self._check_model_loaded()
|
| 480 |
+
|
| 481 |
+
def _format_chat():
|
| 482 |
+
try:
|
| 483 |
+
# Format messages untuk chat
|
| 484 |
+
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 485 |
+
messages,
|
| 486 |
+
tokenize=False,
|
| 487 |
+
add_generation_prompt=True
|
| 488 |
+
)
|
| 489 |
+
return formatted_prompt
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
self.logger.error(f"Error during chat formatting: {e}")
|
| 493 |
+
raise
|
| 494 |
+
|
| 495 |
+
# Format chat template dalam thread pool
|
| 496 |
+
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 497 |
+
self.executor, _format_chat
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
async for chunk in self.generate_stream(
|
| 501 |
+
formatted_prompt,
|
| 502 |
+
max_new_tokens=max_new_tokens,
|
| 503 |
+
**kwargs
|
| 504 |
+
):
|
| 505 |
+
yield chunk
|
| 506 |
+
|
| 507 |
+
async def rag_chat_stream(self,
|
| 508 |
+
messages: List[Dict[str, str]],
|
| 509 |
+
contexts: Union[List[str], RetrievalResult],
|
| 510 |
+
template_type: Optional[str] = None,
|
| 511 |
+
max_new_tokens: Optional[int] = None,
|
| 512 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 513 |
+
"""
|
| 514 |
+
RAG Chat dengan format conversation secara streaming async
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 518 |
+
contexts: List of retrieved contexts
|
| 519 |
+
template_type: Template type untuk formatting
|
| 520 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 521 |
+
**kwargs: Parameter tambahan untuk generation
|
| 522 |
+
|
| 523 |
+
Yields:
|
| 524 |
+
Response text chunks
|
| 525 |
+
"""
|
| 526 |
+
await self._check_model_loaded()
|
| 527 |
+
|
| 528 |
+
# Ambil last user message sebagai question
|
| 529 |
+
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 530 |
+
if not user_messages:
|
| 531 |
+
raise ValueError("No user message found in conversation")
|
| 532 |
+
|
| 533 |
+
last_question = user_messages[-1]["content"]
|
| 534 |
+
|
| 535 |
+
# Generate RAG response secara streaming
|
| 536 |
+
async for chunk in self.rag_generate_stream(
|
| 537 |
+
question=last_question,
|
| 538 |
+
contexts=contexts,
|
| 539 |
+
template_type=template_type,
|
| 540 |
+
max_new_tokens=max_new_tokens,
|
| 541 |
+
**kwargs
|
| 542 |
+
):
|
| 543 |
+
yield chunk
|
| 544 |
+
|
| 545 |
+
# Utility method untuk collect full response dari stream
|
| 546 |
+
async def collect_stream(self, stream_generator: AsyncGenerator[str, None]) -> str:
|
| 547 |
+
"""
|
| 548 |
+
Collect semua chunks dari stream generator menjadi full text
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
stream_generator: AsyncGenerator yang menghasilkan text chunks
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
Complete generated text
|
| 555 |
+
"""
|
| 556 |
+
chunks = []
|
| 557 |
+
async for chunk in stream_generator:
|
| 558 |
+
chunks.append(chunk)
|
| 559 |
+
return "".join(chunks)
|
| 560 |
+
|
| 561 |
+
async def multi_template_generate(self,
|
| 562 |
+
question: str,
|
| 563 |
+
contexts: Union[List[str], RetrievalResult],
|
| 564 |
+
template_types: List[str],
|
| 565 |
+
max_new_tokens: Optional[int] = None,
|
| 566 |
+
**kwargs) -> Dict[str, str]:
|
| 567 |
+
"""
|
| 568 |
+
Generate jawaban menggunakan multiple templates secara concurrent
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
question: User question
|
| 572 |
+
contexts: List of retrieved contexts
|
| 573 |
+
template_types: List of template types to use
|
| 574 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 575 |
+
**kwargs: Parameter tambahan untuk generation
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
Dictionary dengan template_type sebagai key dan response sebagai value
|
| 579 |
+
"""
|
| 580 |
+
await self._check_model_loaded()
|
| 581 |
+
|
| 582 |
+
# Create tasks untuk concurrent generation
|
| 583 |
+
tasks = []
|
| 584 |
+
for template_type in template_types:
|
| 585 |
+
task = asyncio.create_task(
|
| 586 |
+
self._generate_single_template(
|
| 587 |
+
question, contexts, template_type, max_new_tokens, **kwargs
|
| 588 |
+
)
|
| 589 |
+
)
|
| 590 |
+
tasks.append((template_type, task))
|
| 591 |
+
|
| 592 |
+
# Wait for all tasks
|
| 593 |
+
results = {}
|
| 594 |
+
for template_type, task in tasks:
|
| 595 |
+
try:
|
| 596 |
+
response = await task
|
| 597 |
+
results[template_type] = response
|
| 598 |
+
except Exception as e:
|
| 599 |
+
self.logger.error(f"Error generating with template {template_type}: {e}")
|
| 600 |
+
results[template_type] = f"Error: {str(e)}"
|
| 601 |
+
|
| 602 |
+
return results
|
| 603 |
+
|
| 604 |
+
async def _generate_single_template(self,
|
| 605 |
+
question: str,
|
| 606 |
+
contexts: Union[List[str], RetrievalResult],
|
| 607 |
+
template_type: str,
|
| 608 |
+
max_new_tokens: Optional[int] = None,
|
| 609 |
+
**kwargs) -> str:
|
| 610 |
+
"""Helper method untuk single template generation"""
|
| 611 |
+
return await self.rag_generate(
|
| 612 |
+
question=question,
|
| 613 |
+
contexts=contexts,
|
| 614 |
+
template_type=template_type,
|
| 615 |
+
max_new_tokens=max_new_tokens,
|
| 616 |
+
**kwargs
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
async def rag_generate(self,
|
| 620 |
+
question: str,
|
| 621 |
+
contexts: Union[List[str], RetrievalResult],
|
| 622 |
+
template_type: Optional[str] = None,
|
| 623 |
+
max_new_tokens: Optional[int] = None,
|
| 624 |
+
temperature: Optional[float] = None,
|
| 625 |
+
**kwargs) -> str:
|
| 626 |
+
"""
|
| 627 |
+
Generate jawaban untuk RAG secara async
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
question: User question
|
| 631 |
+
contexts: List of retrieved contexts
|
| 632 |
+
template_type: Template type untuk formatting
|
| 633 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 634 |
+
temperature: Temperature untuk generation
|
| 635 |
+
**kwargs: Parameter tambahan untuk generation
|
| 636 |
+
|
| 637 |
+
Returns:
|
| 638 |
+
Generated answer
|
| 639 |
+
"""
|
| 640 |
+
await self._check_model_loaded()
|
| 641 |
+
|
| 642 |
+
# Format prompt
|
| 643 |
+
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 644 |
+
|
| 645 |
+
# Generate dengan temperature yang lebih rendah untuk RAG (lebih faktual)
|
| 646 |
+
temp = temperature if temperature is not None else 0.3
|
| 647 |
+
|
| 648 |
+
return await self.generate(
|
| 649 |
+
prompt=prompt,
|
| 650 |
+
max_new_tokens=max_new_tokens,
|
| 651 |
+
temperature=temp,
|
| 652 |
+
**kwargs
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
async def rag_chat(self,
|
| 656 |
+
messages: List[Dict[str, str]],
|
| 657 |
+
contexts: Union[List[str], RetrievalResult],
|
| 658 |
+
template_type: Optional[str] = None,
|
| 659 |
+
max_new_tokens: Optional[int] = None,
|
| 660 |
+
**kwargs) -> str:
|
| 661 |
+
"""
|
| 662 |
+
RAG Chat dengan format conversation secara async
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 666 |
+
contexts: List of retrieved contexts
|
| 667 |
+
template_type: Template type untuk formatting
|
| 668 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 669 |
+
**kwargs: Parameter tambahan untuk generation
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Response text
|
| 673 |
+
"""
|
| 674 |
+
await self._check_model_loaded()
|
| 675 |
+
|
| 676 |
+
# Ambil last user message sebagai question
|
| 677 |
+
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 678 |
+
if not user_messages:
|
| 679 |
+
raise ValueError("No user message found in conversation")
|
| 680 |
+
|
| 681 |
+
last_question = user_messages[-1]["content"]
|
| 682 |
+
|
| 683 |
+
# Generate RAG response
|
| 684 |
+
return await self.rag_generate(
|
| 685 |
+
question=last_question,
|
| 686 |
+
contexts=contexts,
|
| 687 |
+
template_type=template_type,
|
| 688 |
+
max_new_tokens=max_new_tokens,
|
| 689 |
+
**kwargs
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
async def _check_model_loaded(self) -> None:
|
| 693 |
+
"""Cek apakah model sudah di-load secara async"""
|
| 694 |
+
if not self.is_loaded:
|
| 695 |
+
raise RuntimeError("Model belum di-load. Panggil await load_model() terlebih dahulu.")
|
| 696 |
+
|
| 697 |
+
async def generate(self,
|
| 698 |
+
prompt: Union[List[Dict], str],
|
| 699 |
+
max_new_tokens: Optional[int] = None,
|
| 700 |
+
temperature: Optional[float] = None,
|
| 701 |
+
top_p: Optional[float] = None,
|
| 702 |
+
**kwargs) -> str:
|
| 703 |
+
"""
|
| 704 |
+
Generate text dari prompt secara async
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
prompt: Input text prompt
|
| 708 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 709 |
+
temperature: Temperature untuk generation (override config)
|
| 710 |
+
top_p: Top-p untuk generation (override config)
|
| 711 |
+
**kwargs: Parameter tambahan untuk generation
|
| 712 |
+
|
| 713 |
+
Returns:
|
| 714 |
+
Generated text
|
| 715 |
+
"""
|
| 716 |
+
|
| 717 |
+
await self._check_model_loaded()
|
| 718 |
+
|
| 719 |
+
def _generate_sync():
|
| 720 |
+
try:
|
| 721 |
+
# Tokenize input
|
| 722 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 723 |
+
prompt,
|
| 724 |
+
add_generation_prompt=True,
|
| 725 |
+
return_tensors="pt"
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
# Override generation config jika diperlukan
|
| 729 |
+
gen_config = self.generation_config
|
| 730 |
+
if any([max_new_tokens, temperature, top_p]):
|
| 731 |
+
gen_config = GenerationConfig(
|
| 732 |
+
max_new_tokens=max_new_tokens or self.config.max_length,
|
| 733 |
+
temperature=temperature or self.config.temperature,
|
| 734 |
+
top_p=top_p or self.config.top_p,
|
| 735 |
+
top_k=self.config.top_k,
|
| 736 |
+
do_sample=self.config.do_sample,
|
| 737 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 738 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 739 |
+
repetition_penalty = self.config.repetition_penalty,
|
| 740 |
+
**kwargs
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# Generate
|
| 744 |
+
with torch.no_grad():
|
| 745 |
+
|
| 746 |
+
self.model.to("cuda")
|
| 747 |
+
input_ids = inputs.to("cuda")
|
| 748 |
+
prompt_length = input_ids.shape[-1]
|
| 749 |
+
outputs = self.model.generate(
|
| 750 |
+
input_ids,
|
| 751 |
+
generation_config=gen_config,
|
| 752 |
+
**kwargs
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
# Decode output
|
| 756 |
+
generated_text = self.tokenizer.decode(
|
| 757 |
+
outputs[0][prompt_length:],
|
| 758 |
+
skip_special_tokens=True
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
print("Generated Text", generated_text)
|
| 762 |
+
# Remove input prompt dari output
|
| 763 |
+
return generated_text
|
| 764 |
+
|
| 765 |
+
except Exception as e:
|
| 766 |
+
self.logger.error(f"Error during generation: {e}")
|
| 767 |
+
raise
|
| 768 |
+
|
| 769 |
+
# Run generation in thread pool dengan timeout
|
| 770 |
+
try:
|
| 771 |
+
result = await asyncio.wait_for(
|
| 772 |
+
asyncio.get_event_loop().run_in_executor(self.executor, _generate_sync),
|
| 773 |
+
timeout=self.config.generation_timeout
|
| 774 |
+
)
|
| 775 |
+
return result
|
| 776 |
+
except asyncio.TimeoutError:
|
| 777 |
+
self.logger.error(f"Generation timeout after {self.config.generation_timeout} seconds")
|
| 778 |
+
raise TimeoutError(f"Generation timeout after {self.config.generation_timeout} seconds")
|
| 779 |
+
|
| 780 |
+
async def chat(self,
|
| 781 |
+
messages: List[Dict[str, str]],
|
| 782 |
+
max_new_tokens: Optional[int] = None,
|
| 783 |
+
**kwargs) -> str:
|
| 784 |
+
"""
|
| 785 |
+
Chat dengan format conversation secara async
|
| 786 |
+
|
| 787 |
+
Args:
|
| 788 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 789 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 790 |
+
**kwargs: Parameter tambahan untuk generation
|
| 791 |
+
|
| 792 |
+
Returns:
|
| 793 |
+
Response text
|
| 794 |
+
"""
|
| 795 |
+
await self._check_model_loaded()
|
| 796 |
+
|
| 797 |
+
def _format_chat():
|
| 798 |
+
try:
|
| 799 |
+
# Format messages untuk chat
|
| 800 |
+
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 801 |
+
messages,
|
| 802 |
+
chat_template="rag",
|
| 803 |
+
return_tensors="pt"
|
| 804 |
+
)
|
| 805 |
+
return formatted_prompt
|
| 806 |
+
|
| 807 |
+
except Exception as e:
|
| 808 |
+
self.logger.error(f"Error during chat formatting: {e}")
|
| 809 |
+
raise
|
| 810 |
+
|
| 811 |
+
# Format chat template dalam thread pool
|
| 812 |
+
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 813 |
+
self.executor, _format_chat
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
return await self.generate(
|
| 817 |
+
formatted_prompt,
|
| 818 |
+
max_new_tokens=max_new_tokens,
|
| 819 |
+
**kwargs
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
async def update_config(self, **kwargs) -> None:
|
| 823 |
+
"""
|
| 824 |
+
Update konfigurasi model secara async
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
**kwargs: Parameter konfigurasi yang akan diupdate
|
| 828 |
+
"""
|
| 829 |
+
async with self._lock:
|
| 830 |
+
for key, value in kwargs.items():
|
| 831 |
+
if hasattr(self.config, key):
|
| 832 |
+
setattr(self.config, key, value)
|
| 833 |
+
self.logger.info(f"Updated {key} to {value}")
|
| 834 |
+
else:
|
| 835 |
+
self.logger.warning(f"Unknown config parameter: {key}")
|
| 836 |
+
|
| 837 |
+
# Update generation config jika model sudah loaded
|
| 838 |
+
if self.is_loaded:
|
| 839 |
+
self.generation_config = GenerationConfig(
|
| 840 |
+
max_length=self.config.max_length,
|
| 841 |
+
temperature=self.config.temperature,
|
| 842 |
+
top_p=self.config.top_p,
|
| 843 |
+
top_k=self.config.top_k,
|
| 844 |
+
do_sample=self.config.do_sample,
|
| 845 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 846 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 847 |
+
repetition_penalty = self.config.repetition_penalty,
|
| 848 |
+
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
async def get_model_info(self) -> Dict[str, Any]:
|
| 852 |
+
"""
|
| 853 |
+
Dapatkan informasi model secara async
|
| 854 |
+
|
| 855 |
+
Returns:
|
| 856 |
+
Dictionary dengan informasi model
|
| 857 |
+
"""
|
| 858 |
+
info = {
|
| 859 |
+
"model_name": self.config.model_name,
|
| 860 |
+
"is_loaded": self.is_loaded,
|
| 861 |
+
"config": self.config.__dict__
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
if self.is_loaded:
|
| 865 |
+
# Get model info dalam thread pool
|
| 866 |
+
def _get_info():
|
| 867 |
+
return {
|
| 868 |
+
"vocab_size": self.tokenizer.vocab_size,
|
| 869 |
+
"model_parameters": sum(p.numel() for p in self.model.parameters()),
|
| 870 |
+
"device": str(next(self.model.parameters()).device)
|
| 871 |
+
}
|
| 872 |
+
|
| 873 |
+
model_info = await asyncio.get_event_loop().run_in_executor(
|
| 874 |
+
self.executor, _get_info
|
| 875 |
+
)
|
| 876 |
+
info.update(model_info)
|
| 877 |
+
|
| 878 |
+
return info
|
| 879 |
+
|
| 880 |
+
async def batch_generate(self,
|
| 881 |
+
prompts: List[str],
|
| 882 |
+
max_new_tokens: Optional[int] = None,
|
| 883 |
+
**kwargs) -> List[str]:
|
| 884 |
+
"""
|
| 885 |
+
Generate multiple prompts secara batch dan concurrent
|
| 886 |
+
|
| 887 |
+
Args:
|
| 888 |
+
prompts: List of prompts to generate
|
| 889 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 890 |
+
**kwargs: Parameter tambahan untuk generation
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
List of generated texts
|
| 894 |
+
"""
|
| 895 |
+
await self._check_model_loaded()
|
| 896 |
+
|
| 897 |
+
# Create tasks untuk concurrent generation
|
| 898 |
+
tasks = [
|
| 899 |
+
asyncio.create_task(
|
| 900 |
+
self.generate(prompt, max_new_tokens=max_new_tokens, **kwargs)
|
| 901 |
+
)
|
| 902 |
+
for prompt in prompts
|
| 903 |
+
]
|
| 904 |
+
|
| 905 |
+
# Wait for all tasks
|
| 906 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 907 |
+
|
| 908 |
+
# Process results
|
| 909 |
+
processed_results = []
|
| 910 |
+
for i, result in enumerate(results):
|
| 911 |
+
if isinstance(result, Exception):
|
| 912 |
+
self.logger.error(f"Error generating prompt {i}: {result}")
|
| 913 |
+
processed_results.append(f"Error: {str(result)}")
|
| 914 |
+
else:
|
| 915 |
+
processed_results.append(result)
|
| 916 |
+
|
| 917 |
+
return processed_results
|
| 918 |
+
|
| 919 |
+
async def close(self) -> None:
|
| 920 |
+
"""
|
| 921 |
+
Cleanup resources secara async
|
| 922 |
+
"""
|
| 923 |
+
self.logger.info("Closing LM...")
|
| 924 |
+
|
| 925 |
+
# Shutdown executor
|
| 926 |
+
self.executor.shutdown(wait=True)
|
| 927 |
+
|
| 928 |
+
# Clear GPU memory
|
| 929 |
+
if hasattr(self, 'model') and self.model is not None:
|
| 930 |
+
del self.model
|
| 931 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
| 932 |
+
del self.tokenizer
|
| 933 |
+
|
| 934 |
+
if torch.cuda.is_available():
|
| 935 |
+
torch.cuda.empty_cache()
|
| 936 |
+
|
| 937 |
+
self.is_loaded = False
|
| 938 |
+
self.logger.info("LM closed successfully")
|
| 939 |
+
|
| 940 |
+
async def __aenter__(self):
|
| 941 |
+
"""Async context manager entry"""
|
| 942 |
+
await self.load_model()
|
| 943 |
+
return self
|
| 944 |
+
|
| 945 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 946 |
+
"""Async context manager exit"""
|
| 947 |
+
await self.close()
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/__init__.py
ADDED
|
File without changes
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/retriever/langchain_retriever.py
CHANGED
|
@@ -6,6 +6,7 @@ from langchain_openai import OpenAIEmbeddings
|
|
| 6 |
|
| 7 |
# Vector stores
|
| 8 |
from langchain_community.vectorstores import Chroma, FAISS, Pinecone
|
|
|
|
| 9 |
|
| 10 |
# Retriever base
|
| 11 |
from langchain_core.vectorstores import VectorStoreRetriever
|
|
@@ -24,7 +25,6 @@ from langchain_core.documents import Document
|
|
| 24 |
|
| 25 |
logging.basicConfig(level=logging.INFO)
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
-
|
| 28 |
class LangChainRetriever(BaseRetriever):
|
| 29 |
"""LangChain-based retriever with multiple format support"""
|
| 30 |
|
|
@@ -160,17 +160,34 @@ class LangChainRetriever(BaseRetriever):
|
|
| 160 |
except Exception as e:
|
| 161 |
logger.error(f"Error adding documents: {str(e)}")
|
| 162 |
return False
|
| 163 |
-
|
| 164 |
async def _update_bm25_retriever(self, documents: List[Document]):
|
| 165 |
try:
|
|
|
|
| 166 |
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 167 |
-
self.
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
)
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
logger.error(f"Error updating BM25 retriever: {str(e)}")
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
async def retrieve(self, query: str, k: int = 5) -> RetrievalResult:
|
| 175 |
try:
|
| 176 |
import time
|
|
@@ -181,6 +198,7 @@ class LangChainRetriever(BaseRetriever):
|
|
| 181 |
None, self.retriever.get_relevant_documents, query
|
| 182 |
)
|
| 183 |
retrieved_docs = retrieved_docs[:k]
|
|
|
|
| 184 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 185 |
|
| 186 |
retrieval_time = time.time() - start_time
|
|
@@ -222,4 +240,4 @@ class LangChainRetriever(BaseRetriever):
|
|
| 222 |
return list(self.processed_documents.values())
|
| 223 |
|
| 224 |
def get_supported_formats(self) -> List[str]:
|
| 225 |
-
return self.document_loader.get_supported_extensions()
|
|
|
|
| 6 |
|
| 7 |
# Vector stores
|
| 8 |
from langchain_community.vectorstores import Chroma, FAISS, Pinecone
|
| 9 |
+
from langchain.retrievers import EnsembleRetriever
|
| 10 |
|
| 11 |
# Retriever base
|
| 12 |
from langchain_core.vectorstores import VectorStoreRetriever
|
|
|
|
| 25 |
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
| 27 |
logger = logging.getLogger(__name__)
|
|
|
|
| 28 |
class LangChainRetriever(BaseRetriever):
|
| 29 |
"""LangChain-based retriever with multiple format support"""
|
| 30 |
|
|
|
|
| 160 |
except Exception as e:
|
| 161 |
logger.error(f"Error adding documents: {str(e)}")
|
| 162 |
return False
|
|
|
|
| 163 |
async def _update_bm25_retriever(self, documents: List[Document]):
|
| 164 |
try:
|
| 165 |
+
# Create BM25 retriever from documents
|
| 166 |
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 167 |
+
self.bm25_retriever.k = 10 # Set number of documents to retrieve
|
| 168 |
+
|
| 169 |
+
# For hybrid search, you have several options:
|
| 170 |
+
|
| 171 |
+
# Option 1: Use only BM25 retriever (simplest fix)
|
| 172 |
+
self.retriever = self.bm25_retriever
|
| 173 |
+
|
| 174 |
+
vector_retriever = VectorStoreRetriever(
|
| 175 |
+
vectorstore=self.vectorstore,
|
| 176 |
+
search_kwargs={"k": 10}
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.retriever = EnsembleRetriever(
|
| 180 |
+
retrievers=[vector_retriever, self.bm25_retriever],
|
| 181 |
+
weights=[0.5, 0.5] # Equal weight to both retrievers
|
| 182 |
)
|
| 183 |
+
|
| 184 |
except Exception as e:
|
| 185 |
logger.error(f"Error updating BM25 retriever: {str(e)}")
|
| 186 |
+
# Fallback to vector retriever only
|
| 187 |
+
self.retriever = VectorStoreRetriever(
|
| 188 |
+
vectorstore=self.vectorstore,
|
| 189 |
+
search_kwargs={"k": 10}
|
| 190 |
+
)
|
| 191 |
async def retrieve(self, query: str, k: int = 5) -> RetrievalResult:
|
| 192 |
try:
|
| 193 |
import time
|
|
|
|
| 198 |
None, self.retriever.get_relevant_documents, query
|
| 199 |
)
|
| 200 |
retrieved_docs = retrieved_docs[:k]
|
| 201 |
+
|
| 202 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 203 |
|
| 204 |
retrieval_time = time.time() - start_time
|
|
|
|
| 240 |
return list(self.processed_documents.values())
|
| 241 |
|
| 242 |
def get_supported_formats(self) -> List[str]:
|
| 243 |
+
return self.document_loader.get_supported_extensions()
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/web_search/__init__.py
ADDED
|
File without changes
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rtc/rtc_call_gpt.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fastapi
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
|
| 4 |
+
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
|
| 5 |
+
from fastrtc.utils import audio_to_int16
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from elevenlabs.client import ElevenLabs
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from tts.audio_edge_tts import EdgeTTS
|
| 10 |
+
from rag import document_retriever
|
| 11 |
+
import logging
|
| 12 |
+
import time
|
| 13 |
+
import platform
|
| 14 |
+
import socket
|
| 15 |
+
import os
|
| 16 |
+
import numpy as np
|
| 17 |
+
import io
|
| 18 |
+
import wave
|
| 19 |
+
import asyncio
|
| 20 |
+
import librosa
|
| 21 |
+
from pydub import AudioSegment
|
| 22 |
+
# from stt.whisper_stt import WhisperSTT
|
| 23 |
+
from collections import deque
|
| 24 |
+
import torch
|
| 25 |
+
import torchaudio.transforms as T
|
| 26 |
+
import asyncio
|
| 27 |
+
import concurrent.futures
|
| 28 |
+
import threading
|
| 29 |
+
from config.constant import HF_TOKEN
|
| 30 |
+
import threading
|
| 31 |
+
import re
|
| 32 |
+
from openai import OpenAI
|
| 33 |
+
from langchain_core.documents import Document
|
| 34 |
+
|
| 35 |
+
from rag import ddgs
|
| 36 |
+
# Load .env
|
| 37 |
+
load_dotenv()
|
| 38 |
+
logging.basicConfig(level=logging.INFO)
|
| 39 |
+
|
| 40 |
+
class RTCHandler:
|
| 41 |
+
def __init__(self, openai_client: OpenAI, whisper_stt = None, edge_tts : EdgeTTS = None):
|
| 42 |
+
|
| 43 |
+
"""Initialize RTC handler with OpenAI, ElevenLabs, and EdgeTTS"""
|
| 44 |
+
self.whisper_stt = whisper_stt
|
| 45 |
+
self.edge_tts = edge_tts
|
| 46 |
+
self.prompt = ""
|
| 47 |
+
self.sys_prompt = """
|
| 48 |
+
|
| 49 |
+
Kamu adalah customer service yang berbahasa Indonesia dengan baik sopan, santun, tapi santai pembawaannya.
|
| 50 |
+
Kamu bisa menjelaskan sesuatu secara baik dan membimbing customer dalam menghadapi masalah yang ada!
|
| 51 |
+
|
| 52 |
+
Kamu akan menjawab customer dengan media call /telepon jadi anda harus memberikan respon seperlunya saja
|
| 53 |
+
Tidak kepanjanngan, dan sangat jelas,
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
Tidak lebih dari 50 kata.
|
| 57 |
+
"""
|
| 58 |
+
self.openai_client = openai_client
|
| 59 |
+
self.messages = [
|
| 60 |
+
|
| 61 |
+
{
|
| 62 |
+
"role": "system",
|
| 63 |
+
"content": self.sys_prompt
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
]
|
| 67 |
+
self.full_response = ""
|
| 68 |
+
self.stream = None
|
| 69 |
+
self.app = None
|
| 70 |
+
|
| 71 |
+
self._setup_webrtc_ip()
|
| 72 |
+
|
| 73 |
+
def _setup_webrtc_ip(self):
|
| 74 |
+
"""Setup WebRTC IP for Windows"""
|
| 75 |
+
if platform.system() == 'Windows':
|
| 76 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
| 77 |
+
try:
|
| 78 |
+
s.connect(('8.8.8.8', 80))
|
| 79 |
+
local_ip = s.getsockname()[0]
|
| 80 |
+
except Exception:
|
| 81 |
+
local_ip = '127.0.0.1'
|
| 82 |
+
finally:
|
| 83 |
+
s.close()
|
| 84 |
+
os.environ['WEBRTC_IP'] = local_ip
|
| 85 |
+
|
| 86 |
+
def audio_to_bytes(self, audio_tuple, sample_rate=24000) -> io.BufferedReader:
|
| 87 |
+
sr, audio_data = audio_tuple
|
| 88 |
+
audio_int16 = audio_to_int16(audio_tuple)
|
| 89 |
+
|
| 90 |
+
buffer = io.BytesIO()
|
| 91 |
+
with wave.open(buffer, "wb") as wf:
|
| 92 |
+
wf.setnchannels(1)
|
| 93 |
+
wf.setsampwidth(2)
|
| 94 |
+
wf.setframerate(sr)
|
| 95 |
+
wf.writeframes(audio_int16.tobytes())
|
| 96 |
+
buffer.seek(0)
|
| 97 |
+
buffer.name = "audio.wav"
|
| 98 |
+
return buffer
|
| 99 |
+
def echo(self, audio):
|
| 100 |
+
"""Process audio input and generate audio response - Optimized version"""
|
| 101 |
+
try:
|
| 102 |
+
stt_time = time.time()
|
| 103 |
+
logging.info("Performing STT")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# transcription = self.whisper_stt.transcribe(self.audio_to_bytes(audio))
|
| 107 |
+
transcription = self.openai_client.audio.transcriptions.create(
|
| 108 |
+
model="whisper-1",
|
| 109 |
+
file=self.audio_to_bytes(audio),
|
| 110 |
+
language="id"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.prompt = transcription.text
|
| 114 |
+
if self.prompt == "":
|
| 115 |
+
logging.info("STT returned empty string")
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
logging.info(f"STT response: {transcription}")
|
| 119 |
+
|
| 120 |
+
logging.info(f"STT took {time.time() - stt_time} seconds")
|
| 121 |
+
|
| 122 |
+
llm_time = time.time()
|
| 123 |
+
self.full_response = ""
|
| 124 |
+
|
| 125 |
+
# Single async function to handle both text streaming and audio generation
|
| 126 |
+
async def stream_text_to_audio():
|
| 127 |
+
# self.prompt = "Perhitungan BPJS"
|
| 128 |
+
retrieval_result = await document_retriever.retrieve(query = self.prompt)
|
| 129 |
+
contexts = ""
|
| 130 |
+
search_results = []
|
| 131 |
+
|
| 132 |
+
async for result in ddgs.search(self.prompt, max_results=5):
|
| 133 |
+
# self.logger.info(f"Processing SEO Result: {result[:100]}...")
|
| 134 |
+
doc = Document(
|
| 135 |
+
page_content=result,
|
| 136 |
+
metadata={"source": "internet_search", "query": self.prompt}
|
| 137 |
+
)
|
| 138 |
+
print(doc)
|
| 139 |
+
search_results.append(doc)
|
| 140 |
+
|
| 141 |
+
await document_retriever.add_documents([doc])
|
| 142 |
+
|
| 143 |
+
i = 1
|
| 144 |
+
for ctx in retrieval_result.documents:
|
| 145 |
+
contexts += f"{i}. {ctx.page_content}" + "\n"
|
| 146 |
+
print("Retrieved Contexts :", contexts)
|
| 147 |
+
self.messages.append({"role": "user", "content": f"""
|
| 148 |
+
Dari Konteks yang diberikan (jika diperlukan) :
|
| 149 |
+
{contexts}
|
| 150 |
+
|
| 151 |
+
Berikan jawaban atas pertanyaan yang diberikan :
|
| 152 |
+
{self.prompt}
|
| 153 |
+
|
| 154 |
+
"""})
|
| 155 |
+
|
| 156 |
+
response = self.openai_client.chat.completions.create(
|
| 157 |
+
model="gpt-3.5-turbo",
|
| 158 |
+
messages=self.messages,
|
| 159 |
+
max_tokens=200,
|
| 160 |
+
stream=True
|
| 161 |
+
)
|
| 162 |
+
chunk_size = 1024
|
| 163 |
+
no_buffer = 0
|
| 164 |
+
text_buffer = ""
|
| 165 |
+
|
| 166 |
+
for stream_data in response:
|
| 167 |
+
print(stream_data.choices[0].delta.content)
|
| 168 |
+
if stream_data.choices[0].finish_reason == "stop":
|
| 169 |
+
if text_buffer: # Yield sisa text
|
| 170 |
+
yield text_buffer
|
| 171 |
+
break
|
| 172 |
+
if stream_data.choices[0].delta.content:
|
| 173 |
+
chunk = stream_data.choices[0].delta.content
|
| 174 |
+
self.full_response += chunk
|
| 175 |
+
text_buffer += chunk
|
| 176 |
+
# Generate audio immediately for each text chunk
|
| 177 |
+
if re.search(r'[.,?;!]', chunk):
|
| 178 |
+
try:
|
| 179 |
+
audio_buffer_gen = await self.edge_tts.generate_audio_buffer(text_buffer)
|
| 180 |
+
audio_buffer = audio_buffer_gen[0]
|
| 181 |
+
|
| 182 |
+
audio_buffer.seek(0)
|
| 183 |
+
|
| 184 |
+
# Convert MP3 to PCM
|
| 185 |
+
audio_segment = AudioSegment.from_file(audio_buffer, format="mp3")
|
| 186 |
+
samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32) / (2 ** 15)
|
| 187 |
+
|
| 188 |
+
# Handle stereo to mono
|
| 189 |
+
if audio_segment.channels == 2:
|
| 190 |
+
samples = samples.reshape((-1, 2)).mean(axis=1)
|
| 191 |
+
|
| 192 |
+
# # Resample to 24kHz
|
| 193 |
+
# resampled = librosa.resample(samples, orig_sr=audio_segment.frame_rate, target_sr=24000)
|
| 194 |
+
import torch
|
| 195 |
+
import torchaudio
|
| 196 |
+
|
| 197 |
+
# Check if CUDA is available
|
| 198 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 199 |
+
|
| 200 |
+
# Convert numpy array to torch tensor and move to GPU
|
| 201 |
+
audio_tensor = torch.from_numpy(samples).unsqueeze(0).to(device) # Add batch dimension and move to GPU
|
| 202 |
+
|
| 203 |
+
# Create resampler and move to GPU
|
| 204 |
+
resampler = torchaudio.transforms.Resample(
|
| 205 |
+
orig_freq=audio_segment.frame_rate,
|
| 206 |
+
new_freq=24000
|
| 207 |
+
).to(device)
|
| 208 |
+
|
| 209 |
+
# Apply resampling on GPU
|
| 210 |
+
resampled_tensor = resampler(audio_tensor)
|
| 211 |
+
|
| 212 |
+
# Convert back to numpy (move to CPU first)
|
| 213 |
+
resampled = resampled_tensor.squeeze(0).cpu().numpy()
|
| 214 |
+
# Yield audio chunks
|
| 215 |
+
for i in range(0, len(resampled), chunk_size):
|
| 216 |
+
yield (24000, resampled[i:i + chunk_size])
|
| 217 |
+
no_buffer = 0
|
| 218 |
+
text_buffer = ""
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logging.error(f"TTS generation failed for chunk: {e}")
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
# elif stream_data["type"] == "metadata":
|
| 224 |
+
# setup_time = stream_data['data']['setup_time']
|
| 225 |
+
# print(f"\nSetup completed in {setup_time:.2f}s")
|
| 226 |
+
|
| 227 |
+
# elif stream_data["type"] == "complete":
|
| 228 |
+
# total_time = stream_data['data']['total_time']
|
| 229 |
+
# print(f"\nTotal time: {total_time:.2f}s")
|
| 230 |
+
# break
|
| 231 |
+
|
| 232 |
+
# Run the single async function
|
| 233 |
+
loop = asyncio.new_event_loop()
|
| 234 |
+
asyncio.set_event_loop(loop)
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
async_gen = stream_text_to_audio()
|
| 238 |
+
while True:
|
| 239 |
+
try:
|
| 240 |
+
chunk = loop.run_until_complete(async_gen.__anext__())
|
| 241 |
+
yield chunk
|
| 242 |
+
except StopAsyncIteration:
|
| 243 |
+
break
|
| 244 |
+
finally:
|
| 245 |
+
loop.close()
|
| 246 |
+
|
| 247 |
+
self.messages.append({"role": "assistant", "content": self.full_response + " "})
|
| 248 |
+
logging.info(f"LLM response: {self.full_response}")
|
| 249 |
+
logging.info(f"LLM took {time.time() - llm_time} seconds")
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logging.error(f"Error in echo function: {e}")
|
| 253 |
+
error_audio = np.zeros(24000, dtype=np.float32)
|
| 254 |
+
yield (24000, error_audio)
|
| 255 |
+
def reset_conversation(self):
|
| 256 |
+
logging.info("Resetting chat")
|
| 257 |
+
self.messages = [{"role": "system", "content": self.sys_prompt}]
|
| 258 |
+
self.full_response = ""
|
| 259 |
+
|
| 260 |
+
def create_stream(self):
|
| 261 |
+
try:
|
| 262 |
+
async def get_credentials():
|
| 263 |
+
return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN)
|
| 264 |
+
self.stream = Stream(
|
| 265 |
+
rtc_configuration=get_credentials,
|
| 266 |
+
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000),
|
| 267 |
+
handler = ReplyOnPause(
|
| 268 |
+
self.echo,
|
| 269 |
+
algo_options=AlgoOptions(
|
| 270 |
+
audio_chunk_duration=0.5,
|
| 271 |
+
started_talking_threshold=0.1,
|
| 272 |
+
speech_threshold=0.03
|
| 273 |
+
),
|
| 274 |
+
model_options=SileroVadOptions(
|
| 275 |
+
threshold=0.90,
|
| 276 |
+
min_speech_duration_ms=250,
|
| 277 |
+
min_silence_duration_ms=2000,
|
| 278 |
+
speech_pad_ms=400,
|
| 279 |
+
max_speech_duration_s=15
|
| 280 |
+
)
|
| 281 |
+
),
|
| 282 |
+
modality="audio",
|
| 283 |
+
mode="send-receive"
|
| 284 |
+
)
|
| 285 |
+
return self.stream
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logging.error(f"Error creating stream: {e}")
|
| 288 |
+
raise
|
| 289 |
+
|
| 290 |
+
def create_fastapi_app(self):
|
| 291 |
+
try:
|
| 292 |
+
self.app = fastapi.FastAPI()
|
| 293 |
+
self.app.add_middleware(
|
| 294 |
+
CORSMiddleware,
|
| 295 |
+
allow_origins=["*"],
|
| 296 |
+
allow_credentials=True,
|
| 297 |
+
allow_methods=["*"],
|
| 298 |
+
allow_headers=["*"],
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if not self.stream:
|
| 302 |
+
self.create_stream()
|
| 303 |
+
self.stream.mount(self.app)
|
| 304 |
+
|
| 305 |
+
@self.app.get("/reset")
|
| 306 |
+
async def reset():
|
| 307 |
+
try:
|
| 308 |
+
self.reset_conversation()
|
| 309 |
+
return {"status": "success"}
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logging.error(f"Error in reset endpoint: {e}")
|
| 312 |
+
return {"status": "error", "message": str(e)}
|
| 313 |
+
|
| 314 |
+
@self.app.get("/status")
|
| 315 |
+
async def status():
|
| 316 |
+
try:
|
| 317 |
+
return {
|
| 318 |
+
"status": "running",
|
| 319 |
+
"messages_count": len(self.messages),
|
| 320 |
+
"last_response": self.full_response
|
| 321 |
+
}
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logging.error(f"Error in status endpoint: {e}")
|
| 324 |
+
return {"status": "error", "message": str(e)}
|
| 325 |
+
|
| 326 |
+
return self.app
|
| 327 |
+
except Exception as e:
|
| 328 |
+
logging.error(f"Error creating FastAPI app: {e}")
|
| 329 |
+
raise
|
| 330 |
+
|
| 331 |
+
def start_server(self, host: str = "0.0.0.0", port: int = 7860):
|
| 332 |
+
import uvicorn
|
| 333 |
+
if not self.app:
|
| 334 |
+
self.create_fastapi_app()
|
| 335 |
+
logging.info(f"Starting server on {host}:{port}")
|
| 336 |
+
try:
|
| 337 |
+
uvicorn.run(self.app, host=host, port=port, log_level="info")
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logging.error(f"Error starting server: {e}")
|
| 340 |
+
raise
|
| 341 |
+
def launch_ui(self, browser: bool = True):
|
| 342 |
+
try:
|
| 343 |
+
if not self.stream:
|
| 344 |
+
self.create_stream()
|
| 345 |
+
if not self.app:
|
| 346 |
+
self.create_fastapi_app()
|
| 347 |
+
logging.info("Launching RTC UI...")
|
| 348 |
+
self.stream.ui.launch(self.app,
|
| 349 |
+
server_name="0.0.0.0",
|
| 350 |
+
server_port=7860,
|
| 351 |
+
)
|
| 352 |
+
except Exception as e:
|
| 353 |
+
logging.error(f"Error launching UI: {e}")
|
| 354 |
+
raise
|
| 355 |
+
|
| 356 |
+
def get_conversation_history(self):
|
| 357 |
+
return self.messages.copy()
|
| 358 |
+
|
| 359 |
+
def set_system_prompt(self, new_prompt: str):
|
| 360 |
+
self.sys_prompt = new_prompt
|
| 361 |
+
self.messages[0] = {"role": "system", "content": new_prompt}
|
| 362 |
+
|
| 363 |
+
def get_last_response(self):
|
| 364 |
+
return self.full_response
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/tests/qwen_llm_test.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
from rag.retriever.retriever_types import *
|
| 2 |
-
from rag.pipeline.
|
| 3 |
|
| 4 |
import warnings
|
| 5 |
warnings.filterwarnings("ignore")
|
| 6 |
|
| 7 |
-
async def
|
| 8 |
print(" ===== Testing QWEN LLM ==== ")
|
| 9 |
-
"""Example usage of async
|
| 10 |
|
| 11 |
-
config =
|
| 12 |
temperature=0.5,
|
| 13 |
max_length=512,
|
| 14 |
generation_timeout=30
|
|
@@ -23,20 +23,20 @@ async def test_qwen_llm():
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# Using async context manager
|
| 26 |
-
async with
|
| 27 |
await test_qwen_single_generation(llm)
|
| 28 |
await test_qwen_single_rag_generation(llm, contexts)
|
| 29 |
await test_qwen_multiple_template_rag_generation(llm, contexts)
|
| 30 |
await test_qwen_batch_generation(llm, contexts)
|
| 31 |
print(" ===== Testing LLM DONE ==== ")
|
| 32 |
|
| 33 |
-
async def test_qwen_single_generation(llm :
|
| 34 |
print(" * Test Single Generation * ")
|
| 35 |
response = await llm.generate("Jelaskan tentang AI")
|
| 36 |
print(f"Response: {response}")
|
| 37 |
print(" * Test Single Generation Done * ")
|
| 38 |
|
| 39 |
-
async def test_qwen_single_rag_generation(llm :
|
| 40 |
print(" * Test Single RAG Generation * ")
|
| 41 |
rag_response = await llm.rag_generate(
|
| 42 |
question="Apa itu AI dan machine learning?",
|
|
@@ -46,7 +46,7 @@ async def test_qwen_single_rag_generation(llm : QwenLLM, ctx : RetrievalResult):
|
|
| 46 |
print(f"RAG Response: {rag_response}")
|
| 47 |
print(" * Test Single RAG Generation Done * ")
|
| 48 |
|
| 49 |
-
async def test_qwen_multiple_template_rag_generation(llm :
|
| 50 |
print(" * Test Multiple Template Generation * ")
|
| 51 |
multi_responses = await llm.multi_template_generate(
|
| 52 |
question="Apa itu AI?",
|
|
@@ -57,7 +57,7 @@ async def test_qwen_multiple_template_rag_generation(llm : QwenLLM,ctx : Retriev
|
|
| 57 |
print(" * Test Multiple Template Generation Done* ")
|
| 58 |
|
| 59 |
|
| 60 |
-
async def test_qwen_batch_generation(llm :
|
| 61 |
print(" * Test Batch Generation * ")
|
| 62 |
batch_responses = await llm.batch_generate([
|
| 63 |
"Jelaskan tentang Python",
|
|
|
|
| 1 |
from rag.retriever.retriever_types import *
|
| 2 |
+
from rag.pipeline.language_model import LM, LMConfig
|
| 3 |
|
| 4 |
import warnings
|
| 5 |
warnings.filterwarnings("ignore")
|
| 6 |
|
| 7 |
+
async def test_language_model():
|
| 8 |
print(" ===== Testing QWEN LLM ==== ")
|
| 9 |
+
"""Example usage of async LM"""
|
| 10 |
|
| 11 |
+
config = LMConfig(
|
| 12 |
temperature=0.5,
|
| 13 |
max_length=512,
|
| 14 |
generation_timeout=30
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# Using async context manager
|
| 26 |
+
async with LM(config) as llm:
|
| 27 |
await test_qwen_single_generation(llm)
|
| 28 |
await test_qwen_single_rag_generation(llm, contexts)
|
| 29 |
await test_qwen_multiple_template_rag_generation(llm, contexts)
|
| 30 |
await test_qwen_batch_generation(llm, contexts)
|
| 31 |
print(" ===== Testing LLM DONE ==== ")
|
| 32 |
|
| 33 |
+
async def test_qwen_single_generation(llm : LM):
|
| 34 |
print(" * Test Single Generation * ")
|
| 35 |
response = await llm.generate("Jelaskan tentang AI")
|
| 36 |
print(f"Response: {response}")
|
| 37 |
print(" * Test Single Generation Done * ")
|
| 38 |
|
| 39 |
+
async def test_qwen_single_rag_generation(llm : LM, ctx : RetrievalResult):
|
| 40 |
print(" * Test Single RAG Generation * ")
|
| 41 |
rag_response = await llm.rag_generate(
|
| 42 |
question="Apa itu AI dan machine learning?",
|
|
|
|
| 46 |
print(f"RAG Response: {rag_response}")
|
| 47 |
print(" * Test Single RAG Generation Done * ")
|
| 48 |
|
| 49 |
+
async def test_qwen_multiple_template_rag_generation(llm : LM,ctx : RetrievalResult):
|
| 50 |
print(" * Test Multiple Template Generation * ")
|
| 51 |
multi_responses = await llm.multi_template_generate(
|
| 52 |
question="Apa itu AI?",
|
|
|
|
| 57 |
print(" * Test Multiple Template Generation Done* ")
|
| 58 |
|
| 59 |
|
| 60 |
+
async def test_qwen_batch_generation(llm : LM, ctx : RetrievalResult):
|
| 61 |
print(" * Test Batch Generation * ")
|
| 62 |
batch_responses = await llm.batch_generate([
|
| 63 |
"Jelaskan tentang Python",
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/__chat__.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
from tests.inference_test import test_inference
|
| 2 |
-
|
|
|
|
| 3 |
import warnings
|
| 4 |
warnings.filterwarnings("ignore")
|
| 5 |
import asyncio
|
| 6 |
def run_test():
|
| 7 |
try:
|
| 8 |
# await test_document_retriever()
|
| 9 |
-
# await
|
| 10 |
-
|
| 11 |
except Exception as e:
|
| 12 |
print(e)
|
| 13 |
|
|
|
|
| 1 |
from tests.inference_test import test_inference
|
| 2 |
+
from huggingface_hub import login
|
| 3 |
+
login(new_session=False)
|
| 4 |
import warnings
|
| 5 |
warnings.filterwarnings("ignore")
|
| 6 |
import asyncio
|
| 7 |
def run_test():
|
| 8 |
try:
|
| 9 |
# await test_document_retriever()
|
| 10 |
+
# await test_language_model()
|
| 11 |
+
test_inference()
|
| 12 |
except Exception as e:
|
| 13 |
print(e)
|
| 14 |
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/__test__.py
CHANGED
|
@@ -1,8 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
# from tests.document_retriever_test import test_document_retriever
|
| 3 |
-
# from tests.document_retriever_test import test_document_retriever
|
| 4 |
-
# from tests.qwen_llm_test import test_qwen_llm
|
| 5 |
-
# from tests.inference_test import test_inference
|
| 6 |
from tests.rtc_test import test_rtc
|
| 7 |
import warnings
|
| 8 |
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from tests.rtc_test import test_rtc
|
| 2 |
import warnings
|
| 3 |
warnings.filterwarnings("ignore")
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/app.log
ADDED
|
File without changes
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/__init__.py
CHANGED
|
@@ -1,17 +1,44 @@
|
|
| 1 |
-
from rag.pipeline.
|
| 2 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 3 |
from rag.inference.inferencer import Inferencer, InferencerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
temperature=0.3,
|
| 7 |
max_length=512,
|
| 8 |
-
generation_timeout=
|
| 9 |
repetition_penalty=1.1,
|
| 10 |
-
max_workers =
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
llm =
|
| 15 |
config = config
|
| 16 |
)
|
| 17 |
|
|
@@ -22,29 +49,42 @@ inferencer_config = InferencerConfig(
|
|
| 22 |
)
|
| 23 |
|
| 24 |
document_retriever = LangChainRetriever(
|
| 25 |
-
embedding_model="all-MiniLM-L6-v2",
|
| 26 |
vectorstore_type="chroma",
|
| 27 |
-
vectorstore_path="
|
| 28 |
use_hybrid_search=True,
|
| 29 |
chunk_size=1000,
|
| 30 |
chunk_overlap=200
|
| 31 |
)
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
model=llm,
|
| 35 |
retriever=document_retriever,
|
|
|
|
| 36 |
reranker=None,
|
| 37 |
config=inferencer_config
|
| 38 |
)
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
async def get_stream_response(question):
|
| 45 |
-
async for item in inferencer.infer_stream(query = question,
|
| 46 |
-
enable_reranking=False,
|
| 47 |
-
template_type="main_template",
|
| 48 |
-
k=3):
|
| 49 |
-
print("Stream Response :", item)
|
| 50 |
-
yield item
|
|
|
|
| 1 |
+
from rag.pipeline.language_model import LM, LMConfig
|
| 2 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 3 |
from rag.inference.inferencer import Inferencer, InferencerConfig
|
| 4 |
+
from rag.agents.customer_service_agent import CSAgent
|
| 5 |
+
from rag.agents.query_maker_agent import QueryMakerAgent
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
from rag.web_search.duckduckgo_search import DuckDuckGoSearch
|
| 8 |
+
from rag.chat_template import get_chat_template
|
| 9 |
+
from transformers import BitsAndBytesConfig
|
| 10 |
+
import torch
|
| 11 |
|
| 12 |
+
import logging
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=logging.DEBUG,
|
| 17 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s() - %(message)s',
|
| 18 |
+
handlers=[
|
| 19 |
+
logging.FileHandler('app.log'),
|
| 20 |
+
logging.StreamHandler(sys.stdout)
|
| 21 |
+
]
|
| 22 |
+
)
|
| 23 |
+
bnb = BitsAndBytesConfig(
|
| 24 |
+
load_in_4bit=True, # Enable 4-bit quantization
|
| 25 |
+
bnb_4bit_use_double_quant=True, # Use double quantization
|
| 26 |
+
bnb_4bit_quant_type="nf4", # Use NF4 quantization
|
| 27 |
+
bnb_4bit_compute_dtype=torch.bfloat16, # Compute dtype for 4bit base models
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
config = LMConfig(
|
| 32 |
+
model_name = "Qwen/Qwen2.5-1.5B-Instruct",
|
| 33 |
temperature=0.3,
|
| 34 |
max_length=512,
|
| 35 |
+
generation_timeout=100,
|
| 36 |
repetition_penalty=1.1,
|
| 37 |
+
max_workers = 2,
|
| 38 |
+
quantization_config = bnb
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
llm = LM(
|
| 42 |
config = config
|
| 43 |
)
|
| 44 |
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
+
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
+
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
chunk_size=1000,
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
| 60 |
+
ddgs = DuckDuckGoSearch()
|
| 61 |
+
|
| 62 |
+
cs_inferencer = Inferencer(
|
| 63 |
model=llm,
|
| 64 |
retriever=document_retriever,
|
| 65 |
+
# search_engine = ddgs,
|
| 66 |
reranker=None,
|
| 67 |
config=inferencer_config
|
| 68 |
)
|
| 69 |
|
| 70 |
+
query_maker_inferencer = Inferencer(
|
| 71 |
+
model=llm,
|
| 72 |
+
config=inferencer_config
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
cs_agent = CSAgent(
|
| 76 |
+
inferencer = cs_inferencer,
|
| 77 |
+
prompt_template = get_chat_template("customer_service")
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
query_maker_chat_template = get_chat_template("query_maker")
|
| 81 |
+
query_maker_chat_template[1]["content"] = """{question}"""
|
| 82 |
+
|
| 83 |
+
query_maker_agent = QueryMakerAgent(
|
| 84 |
+
inferencer = query_maker_inferencer,
|
| 85 |
+
prompt_template = query_maker_chat_template
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/inference/inferencer.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 2 |
-
from rag.pipeline.
|
| 3 |
from rag.retriever.retriever_types import RetrievalResult
|
|
|
|
|
|
|
| 4 |
# from rag.pipeline.reranker import BGEM3Reranker
|
| 5 |
from typing import List, Union, Dict, Any, Optional, AsyncGenerator
|
| 6 |
import asyncio
|
|
@@ -29,15 +31,16 @@ class Inferencer:
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self,
|
| 32 |
-
model:
|
| 33 |
-
retriever: LangChainRetriever,
|
|
|
|
| 34 |
reranker=None,
|
| 35 |
config: Optional[InferencerConfig] = None):
|
| 36 |
"""
|
| 37 |
Initialize Inferencer
|
| 38 |
|
| 39 |
Args:
|
| 40 |
-
model:
|
| 41 |
retriever: LangChainRetriever instance
|
| 42 |
reranker: Reranker instance (optional)
|
| 43 |
config: InferencerConfig (optional)
|
|
@@ -45,6 +48,7 @@ class Inferencer:
|
|
| 45 |
self.model = model
|
| 46 |
self.retriever = retriever
|
| 47 |
self.reranker = reranker
|
|
|
|
| 48 |
self.config = config or InferencerConfig()
|
| 49 |
|
| 50 |
# Setup logging
|
|
@@ -85,6 +89,7 @@ class Inferencer:
|
|
| 85 |
try:
|
| 86 |
start_time = datetime.now()
|
| 87 |
contexts = await self.retriever.retrieve(query, k=k)
|
|
|
|
| 88 |
retrieval_time = (datetime.now() - start_time).total_seconds()
|
| 89 |
|
| 90 |
self.logger.info(f"Retrieved {len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts)} contexts in {retrieval_time:.2f}s")
|
|
@@ -292,7 +297,7 @@ class Inferencer:
|
|
| 292 |
yield chunk
|
| 293 |
|
| 294 |
async def infer(self,
|
| 295 |
-
query:
|
| 296 |
response_type: Union[List[str], str] = None,
|
| 297 |
k: Optional[int] = None,
|
| 298 |
enable_reranking: Optional[bool] = None,
|
|
@@ -321,8 +326,12 @@ class Inferencer:
|
|
| 321 |
|
| 322 |
try:
|
| 323 |
# Step 1: Retrieve contexts
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
# Step 2: Rerank contexts (if enabled)
|
| 327 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 328 |
if enable_rerank:
|
|
@@ -363,7 +372,34 @@ class Inferencer:
|
|
| 363 |
except Exception as e:
|
| 364 |
self.logger.error(f"Error during inference: {e}")
|
| 365 |
raise
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
async def infer_stream(self,
|
| 368 |
query: str,
|
| 369 |
k: Optional[int] = None,
|
|
@@ -389,8 +425,14 @@ class Inferencer:
|
|
| 389 |
|
| 390 |
try:
|
| 391 |
# Step 1: Retrieve contexts
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
|
|
|
| 394 |
# Step 2: Rerank contexts (if enabled)
|
| 395 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 396 |
if enable_rerank:
|
|
|
|
| 1 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 2 |
+
from rag.pipeline.language_model import LM, LMConfig
|
| 3 |
from rag.retriever.retriever_types import RetrievalResult
|
| 4 |
+
from rag.web_search.duckduckgo_search import DuckDuckGoSearch
|
| 5 |
+
from langchain_core.documents import Document
|
| 6 |
# from rag.pipeline.reranker import BGEM3Reranker
|
| 7 |
from typing import List, Union, Dict, Any, Optional, AsyncGenerator
|
| 8 |
import asyncio
|
|
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
def __init__(self,
|
| 34 |
+
model: LM,
|
| 35 |
+
retriever: LangChainRetriever = None,
|
| 36 |
+
search_engine = None,
|
| 37 |
reranker=None,
|
| 38 |
config: Optional[InferencerConfig] = None):
|
| 39 |
"""
|
| 40 |
Initialize Inferencer
|
| 41 |
|
| 42 |
Args:
|
| 43 |
+
model: LM instance
|
| 44 |
retriever: LangChainRetriever instance
|
| 45 |
reranker: Reranker instance (optional)
|
| 46 |
config: InferencerConfig (optional)
|
|
|
|
| 48 |
self.model = model
|
| 49 |
self.retriever = retriever
|
| 50 |
self.reranker = reranker
|
| 51 |
+
self.search_engine = search_engine
|
| 52 |
self.config = config or InferencerConfig()
|
| 53 |
|
| 54 |
# Setup logging
|
|
|
|
| 89 |
try:
|
| 90 |
start_time = datetime.now()
|
| 91 |
contexts = await self.retriever.retrieve(query, k=k)
|
| 92 |
+
self.logger.info(f"Retrieved Contexts : {contexts}")
|
| 93 |
retrieval_time = (datetime.now() - start_time).total_seconds()
|
| 94 |
|
| 95 |
self.logger.info(f"Retrieved {len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts)} contexts in {retrieval_time:.2f}s")
|
|
|
|
| 297 |
yield chunk
|
| 298 |
|
| 299 |
async def infer(self,
|
| 300 |
+
query: str,
|
| 301 |
response_type: Union[List[str], str] = None,
|
| 302 |
k: Optional[int] = None,
|
| 303 |
enable_reranking: Optional[bool] = None,
|
|
|
|
| 326 |
|
| 327 |
try:
|
| 328 |
# Step 1: Retrieve contexts
|
| 329 |
+
if(self.search_engine):
|
| 330 |
+
await self.retrieve_from_search_engine(query, k = k)
|
| 331 |
+
if(self.retriever):
|
| 332 |
+
retrieved_contexts = await self.retrieve_context(main_query, k=k)
|
| 333 |
+
else:
|
| 334 |
+
retrieved_contexts = ""
|
| 335 |
# Step 2: Rerank contexts (if enabled)
|
| 336 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 337 |
if enable_rerank:
|
|
|
|
| 372 |
except Exception as e:
|
| 373 |
self.logger.error(f"Error during inference: {e}")
|
| 374 |
raise
|
| 375 |
+
async def retrieve_from_search_engine(self, query: str, k: int = 3):
|
| 376 |
+
"""
|
| 377 |
+
Alternative method: Process results as they come
|
| 378 |
+
"""
|
| 379 |
+
from langchain_core.documents import Document
|
| 380 |
+
|
| 381 |
+
search_results = []
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
# Process results one by one as they come
|
| 385 |
+
async for result in self.search_engine.search(query, max_results=k):
|
| 386 |
+
self.logger.info(f"Processing SEO Result: {result[:100]}...")
|
| 387 |
+
|
| 388 |
+
doc = Document(
|
| 389 |
+
page_content=result,
|
| 390 |
+
metadata={"source": "internet_search", "query": query}
|
| 391 |
+
)
|
| 392 |
+
search_results.append(doc)
|
| 393 |
+
|
| 394 |
+
# Optionally add to retriever immediately
|
| 395 |
+
await self.retriever.add_documents([doc])
|
| 396 |
+
|
| 397 |
+
self.logger.info(f"Processed {len(search_results)} search results")
|
| 398 |
+
return search_results
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
self.logger.error(f"Error in retrieve_from_search_engine_alternative: {e}", exc_info=True)
|
| 402 |
+
raise
|
| 403 |
async def infer_stream(self,
|
| 404 |
query: str,
|
| 405 |
k: Optional[int] = None,
|
|
|
|
| 425 |
|
| 426 |
try:
|
| 427 |
# Step 1: Retrieve contexts
|
| 428 |
+
if(self.search_engine):
|
| 429 |
+
await self.retrieve_from_search_engine(query, k = k)
|
| 430 |
+
if(self.retriever is not None):
|
| 431 |
+
retrieved_contexts = await self.retrieve_context(query, k=k)
|
| 432 |
+
else:
|
| 433 |
+
retrieved_contexts = ""
|
| 434 |
|
| 435 |
+
|
| 436 |
# Step 2: Rerank contexts (if enabled)
|
| 437 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 438 |
if enable_rerank:
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/pipeline/qwen_llm.py
CHANGED
|
@@ -17,7 +17,7 @@ import copy
|
|
| 17 |
@dataclass
|
| 18 |
class QwenConfig:
|
| 19 |
"""Konfigurasi untuk model Qwen 0.5B"""
|
| 20 |
-
model_name: str = "Qwen/Qwen2.5-
|
| 21 |
device: str = "cuda"
|
| 22 |
torch_dtype: torch.dtype = torch.float16
|
| 23 |
max_length: int = 2048
|
|
@@ -286,14 +286,35 @@ class QwenLLM:
|
|
| 286 |
|
| 287 |
formatted_template = []
|
| 288 |
for cht in template_data:
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
-
|
| 295 |
-
cht["content"] = cht["content"].format(question=question)
|
| 296 |
-
formatted_template.append(cht)
|
| 297 |
|
| 298 |
self.logger.info("Formatted Template", formatted_template)
|
| 299 |
print("Forrmatted Template", formatted_template)
|
|
|
|
| 17 |
@dataclass
|
| 18 |
class QwenConfig:
|
| 19 |
"""Konfigurasi untuk model Qwen 0.5B"""
|
| 20 |
+
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 21 |
device: str = "cuda"
|
| 22 |
torch_dtype: torch.dtype = torch.float16
|
| 23 |
max_length: int = 2048
|
|
|
|
| 286 |
|
| 287 |
formatted_template = []
|
| 288 |
for cht in template_data:
|
| 289 |
+
# Create a copy of the content to avoid modifying the original
|
| 290 |
+
content = cht["content"]
|
| 291 |
+
|
| 292 |
+
# Format both placeholders at once to avoid KeyError
|
| 293 |
+
if "{context}" in content or "{question}" in content:
|
| 294 |
+
try:
|
| 295 |
+
content = content.format(
|
| 296 |
+
context=formatted_context,
|
| 297 |
+
question=question
|
| 298 |
+
)
|
| 299 |
+
except KeyError as e:
|
| 300 |
+
self.logger.error(f"Missing placeholder in template: {e}")
|
| 301 |
+
# Fallback: format only available placeholders
|
| 302 |
+
if "{context}" in content:
|
| 303 |
+
content = content.replace("{context}", formatted_context)
|
| 304 |
+
if "{question}" in content:
|
| 305 |
+
content = content.replace("{question}", question)
|
| 306 |
+
|
| 307 |
+
# Create new dict with formatted content
|
| 308 |
+
formatted_chat = {
|
| 309 |
+
"role": cht["role"],
|
| 310 |
+
"content": content
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
# Copy other fields if they exist
|
| 314 |
+
if "description" in cht:
|
| 315 |
+
formatted_chat["description"] = cht["description"]
|
| 316 |
|
| 317 |
+
formatted_template.append(formatted_chat)
|
|
|
|
|
|
|
| 318 |
|
| 319 |
self.logger.info("Formatted Template", formatted_template)
|
| 320 |
print("Forrmatted Template", formatted_template)
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/prompt_tuner/chat_template.py
CHANGED
|
@@ -8,18 +8,20 @@ def RAG_TEMPLATES():
|
|
| 8 |
|
| 9 |
1. Selalu berikan sapaan yang ramah dan profesional
|
| 10 |
2. Gunakan HANYA informasi dari knowledge base yang tersedia
|
| 11 |
-
3. Berikan jawaban yang jelas, mudah dipahami, dan terstruktur semuanya berdasarkan konteks yang diberikan
|
| 12 |
-
{context}
|
| 13 |
4. Jika informasi tidak tersedia, tawarkan alternatif bantuan atau arahkan ke channel yang tepat
|
| 14 |
5. Gunakan bahasa yang sopan dan empati terhadap kebutuhan pelanggan
|
| 15 |
6. Akhiri dengan penawaran bantuan lebih lanjut
|
|
|
|
| 16 |
""",
|
| 17 |
"description": "Template dengan system prompt untuk customer service professional"
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"role" : "user",
|
| 21 |
-
"content" : """
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
},
|
| 25 |
],
|
|
|
|
| 8 |
|
| 9 |
1. Selalu berikan sapaan yang ramah dan profesional
|
| 10 |
2. Gunakan HANYA informasi dari knowledge base yang tersedia
|
| 11 |
+
3. Berikan jawaban yang jelas, mudah dipahami, dan terstruktur semuanya berdasarkan konteks yang diberikan user.
|
|
|
|
| 12 |
4. Jika informasi tidak tersedia, tawarkan alternatif bantuan atau arahkan ke channel yang tepat
|
| 13 |
5. Gunakan bahasa yang sopan dan empati terhadap kebutuhan pelanggan
|
| 14 |
6. Akhiri dengan penawaran bantuan lebih lanjut
|
| 15 |
+
|
| 16 |
""",
|
| 17 |
"description": "Template dengan system prompt untuk customer service professional"
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"role" : "user",
|
| 21 |
+
"content" : """Dari konteks yang diberikan : {context}
|
| 22 |
+
|
| 23 |
+
berikan jawaban atas pertanyaan saya yaitu : {question}
|
| 24 |
+
|
| 25 |
"""
|
| 26 |
},
|
| 27 |
],
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rag/web_search/duckduckgo_search.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ddgs import DDGS
|
| 2 |
+
from langchain_community.document_loaders import AsyncChromiumLoader
|
| 3 |
+
from langchain_community.document_transformers import BeautifulSoupTransformer
|
| 4 |
+
import re
|
| 5 |
+
import logging
|
| 6 |
+
from typing import AsyncGenerator, List
|
| 7 |
+
|
| 8 |
+
class DuckDuckGoSearch:
|
| 9 |
+
def __init__(self, html_loader: AsyncChromiumLoader = None, html_parser = None):
|
| 10 |
+
# Initialize dengan default values jika tidak diberikan
|
| 11 |
+
self.html_loader = html_loader or AsyncChromiumLoader([])
|
| 12 |
+
self.html_parser = html_parser or BeautifulSoupTransformer()
|
| 13 |
+
self.logger = logging.getLogger("ddgs_logger")
|
| 14 |
+
|
| 15 |
+
async def get_page(self, urls: List[str]):
|
| 16 |
+
"""Get page content from URLs - returns list of documents"""
|
| 17 |
+
try:
|
| 18 |
+
self.html_loader.urls = urls
|
| 19 |
+
html = await self.html_loader.aload() # This returns a LIST
|
| 20 |
+
self.logger.info(f"search engine aload result: {len(html)} documents loaded")
|
| 21 |
+
|
| 22 |
+
docs_transformed = self.html_parser.transform_documents(
|
| 23 |
+
html,
|
| 24 |
+
tags_to_extract=["p"],
|
| 25 |
+
remove_unwanted_tags=["a"]
|
| 26 |
+
)
|
| 27 |
+
return docs_transformed # Returns LIST of documents
|
| 28 |
+
|
| 29 |
+
except Exception as e:
|
| 30 |
+
self.logger.error(f"Error loading pages: {e}", exc_info=True)
|
| 31 |
+
return [] # Return empty list on error
|
| 32 |
+
|
| 33 |
+
def truncate(self, text: str, max_words: int = 400) -> str:
|
| 34 |
+
"""Truncate text to specified number of words"""
|
| 35 |
+
if not text:
|
| 36 |
+
return ""
|
| 37 |
+
|
| 38 |
+
words = text.split()
|
| 39 |
+
if len(words) <= max_words:
|
| 40 |
+
return text
|
| 41 |
+
|
| 42 |
+
truncated = " ".join(words[:max_words])
|
| 43 |
+
return truncated + "..." if len(words) > max_words else truncated
|
| 44 |
+
|
| 45 |
+
async def search(self, query: str, max_results: int = 5) -> AsyncGenerator[str, None]:
|
| 46 |
+
"""
|
| 47 |
+
Search and yield page contents one by one
|
| 48 |
+
|
| 49 |
+
FIXED VERSION: Properly handle async iteration
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
self.logger.info(f"Searching for: {query} (max_results: {max_results})")
|
| 53 |
+
|
| 54 |
+
# Step 1: Get search results from DDGS (regular iterator)
|
| 55 |
+
results = DDGS().text(query, max_results=max_results)
|
| 56 |
+
urls = []
|
| 57 |
+
|
| 58 |
+
# Step 2: Extract URLs using regular for loop (NOT async for)
|
| 59 |
+
for result in results: # ← FIXED: Regular for loop
|
| 60 |
+
url = result.get('href')
|
| 61 |
+
if url:
|
| 62 |
+
urls.append(url)
|
| 63 |
+
|
| 64 |
+
self.logger.info(f"Found {len(urls)} URLs to process")
|
| 65 |
+
|
| 66 |
+
if not urls:
|
| 67 |
+
self.logger.warning("No URLs found from search results")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# Step 3: Get page content (await the coroutine first)
|
| 71 |
+
docs = await self.get_page(urls) # ← FIXED: Await first, get list
|
| 72 |
+
|
| 73 |
+
# Step 4: Process documents using regular for loop (NOT async for)
|
| 74 |
+
for doc in docs: # ← FIXED: Regular for loop on list
|
| 75 |
+
try:
|
| 76 |
+
if hasattr(doc, 'page_content') and doc.page_content:
|
| 77 |
+
# Clean up text
|
| 78 |
+
page_text = re.sub(r"\n\n+", "\n", doc.page_content)
|
| 79 |
+
page_text = page_text.strip()
|
| 80 |
+
|
| 81 |
+
if page_text: # Only yield if there's actual content
|
| 82 |
+
text = self.truncate(page_text)
|
| 83 |
+
yield text # Yield makes this an async generator
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
self.logger.error(f"Error processing document: {e}")
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
self.logger.error(f"Error in search method: {e}", exc_info=True)
|
| 91 |
+
# Don't re-raise, just log and return (generator will be empty)
|
| 92 |
+
|
| 93 |
+
async def search_with_metadata(self, query: str, max_results: int = 5) -> AsyncGenerator[dict, None]:
|
| 94 |
+
"""
|
| 95 |
+
Alternative method that yields dictionaries with metadata
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
results = DDGS().text(query, max_results=max_results)
|
| 99 |
+
urls_and_titles = []
|
| 100 |
+
|
| 101 |
+
# Collect URLs and titles
|
| 102 |
+
for result in results:
|
| 103 |
+
url = result.get('href')
|
| 104 |
+
title = result.get('title', 'No title')
|
| 105 |
+
if url:
|
| 106 |
+
urls_and_titles.append({'url': url, 'title': title})
|
| 107 |
+
|
| 108 |
+
if not urls_and_titles:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
# Get page content
|
| 112 |
+
urls = [item['url'] for item in urls_and_titles]
|
| 113 |
+
docs = await self.get_page(urls)
|
| 114 |
+
|
| 115 |
+
# Process and yield with metadata
|
| 116 |
+
for i, doc in enumerate(docs):
|
| 117 |
+
try:
|
| 118 |
+
if hasattr(doc, 'page_content') and doc.page_content:
|
| 119 |
+
page_text = re.sub(r"\n\n+", "\n", doc.page_content)
|
| 120 |
+
page_text = page_text.strip()
|
| 121 |
+
|
| 122 |
+
if page_text:
|
| 123 |
+
text = self.truncate(page_text)
|
| 124 |
+
|
| 125 |
+
# Get metadata if available
|
| 126 |
+
metadata = {}
|
| 127 |
+
if i < len(urls_and_titles):
|
| 128 |
+
metadata = urls_and_titles[i]
|
| 129 |
+
|
| 130 |
+
yield {
|
| 131 |
+
'content': text,
|
| 132 |
+
'url': metadata.get('url', 'Unknown'),
|
| 133 |
+
'title': metadata.get('title', 'No title'),
|
| 134 |
+
'word_count': len(text.split())
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
self.logger.error(f"Error processing document {i}: {e}")
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
self.logger.error(f"Error in search_with_metadata: {e}", exc_info=True)
|
space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/space/app/rtc/__init__.py
CHANGED
|
@@ -2,11 +2,13 @@ from openai import OpenAI
|
|
| 2 |
from elevenlabs.client import ElevenLabs
|
| 3 |
from tts.audio_edge_tts import EdgeTTS
|
| 4 |
from config.constant import OPENAI_API_KEY, ELEVENLABS_API_KEY
|
|
|
|
| 5 |
from rtc.rtc_call import RTCHandler
|
| 6 |
from stt.whisper_stt import WhisperSTT
|
| 7 |
|
| 8 |
-
whisper_stt = WhisperSTT("
|
| 9 |
edge_tts = EdgeTTS("id-ID-ArdiNeural", "+0%", "+0%")
|
|
|
|
| 10 |
rtc_handler = RTCHandler(whisper_stt, edge_tts)
|
| 11 |
|
| 12 |
def handle_rtc():
|
|
|
|
| 2 |
from elevenlabs.client import ElevenLabs
|
| 3 |
from tts.audio_edge_tts import EdgeTTS
|
| 4 |
from config.constant import OPENAI_API_KEY, ELEVENLABS_API_KEY
|
| 5 |
+
# from rtc.rtc_call import RTCHandler
|
| 6 |
from rtc.rtc_call import RTCHandler
|
| 7 |
from stt.whisper_stt import WhisperSTT
|
| 8 |
|
| 9 |
+
whisper_stt = WhisperSTT(model_size = "base", device = "cuda")
|
| 10 |
edge_tts = EdgeTTS("id-ID-ArdiNeural", "+0%", "+0%")
|
| 11 |
+
openai_client = OpenAI(api_key = OPENAI_API_KEY)
|
| 12 |
rtc_handler = RTCHandler(whisper_stt, edge_tts)
|
| 13 |
|
| 14 |
def handle_rtc():
|