Spaces:
Configuration error
Configuration error
Commit ·
d5bcb72
1
Parent(s): e77156b
Deploy files from GitHub repository
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- app/rag/__init__.py +2 -2
- app/rag/agents/customer_service_agent.py +3 -3
- app/rag/retriever/langchain_retriever.py +2 -0
- space/app/rag/__init__.py +2 -2
- space/app/rag/agents/customer_service_agent.py +3 -3
- space/app/rag/retriever/langchain_retriever.py +2 -0
- space/space/app/rag/agents/__init__.py +0 -0
- space/space/app/rag/agents/agents.py +16 -0
- space/space/app/rag/agents/customer_service_agent.py +33 -0
- space/space/app/rag/agents/gpt_customer_service_agent.py +13 -0
- space/space/app/rag/agents/query_maker_agent.py +13 -0
- space/space/app/rag/chat_template/__init__.py +29 -0
- space/space/app/rag/chat_template/customer_service.txt +12 -0
- space/space/app/rag/chat_template/query_maker.txt +35 -0
- space/space/app/rag/chat_template/query_maker_temp.txt +30 -0
- space/space/app/rag/inference/__init__.py +0 -0
- space/space/app/rag/pipeline/language_model.py +947 -0
- space/space/app/rag/retriever/__init__.py +0 -0
- space/space/app/rag/retriever/langchain_retriever.py +25 -7
- space/space/app/rag/web_search/__init__.py +0 -0
- space/space/app/rtc/rtc_call_gpt.py +364 -0
- space/space/app/tests/qwen_llm_test.py +9 -9
- space/space/space/app/__chat__.py +4 -3
- space/space/space/app/__test__.py +0 -5
- space/space/space/app/app.log +0 -0
- space/space/space/app/rag/__init__.py +61 -21
- space/space/space/app/rag/inference/inferencer.py +51 -9
- space/space/space/app/rag/pipeline/qwen_llm.py +29 -8
- space/space/space/app/rag/prompt_tuner/chat_template.py +6 -4
- space/space/space/app/rag/web_search/duckduckgo_search.py +142 -0
- space/space/space/app/rtc/__init__.py +3 -1
- space/space/space/app/rtc/rtc_call.py +3 -3
- space/space/space/app/stt/whisper_stt.py +70 -7
- space/space/space/app/tests/ddgs_test.py +7 -0
- space/space/space/app/tests/inference_test.py +14 -70
- space/space/space/space/space/.env.example +3 -0
- space/space/space/space/space/.gitattributes +37 -0
- space/space/space/space/space/.github/workflows/deploy-to-huggingface.yml +52 -0
- space/space/space/space/space/.gitignore +9 -0
- space/space/space/space/space/Dockerfile +49 -0
- space/space/space/space/space/README.md +31 -0
- space/space/space/space/space/app/.gradio/certificate.pem +31 -0
- space/space/space/space/space/app/__chat__.py +14 -0
- space/space/space/space/space/app/__server__.py +3 -0
- space/space/space/space/space/app/__test__.py +19 -0
- space/space/space/space/space/app/config/__init__.py +0 -0
- space/space/space/space/space/app/config/constant.py +7 -0
- space/space/space/space/space/app/rag/__init__.py +50 -0
- space/space/space/space/space/app/rag/inference/inferencer.py +552 -0
- space/space/space/space/space/app/rag/pipeline/__init__.py +0 -0
app/rag/__init__.py
CHANGED
|
@@ -49,11 +49,11 @@ inferencer_config = InferencerConfig(
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
-
embedding_model="
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
-
chunk_size=
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
+
embedding_model="BAAI/bge-large-en",
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
+
chunk_size=3000,
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
app/rag/agents/customer_service_agent.py
CHANGED
|
@@ -8,9 +8,9 @@ class CSAgent(Agent):
|
|
| 8 |
self.prompt_template = prompt_template
|
| 9 |
self.file_paths = [
|
| 10 |
"../documents/bpjs.pdf",
|
| 11 |
-
"../documents/pph21.pdf",
|
| 12 |
-
"../documents/lembur.pdf",
|
| 13 |
-
"../documents/uu13.pdf",
|
| 14 |
"../documents/file.pdf",
|
| 15 |
]
|
| 16 |
async def load_documents(self):
|
|
|
|
| 8 |
self.prompt_template = prompt_template
|
| 9 |
self.file_paths = [
|
| 10 |
"../documents/bpjs.pdf",
|
| 11 |
+
# "../documents/pph21.pdf",
|
| 12 |
+
# "../documents/lembur.pdf",
|
| 13 |
+
# "../documents/uu13.pdf",
|
| 14 |
"../documents/file.pdf",
|
| 15 |
]
|
| 16 |
async def load_documents(self):
|
app/rag/retriever/langchain_retriever.py
CHANGED
|
@@ -175,6 +175,7 @@ class LangChainRetriever(BaseRetriever):
|
|
| 175 |
vectorstore=self.vectorstore,
|
| 176 |
search_kwargs={"k": 10}
|
| 177 |
)
|
|
|
|
| 178 |
self.retriever = EnsembleRetriever(
|
| 179 |
retrievers=[vector_retriever, self.bm25_retriever],
|
| 180 |
weights=[0.5, 0.5] # Equal weight to both retrievers
|
|
@@ -197,6 +198,7 @@ class LangChainRetriever(BaseRetriever):
|
|
| 197 |
None, self.retriever.get_relevant_documents, query
|
| 198 |
)
|
| 199 |
retrieved_docs = retrieved_docs[:k]
|
|
|
|
| 200 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 201 |
|
| 202 |
retrieval_time = time.time() - start_time
|
|
|
|
| 175 |
vectorstore=self.vectorstore,
|
| 176 |
search_kwargs={"k": 10}
|
| 177 |
)
|
| 178 |
+
|
| 179 |
self.retriever = EnsembleRetriever(
|
| 180 |
retrievers=[vector_retriever, self.bm25_retriever],
|
| 181 |
weights=[0.5, 0.5] # Equal weight to both retrievers
|
|
|
|
| 198 |
None, self.retriever.get_relevant_documents, query
|
| 199 |
)
|
| 200 |
retrieved_docs = retrieved_docs[:k]
|
| 201 |
+
|
| 202 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 203 |
|
| 204 |
retrieval_time = time.time() - start_time
|
space/app/rag/__init__.py
CHANGED
|
@@ -49,11 +49,11 @@ inferencer_config = InferencerConfig(
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
-
embedding_model="
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
-
chunk_size=
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
+
embedding_model="BAAI/bge-large-en",
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
+
chunk_size=3000,
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
space/app/rag/agents/customer_service_agent.py
CHANGED
|
@@ -8,9 +8,9 @@ class CSAgent(Agent):
|
|
| 8 |
self.prompt_template = prompt_template
|
| 9 |
self.file_paths = [
|
| 10 |
"../documents/bpjs.pdf",
|
| 11 |
-
"../documents/pph21.pdf",
|
| 12 |
-
"../documents/lembur.pdf",
|
| 13 |
-
"../documents/uu13.pdf",
|
| 14 |
"../documents/file.pdf",
|
| 15 |
]
|
| 16 |
async def load_documents(self):
|
|
|
|
| 8 |
self.prompt_template = prompt_template
|
| 9 |
self.file_paths = [
|
| 10 |
"../documents/bpjs.pdf",
|
| 11 |
+
# "../documents/pph21.pdf",
|
| 12 |
+
# "../documents/lembur.pdf",
|
| 13 |
+
# "../documents/uu13.pdf",
|
| 14 |
"../documents/file.pdf",
|
| 15 |
]
|
| 16 |
async def load_documents(self):
|
space/app/rag/retriever/langchain_retriever.py
CHANGED
|
@@ -175,6 +175,7 @@ class LangChainRetriever(BaseRetriever):
|
|
| 175 |
vectorstore=self.vectorstore,
|
| 176 |
search_kwargs={"k": 10}
|
| 177 |
)
|
|
|
|
| 178 |
self.retriever = EnsembleRetriever(
|
| 179 |
retrievers=[vector_retriever, self.bm25_retriever],
|
| 180 |
weights=[0.5, 0.5] # Equal weight to both retrievers
|
|
@@ -197,6 +198,7 @@ class LangChainRetriever(BaseRetriever):
|
|
| 197 |
None, self.retriever.get_relevant_documents, query
|
| 198 |
)
|
| 199 |
retrieved_docs = retrieved_docs[:k]
|
|
|
|
| 200 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 201 |
|
| 202 |
retrieval_time = time.time() - start_time
|
|
|
|
| 175 |
vectorstore=self.vectorstore,
|
| 176 |
search_kwargs={"k": 10}
|
| 177 |
)
|
| 178 |
+
|
| 179 |
self.retriever = EnsembleRetriever(
|
| 180 |
retrievers=[vector_retriever, self.bm25_retriever],
|
| 181 |
weights=[0.5, 0.5] # Equal weight to both retrievers
|
|
|
|
| 198 |
None, self.retriever.get_relevant_documents, query
|
| 199 |
)
|
| 200 |
retrieved_docs = retrieved_docs[:k]
|
| 201 |
+
|
| 202 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 203 |
|
| 204 |
retrieval_time = time.time() - start_time
|
space/space/app/rag/agents/__init__.py
ADDED
|
File without changes
|
space/space/app/rag/agents/agents.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.pipeline.language_model import LM
|
| 2 |
+
from rag.inference.inferencer import Inferencer
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
class Agent(ABC):
|
| 5 |
+
def __init__(self, inferencer:Inferencer, prompt_template = [
|
| 6 |
+
{
|
| 7 |
+
"role" : "system",
|
| 8 |
+
"content":"You are an agent that doing some specic task"
|
| 9 |
+
}
|
| 10 |
+
]):
|
| 11 |
+
self.inferencer = inferencer
|
| 12 |
+
self.inferencer.model.prompt_template = prompt_template
|
| 13 |
+
self.prompt = prompt_template
|
| 14 |
+
@abstractmethod
|
| 15 |
+
async def get_result(self):
|
| 16 |
+
pass
|
space/space/app/rag/agents/customer_service_agent.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.agents.agents import Agent
|
| 2 |
+
from rag.inference.inferencer import Inferencer
|
| 3 |
+
|
| 4 |
+
class CSAgent(Agent):
|
| 5 |
+
def __init__(self, inferencer : Inferencer , prompt_template):
|
| 6 |
+
super().__init__(inferencer, prompt_template)
|
| 7 |
+
self.inferencer = inferencer
|
| 8 |
+
self.prompt_template = prompt_template
|
| 9 |
+
self.file_paths = [
|
| 10 |
+
"../documents/bpjs.pdf",
|
| 11 |
+
# "../documents/pph21.pdf",
|
| 12 |
+
# "../documents/lembur.pdf",
|
| 13 |
+
# "../documents/uu13.pdf",
|
| 14 |
+
"../documents/file.pdf",
|
| 15 |
+
]
|
| 16 |
+
async def load_documents(self):
|
| 17 |
+
for file_path in self.file_paths:
|
| 18 |
+
await self.add_doc(file_path)
|
| 19 |
+
|
| 20 |
+
async def add_doc(self, file_path):
|
| 21 |
+
result = await self.inferencer.retriever.add_document_from_file(file_path)
|
| 22 |
+
if result.success:
|
| 23 |
+
print(f"Successfully processed: {result.document_metadata.file_name}")
|
| 24 |
+
print(f"Chunks created: {result.document_metadata.chunk_count}")
|
| 25 |
+
else:
|
| 26 |
+
print(f"Failed to process: {result.error_message}")
|
| 27 |
+
|
| 28 |
+
async def get_result(self, question):
|
| 29 |
+
self.inferencer.model.prompt_template = self.prompt_template
|
| 30 |
+
async for item in self.inferencer.infer_stream(query = question,
|
| 31 |
+
enable_reranking=False,
|
| 32 |
+
k=3):
|
| 33 |
+
yield item
|
space/space/app/rag/agents/gpt_customer_service_agent.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.agents.agents import Agent
|
| 2 |
+
from rag.pipeline.language_model import LM
|
| 3 |
+
from rag.inference.inferencer import Inferencer
|
| 4 |
+
|
| 5 |
+
class GPTCSAgent(Agent):
|
| 6 |
+
def __init__(self, inferencer : Inferencer , prompt_template):
|
| 7 |
+
super().__init__(inferencer, prompt_template)
|
| 8 |
+
self.inferencer = inferencer
|
| 9 |
+
self.prompt_template = prompt_template
|
| 10 |
+
async def get_result(self, question : str):
|
| 11 |
+
self.inferencer.model.prompt_template = self.prompt_template
|
| 12 |
+
print("Question received :", question)
|
| 13 |
+
return await self.inferencer.infer(query = question)
|
space/space/app/rag/agents/query_maker_agent.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.agents.agents import Agent
|
| 2 |
+
from rag.pipeline.language_model import LM
|
| 3 |
+
from rag.inference.inferencer import Inferencer
|
| 4 |
+
|
| 5 |
+
class QueryMakerAgent(Agent):
|
| 6 |
+
def __init__(self, inferencer : Inferencer , prompt_template):
|
| 7 |
+
super().__init__(inferencer, prompt_template)
|
| 8 |
+
self.inferencer = inferencer
|
| 9 |
+
self.prompt_template = prompt_template
|
| 10 |
+
async def get_result(self, question : str):
|
| 11 |
+
self.inferencer.model.prompt_template = self.prompt_template
|
| 12 |
+
print("Question received :", question)
|
| 13 |
+
return await self.inferencer.infer(query = question)
|
space/space/app/rag/chat_template/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def read_template_txt(file_path):
|
| 2 |
+
"""Baca file txt biasa"""
|
| 3 |
+
with open(f"rag/chat_template/{file_path}.txt", 'r', encoding='utf-8') as f:
|
| 4 |
+
return f.read()
|
| 5 |
+
def get_chat_template(file_name):
|
| 6 |
+
sys_prompt = read_template_txt(file_name)
|
| 7 |
+
return [
|
| 8 |
+
{
|
| 9 |
+
"role" : "system",
|
| 10 |
+
"content" : f"""
|
| 11 |
+
{sys_prompt}
|
| 12 |
+
"""
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"role" : "user",
|
| 16 |
+
"content" : """
|
| 17 |
+
|
| 18 |
+
Please answer properly:
|
| 19 |
+
{question}
|
| 20 |
+
|
| 21 |
+
From given context :
|
| 22 |
+
{context}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
}
|
| 28 |
+
]
|
| 29 |
+
|
space/space/app/rag/chat_template/customer_service.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a friendly and professional Customer Service for Human Resource Information System (HRIS) field,
|
| 2 |
+
representative, fluent in Indonesian. Your job is to assist customers with accurate information based on your company's basic knowledge. Follow these guidelines:
|
| 3 |
+
|
| 4 |
+
- Always greet customers in a friendly and professional manner.
|
| 5 |
+
- Your answers are contextual and objective.
|
| 6 |
+
- Provide clear, easy-to-understand, and structured answers based on the context provided by the user.
|
| 7 |
+
- If information is not available, offer alternative assistance or direct them to the appropriate channel.
|
| 8 |
+
- Use polite language and empathize with the customer's needs.
|
| 9 |
+
- Conclude by offering further assistance.
|
| 10 |
+
- You are highly skilled in the area relevant to the given context.
|
| 11 |
+
|
| 12 |
+
Please use the given context to answer accurately.
|
space/space/app/rag/chat_template/query_maker.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Anda adalah agen AI yang tepat dan objektif,
|
| 2 |
+
Anda bertugas mengubah pertanyaan atau pernyataan pengguna menjadi query yang eksplisit dan efisien untuk keperluan pencarian dokumen dalam sistem RAG (Retrieval-Augmented Generation).
|
| 3 |
+
|
| 4 |
+
Ikuti langkah-langkah berikut:
|
| 5 |
+
|
| 6 |
+
1. Ekstrak bagian-bagian penting dari input pengguna:
|
| 7 |
+
- **Intent**: Tujuan utama atau jenis permintaan (misalnya: apa itu, cara, syarat, apakah bisa, berapa).
|
| 8 |
+
- **Entity/Noun Phrase**: Objek utama yang dibahas (misalnya: BPJS, tokenizer truncation, RWKV, gaji).
|
| 9 |
+
- **Context**: Informasi pendukung yang menyempitkan fokus (misalnya: kecelakaan kerja, gaji 1 juta per bulan, perusahaan mitra BPJS).
|
| 10 |
+
- **Question**: Pertanyaan spesifik yang ingin dijawab (misalnya: bagaimana prosesnya, apa manfaatnya, berapa jumlahnya).
|
| 11 |
+
|
| 12 |
+
2. Setelah semua elemen diidentifikasi, bentuk **Query RAG** dengan struktur: [INTENT] + [ENTITY] + [CONTEXT] + [QUESTION]
|
| 13 |
+
3. Gunakan bahasa natural yang ringkas, namun informatif dan eksplisit.
|
| 14 |
+
4. Generate hanya hasil akhirnya saja berupa satu buah kalimat
|
| 15 |
+
|
| 16 |
+
Contoh 0 :
|
| 17 |
+
User Input:
|
| 18 |
+
> Apa itu BPJS
|
| 19 |
+
Output : Pengertian BPJS
|
| 20 |
+
|
| 21 |
+
Contoh 1 :
|
| 22 |
+
User Input:
|
| 23 |
+
> Di mana lokasi PT Sakura System Solution ?
|
| 24 |
+
|
| 25 |
+
Output: Lokasi PT Sakura System Solution
|
| 26 |
+
|
| 27 |
+
Contoh 2:
|
| 28 |
+
User Input:
|
| 29 |
+
> Saya mengalami kecelakaan di kantor dan ingin tahu apakah bisa klaim BPJS karena perusahaan saya adalah mitra.
|
| 30 |
+
|
| 31 |
+
Output: apakah bisa klaim BPJS kecelakaan kerja di kantor jika perusahaan mitra dan apakah saya memenuhi syarat
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
**Tugas Anda sekarang:**
|
| 35 |
+
Lakukan proses di atas untuk setiap input pengguna yang diberikan. Hasilkan query RAG akhir yang siap digunakan dalam pencarian dokumen.
|
space/space/app/rag/chat_template/query_maker_temp.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Anda adalah agen AI yang tepat dan objektif,
|
| 2 |
+
Anda bertugas mengubah pertanyaan atau pernyataan pengguna menjadi query yang eksplisit dan efisien untuk keperluan pencarian dokumen dalam sistem RAG (Retrieval-Augmented Generation).
|
| 3 |
+
|
| 4 |
+
Ikuti langkah-langkah berikut:
|
| 5 |
+
|
| 6 |
+
1. Ekstrak bagian-bagian penting dari input pengguna:
|
| 7 |
+
- **Intent**: Tujuan utama atau jenis permintaan (misalnya: apa itu, cara, syarat, apakah bisa, berapa).
|
| 8 |
+
- **Entity/Noun Phrase**: Objek utama yang dibahas (misalnya: BPJS, tokenizer truncation, RWKV, gaji).
|
| 9 |
+
- **Context**: Informasi pendukung yang menyempitkan fokus (misalnya: kecelakaan kerja, gaji 1 juta per bulan, perusahaan mitra BPJS).
|
| 10 |
+
- **Question**: Pertanyaan spesifik yang ingin dijawab (misalnya: bagaimana prosesnya, apa manfaatnya, berapa jumlahnya).
|
| 11 |
+
|
| 12 |
+
2. Setelah semua elemen diidentifikasi, bentuk **Query RAG** dengan struktur: [INTENT] + [ENTITY] + [CONTEXT] + [QUESTION]
|
| 13 |
+
3. Gunakan bahasa natural yang ringkas, namun informatif dan eksplisit.
|
| 14 |
+
4. Generate hanya hasil akhirnya saja berupa satu buah kalimat
|
| 15 |
+
|
| 16 |
+
Contoh 1 :
|
| 17 |
+
User Input:
|
| 18 |
+
> Di mana lokasi PT Sakura System Solution ?
|
| 19 |
+
|
| 20 |
+
Output: Lokasi PT Sakura System Solution
|
| 21 |
+
|
| 22 |
+
Contoh 2:
|
| 23 |
+
User Input:
|
| 24 |
+
> Saya mengalami kecelakaan di kantor dan ingin tahu apakah bisa klaim BPJS karena perusahaan saya adalah mitra.
|
| 25 |
+
|
| 26 |
+
Output: apakah bisa klaim BPJS kecelakaan kerja di kantor jika perusahaan mitra dan apakah saya memenuhi syarat
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
**Tugas Anda sekarang:**
|
| 30 |
+
Lakukan proses di atas untuk setiap input pengguna yang diberikan. Hasilkan query RAG akhir yang siap digunakan dalam pencarian dokumen.
|
space/space/app/rag/inference/__init__.py
ADDED
|
File without changes
|
space/space/app/rag/pipeline/language_model.py
ADDED
|
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import asyncio
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer, BitsAndBytesConfig
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Optional, Dict, Any, List, Union, Callable, Awaitable, AsyncGenerator
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
from functools import partial
|
| 11 |
+
from threading import Thread
|
| 12 |
+
from rag.retriever.retriever_types import RetrievalResult
|
| 13 |
+
from langchain_core.documents import Document
|
| 14 |
+
import copy
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class LMConfig:
|
| 18 |
+
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 19 |
+
device: str = "cuda"
|
| 20 |
+
torch_dtype: torch.dtype = torch.float16
|
| 21 |
+
max_length: int = 2048
|
| 22 |
+
temperature: float = 0.7
|
| 23 |
+
top_p: float = 0.8
|
| 24 |
+
top_k: int = 50
|
| 25 |
+
do_sample: bool = True
|
| 26 |
+
quantization_config: any = None
|
| 27 |
+
pad_token_id: Optional[int] = None
|
| 28 |
+
eos_token_id: Optional[int] = None
|
| 29 |
+
# RAG-specific configs
|
| 30 |
+
max_context_length: int = 1500
|
| 31 |
+
context_separator: str = "\n---\n"
|
| 32 |
+
instruction_template: str = "system" # "system", "instruction", "custom"
|
| 33 |
+
# Async-specific configs
|
| 34 |
+
max_workers: int = 2
|
| 35 |
+
generation_timeout: float = 30
|
| 36 |
+
repetition_penalty: float = 1.0
|
| 37 |
+
# Streaming-specific configs
|
| 38 |
+
stream_timeout: float = 100 # timeout untuk stream chunk
|
| 39 |
+
skip_prompt: bool = True # skip prompt dari streaming output
|
| 40 |
+
|
| 41 |
+
class LM:
|
| 42 |
+
"""
|
| 43 |
+
Async LLM Qwen 0.5B dengan interface yang mudah digunakan
|
| 44 |
+
Termasuk prompt formatting khusus untuk RAG (Retrieval-Augmented Generation)
|
| 45 |
+
Dan support untuk text streaming
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: Optional[LMConfig] = None, prompt_template = [
|
| 49 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 50 |
+
{"role": "user", "content": "{question}"}
|
| 51 |
+
] ):
|
| 52 |
+
"""
|
| 53 |
+
Inisialisasi LM
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
config: Konfigurasi model (optional, akan menggunakan default jika None)
|
| 57 |
+
"""
|
| 58 |
+
if(config is None):
|
| 59 |
+
self.config = LMConfig()
|
| 60 |
+
else:
|
| 61 |
+
self.config = config
|
| 62 |
+
self.tokenizer : AutoTokenizer = None
|
| 63 |
+
self.model = None
|
| 64 |
+
self.generation_config = None
|
| 65 |
+
self.is_loaded = False
|
| 66 |
+
self.executor = ThreadPoolExecutor(max_workers=self.config.max_workers)
|
| 67 |
+
self._lock = asyncio.Lock()
|
| 68 |
+
# Setup logging
|
| 69 |
+
logging.basicConfig(level=logging.INFO)
|
| 70 |
+
self.logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
# RAG prompt templates
|
| 73 |
+
self.prompt_template = prompt_template
|
| 74 |
+
|
| 75 |
+
async def load_model(self) -> None:
|
| 76 |
+
"""Load model dan tokenizer secara async"""
|
| 77 |
+
async with self._lock:
|
| 78 |
+
if self.is_loaded:
|
| 79 |
+
self.logger.info("Model already loaded")
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
self.logger.info(f"Loading model: {self.config.model_name}")
|
| 84 |
+
|
| 85 |
+
# Load tokenizer dalam thread pool
|
| 86 |
+
self.tokenizer = await asyncio.get_event_loop().run_in_executor(
|
| 87 |
+
self.executor,
|
| 88 |
+
lambda: AutoTokenizer.from_pretrained(
|
| 89 |
+
self.config.model_name,
|
| 90 |
+
trust_remote_code=True,
|
| 91 |
+
torch_dtype="auto",
|
| 92 |
+
device_map="auto",
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Load model dalam thread pool
|
| 97 |
+
self.model = await asyncio.get_event_loop().run_in_executor(
|
| 98 |
+
self.executor,
|
| 99 |
+
lambda: AutoModelForCausalLM.from_pretrained(
|
| 100 |
+
self.config.model_name,
|
| 101 |
+
quantization_config=self.config.quantization_config,
|
| 102 |
+
torch_dtype=self.config.torch_dtype,
|
| 103 |
+
device_map=self.config.device,
|
| 104 |
+
trust_remote_code=True
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Setup generation config
|
| 109 |
+
self.generation_config = GenerationConfig(
|
| 110 |
+
max_length=self.config.max_length,
|
| 111 |
+
temperature=self.config.temperature,
|
| 112 |
+
top_p=self.config.top_p,
|
| 113 |
+
top_k=self.config.top_k,
|
| 114 |
+
do_sample=self.config.do_sample,
|
| 115 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 116 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 117 |
+
repetition_penalty = self.config.repetition_penalty,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.is_loaded = True
|
| 121 |
+
self.logger.info("Model loaded successfully!")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
self.logger.error(f"Error loading model: {e}")
|
| 125 |
+
raise
|
| 126 |
+
|
| 127 |
+
def get_available_templates(self) -> List[str]:
|
| 128 |
+
"""
|
| 129 |
+
Dapatkan list template yang tersedia
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
List of available template names
|
| 133 |
+
"""
|
| 134 |
+
return list(self.prompt_template)
|
| 135 |
+
|
| 136 |
+
def preview_template(self, template_type: str, sample_question: str = "Apa itu AI?",
|
| 137 |
+
sample_context: str = "Artificial Intelligence adalah teknologi...") -> str:
|
| 138 |
+
"""
|
| 139 |
+
Preview template dengan sample data
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
template_type: Template type to preview
|
| 143 |
+
sample_question: Sample question
|
| 144 |
+
sample_context: Sample context
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Preview of formatted template
|
| 148 |
+
"""
|
| 149 |
+
if template_type not in self.prompt_template:
|
| 150 |
+
return f"Template '{template_type}' tidak tersedia. Available: {self.get_available_templates()}"
|
| 151 |
+
|
| 152 |
+
template_data = copy.deepcopy(self.prompt_template)
|
| 153 |
+
# template_key = "user_template" if "user_template" in template_data else "template"
|
| 154 |
+
|
| 155 |
+
return template_data["content"].format(
|
| 156 |
+
context=sample_context,
|
| 157 |
+
question=sample_question
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def _format_context(self, contexts: Union[List[str], RetrievalResult], numbering: bool = True) -> str:
|
| 161 |
+
"""
|
| 162 |
+
Format retrieved contexts menjadi string yang coherent
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
contexts: List of contexts (string atau RetrievalResult objects)
|
| 166 |
+
numbering: Whether to add document numbering
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Formatted context string
|
| 170 |
+
"""
|
| 171 |
+
if not contexts:
|
| 172 |
+
return ""
|
| 173 |
+
|
| 174 |
+
formatted_contexts = []
|
| 175 |
+
self.logger.info(f"Context : {contexts}")
|
| 176 |
+
self.logger.info(f"Is RetrievalResult Contexts = {isinstance(contexts, RetrievalResult)}")
|
| 177 |
+
if isinstance(contexts, RetrievalResult):
|
| 178 |
+
for i, ctx in enumerate(contexts.documents, 1):
|
| 179 |
+
if numbering:
|
| 180 |
+
header = f"[Dokumen {i}"
|
| 181 |
+
if contexts.scores[i - 1]:
|
| 182 |
+
header += f" (Skor: {contexts.scores[i - 1]:.3f})"
|
| 183 |
+
header += "]"
|
| 184 |
+
else:
|
| 185 |
+
header = "[Dokumen"
|
| 186 |
+
header += "]"
|
| 187 |
+
formatted_contexts.append(f"{header}\n{ctx.page_content}")
|
| 188 |
+
else:
|
| 189 |
+
for i, ctx in enumerate(contexts, 1):
|
| 190 |
+
if isinstance(ctx, str):
|
| 191 |
+
header = f"[Dokumen {i}]" if numbering else "[Dokumen]"
|
| 192 |
+
formatted_contexts.append(f"{header}\n{ctx}")
|
| 193 |
+
else:
|
| 194 |
+
header = f"[Dokumen {i}]" if numbering else "[Dokumen]"
|
| 195 |
+
formatted_contexts.append(f"{header}\n{str(ctx)}")
|
| 196 |
+
|
| 197 |
+
return self.config.context_separator.join(formatted_contexts)
|
| 198 |
+
|
| 199 |
+
def _truncate_context(self, context: str, max_length: int) -> str:
|
| 200 |
+
"""
|
| 201 |
+
Truncate context jika terlalu panjang
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
context: Context string
|
| 205 |
+
max_length: Maximum length in characters
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Truncated context
|
| 209 |
+
"""
|
| 210 |
+
if len(context) <= max_length:
|
| 211 |
+
return context
|
| 212 |
+
|
| 213 |
+
# Truncate dan tambahkan indicator
|
| 214 |
+
truncated = context[:max_length - 50]
|
| 215 |
+
return truncated + "\n\n[... Context dipotong karena terlalu panjang ...]"
|
| 216 |
+
|
| 217 |
+
async def format_rag_prompt(self,
|
| 218 |
+
question: str,
|
| 219 |
+
contexts: Union[List[str], RetrievalResult],
|
| 220 |
+
template_type: Optional[str] = None,
|
| 221 |
+
custom_template: Optional[str] = None,
|
| 222 |
+
include_metadata: bool = True,
|
| 223 |
+
context_numbering: bool = True,
|
| 224 |
+
max_contexts: Optional[int] = None) -> str:
|
| 225 |
+
"""
|
| 226 |
+
Format prompt untuk RAG dengan berbagai template options (async)
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def _format_sync():
|
| 230 |
+
|
| 231 |
+
# Handle RetrievalResult secara eksplisit
|
| 232 |
+
if isinstance(contexts, RetrievalResult):
|
| 233 |
+
docs = contexts.documents
|
| 234 |
+
if max_contexts:
|
| 235 |
+
docs = docs[:max_contexts]
|
| 236 |
+
processed_contexts = RetrievalResult(
|
| 237 |
+
documents=docs,
|
| 238 |
+
scores=contexts.scores[:len(docs)] if contexts.scores else [],
|
| 239 |
+
query=contexts.query,
|
| 240 |
+
retrieval_time=contexts.retrieval_time,
|
| 241 |
+
metadata=contexts.metadata
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
# contexts diasumsikan sebagai list biasa (list[str] atau list[Document])
|
| 245 |
+
processed_contexts = contexts[:max_contexts] if max_contexts and len(contexts) > max_contexts else contexts
|
| 246 |
+
|
| 247 |
+
# Format context menjadi string
|
| 248 |
+
formatted_context = self._format_context(processed_contexts, context_numbering)
|
| 249 |
+
|
| 250 |
+
# Truncate jika panjang melebihi batas
|
| 251 |
+
formatted_context = self._truncate_context(
|
| 252 |
+
formatted_context,
|
| 253 |
+
self.config.max_context_length
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Tambah metadata jika diizinkan dan konteks adalah RetrievalResult
|
| 257 |
+
if include_metadata and isinstance(processed_contexts, RetrievalResult):
|
| 258 |
+
metadata_info = []
|
| 259 |
+
for i, doc in enumerate(processed_contexts.documents, 1):
|
| 260 |
+
if hasattr(doc, "metadata") and doc.metadata:
|
| 261 |
+
metadata_info.append(f"Dokumen {i}: {doc.metadata}")
|
| 262 |
+
# if metadata_info:
|
| 263 |
+
# formatted_context += f"\n\n[Metadata]\n" + "\n".join(metadata_info)
|
| 264 |
+
|
| 265 |
+
return formatted_context
|
| 266 |
+
|
| 267 |
+
# Jalankan _format_sync di thread pool
|
| 268 |
+
formatted_context = await asyncio.get_event_loop().run_in_executor(
|
| 269 |
+
self.executor, _format_sync
|
| 270 |
+
)
|
| 271 |
+
self.logger.info(f"Formatted Context {formatted_context}")
|
| 272 |
+
# Tentukan template yang akan dipakai
|
| 273 |
+
if(template_type == ""):
|
| 274 |
+
self.config.instruction_template = "system"
|
| 275 |
+
# Gunakan custom template jika disediakan
|
| 276 |
+
if custom_template:
|
| 277 |
+
return custom_template.format(
|
| 278 |
+
context=formatted_context,
|
| 279 |
+
question=question
|
| 280 |
+
)
|
| 281 |
+
elif self.prompt_template:
|
| 282 |
+
print("question", question)
|
| 283 |
+
|
| 284 |
+
template_data = copy.deepcopy(self.prompt_template)
|
| 285 |
+
print("template = ", template_type, "rag template = ", template_data)
|
| 286 |
+
# template_key = "user_template" if "user_template" in template_data else "template"
|
| 287 |
+
|
| 288 |
+
formatted_template = []
|
| 289 |
+
for cht in template_data:
|
| 290 |
+
# Create a copy of the content to avoid modifying the original
|
| 291 |
+
content = cht["content"]
|
| 292 |
+
|
| 293 |
+
# Format both placeholders at once to avoid KeyError
|
| 294 |
+
if "{context}" in content or "{question}" in content:
|
| 295 |
+
try:
|
| 296 |
+
content = content.format(
|
| 297 |
+
context=formatted_context,
|
| 298 |
+
question=question
|
| 299 |
+
)
|
| 300 |
+
except KeyError as e:
|
| 301 |
+
self.logger.error(f"Missing placeholder in template: {e}")
|
| 302 |
+
# Fallback: format only available placeholders
|
| 303 |
+
if "{context}" in content:
|
| 304 |
+
content = content.replace("{context}", formatted_context)
|
| 305 |
+
if "{question}" in content:
|
| 306 |
+
content = content.replace("{question}", question)
|
| 307 |
+
|
| 308 |
+
# Create new dict with formatted content
|
| 309 |
+
formatted_chat = {
|
| 310 |
+
"role": cht["role"],
|
| 311 |
+
"content": content
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
# Copy other fields if they exist
|
| 315 |
+
if "description" in cht:
|
| 316 |
+
formatted_chat["description"] = cht["description"]
|
| 317 |
+
|
| 318 |
+
formatted_template.append(formatted_chat)
|
| 319 |
+
|
| 320 |
+
# self.logger.info(f"Formatted Template {formatted_template}")
|
| 321 |
+
# print("Forrmatted Template", formatted_template)
|
| 322 |
+
return formatted_template
|
| 323 |
+
else:
|
| 324 |
+
# Fallback default template
|
| 325 |
+
return [
|
| 326 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 327 |
+
{"role": "user", "content": question}
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
async def generate_stream(self,
|
| 331 |
+
prompt: List[Dict],
|
| 332 |
+
max_new_tokens: Optional[int] = None,
|
| 333 |
+
temperature: Optional[float] = None,
|
| 334 |
+
top_p: Optional[float] = None,
|
| 335 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 336 |
+
"""
|
| 337 |
+
Generate text dari prompt secara streaming async
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
prompt: Input text prompt
|
| 341 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 342 |
+
temperature: Temperature untuk generation (override config)
|
| 343 |
+
top_p: Top-p untuk generation (override config)
|
| 344 |
+
**kwargs: Parameter tambahan untuk generation
|
| 345 |
+
|
| 346 |
+
Yields:
|
| 347 |
+
Generated text chunks
|
| 348 |
+
"""
|
| 349 |
+
await self._check_model_loaded()
|
| 350 |
+
|
| 351 |
+
# Setup streamer
|
| 352 |
+
streamer = TextIteratorStreamer(
|
| 353 |
+
self.tokenizer,
|
| 354 |
+
timeout=self.config.stream_timeout,
|
| 355 |
+
skip_prompt=self.config.skip_prompt,
|
| 356 |
+
skip_special_tokens=True
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def _generate_sync():
|
| 360 |
+
try:
|
| 361 |
+
# Tokenize input
|
| 362 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 363 |
+
prompt,
|
| 364 |
+
add_generation_prompt=True,
|
| 365 |
+
return_tensors="pt"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Override generation config jika diperlukan
|
| 369 |
+
gen_config = self.generation_config
|
| 370 |
+
if any([max_new_tokens, temperature, top_p]):
|
| 371 |
+
gen_config = GenerationConfig(
|
| 372 |
+
max_new_tokens=max_new_tokens or self.config.max_length,
|
| 373 |
+
temperature=temperature or self.config.temperature,
|
| 374 |
+
top_p=top_p or self.config.top_p,
|
| 375 |
+
top_k=self.config.top_k,
|
| 376 |
+
do_sample=self.config.do_sample,
|
| 377 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 378 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 379 |
+
repetition_penalty=self.config.repetition_penalty,
|
| 380 |
+
**kwargs
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Move to GPU
|
| 384 |
+
self.model.to("cuda")
|
| 385 |
+
input_ids = inputs.to("cuda")
|
| 386 |
+
|
| 387 |
+
# Generate dalam thread terpisah
|
| 388 |
+
generation_kwargs = {
|
| 389 |
+
"input_ids": input_ids,
|
| 390 |
+
"generation_config": gen_config,
|
| 391 |
+
"streamer": streamer,
|
| 392 |
+
**kwargs
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
| 396 |
+
thread.start()
|
| 397 |
+
|
| 398 |
+
return thread
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
self.logger.error(f"Error during stream generation setup: {e}")
|
| 402 |
+
raise
|
| 403 |
+
|
| 404 |
+
# Setup generation thread
|
| 405 |
+
generation_thread = await asyncio.get_event_loop().run_in_executor(
|
| 406 |
+
self.executor, _generate_sync
|
| 407 |
+
)
|
| 408 |
+
err = None
|
| 409 |
+
try:
|
| 410 |
+
# Stream tokens
|
| 411 |
+
for token in streamer:
|
| 412 |
+
if token: # Skip empty tokens
|
| 413 |
+
yield token
|
| 414 |
+
|
| 415 |
+
# Wait for generation thread to finish
|
| 416 |
+
err = await asyncio.get_event_loop().run_in_executor(
|
| 417 |
+
self.executor, generation_thread.join
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
except Exception as e:
|
| 421 |
+
self.logger.error(f"Error during streaming: {e}, {err}")
|
| 422 |
+
# Make sure thread is cleaned up
|
| 423 |
+
if generation_thread.is_alive():
|
| 424 |
+
generation_thread.join(timeout=1.0)
|
| 425 |
+
raise
|
| 426 |
+
|
| 427 |
+
async def rag_generate_stream(self,
|
| 428 |
+
question: str,
|
| 429 |
+
contexts: Union[List[str], RetrievalResult],
|
| 430 |
+
template_type: Optional[str] = None,
|
| 431 |
+
max_new_tokens: Optional[int] = None,
|
| 432 |
+
temperature: Optional[float] = None,
|
| 433 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 434 |
+
"""
|
| 435 |
+
Generate jawaban untuk RAG secara streaming async
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
question: User question
|
| 439 |
+
contexts: List of retrieved contexts
|
| 440 |
+
template_type: Template type untuk formatting
|
| 441 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 442 |
+
temperature: Temperature untuk generation
|
| 443 |
+
**kwargs: Parameter tambahan untuk generation
|
| 444 |
+
|
| 445 |
+
Yields:
|
| 446 |
+
Generated answer chunks
|
| 447 |
+
"""
|
| 448 |
+
await self._check_model_loaded()
|
| 449 |
+
|
| 450 |
+
# Format prompt
|
| 451 |
+
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 452 |
+
|
| 453 |
+
# Generate dengan temperature yang lebih rendah untuk RAG (lebih faktual)
|
| 454 |
+
temp = temperature if temperature is not None else 0.3
|
| 455 |
+
|
| 456 |
+
async for chunk in self.generate_stream(
|
| 457 |
+
prompt=prompt,
|
| 458 |
+
max_new_tokens=max_new_tokens,
|
| 459 |
+
temperature=temp,
|
| 460 |
+
**kwargs
|
| 461 |
+
):
|
| 462 |
+
yield chunk
|
| 463 |
+
|
| 464 |
+
async def chat_stream(self,
|
| 465 |
+
messages: List[Dict[str, str]],
|
| 466 |
+
max_new_tokens: Optional[int] = None,
|
| 467 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 468 |
+
"""
|
| 469 |
+
Chat dengan format conversation secara streaming async
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 473 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 474 |
+
**kwargs: Parameter tambahan untuk generation
|
| 475 |
+
|
| 476 |
+
Yields:
|
| 477 |
+
Response text chunks
|
| 478 |
+
"""
|
| 479 |
+
await self._check_model_loaded()
|
| 480 |
+
|
| 481 |
+
def _format_chat():
|
| 482 |
+
try:
|
| 483 |
+
# Format messages untuk chat
|
| 484 |
+
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 485 |
+
messages,
|
| 486 |
+
tokenize=False,
|
| 487 |
+
add_generation_prompt=True
|
| 488 |
+
)
|
| 489 |
+
return formatted_prompt
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
self.logger.error(f"Error during chat formatting: {e}")
|
| 493 |
+
raise
|
| 494 |
+
|
| 495 |
+
# Format chat template dalam thread pool
|
| 496 |
+
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 497 |
+
self.executor, _format_chat
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
async for chunk in self.generate_stream(
|
| 501 |
+
formatted_prompt,
|
| 502 |
+
max_new_tokens=max_new_tokens,
|
| 503 |
+
**kwargs
|
| 504 |
+
):
|
| 505 |
+
yield chunk
|
| 506 |
+
|
| 507 |
+
async def rag_chat_stream(self,
|
| 508 |
+
messages: List[Dict[str, str]],
|
| 509 |
+
contexts: Union[List[str], RetrievalResult],
|
| 510 |
+
template_type: Optional[str] = None,
|
| 511 |
+
max_new_tokens: Optional[int] = None,
|
| 512 |
+
**kwargs) -> AsyncGenerator[str, None]:
|
| 513 |
+
"""
|
| 514 |
+
RAG Chat dengan format conversation secara streaming async
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 518 |
+
contexts: List of retrieved contexts
|
| 519 |
+
template_type: Template type untuk formatting
|
| 520 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 521 |
+
**kwargs: Parameter tambahan untuk generation
|
| 522 |
+
|
| 523 |
+
Yields:
|
| 524 |
+
Response text chunks
|
| 525 |
+
"""
|
| 526 |
+
await self._check_model_loaded()
|
| 527 |
+
|
| 528 |
+
# Ambil last user message sebagai question
|
| 529 |
+
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 530 |
+
if not user_messages:
|
| 531 |
+
raise ValueError("No user message found in conversation")
|
| 532 |
+
|
| 533 |
+
last_question = user_messages[-1]["content"]
|
| 534 |
+
|
| 535 |
+
# Generate RAG response secara streaming
|
| 536 |
+
async for chunk in self.rag_generate_stream(
|
| 537 |
+
question=last_question,
|
| 538 |
+
contexts=contexts,
|
| 539 |
+
template_type=template_type,
|
| 540 |
+
max_new_tokens=max_new_tokens,
|
| 541 |
+
**kwargs
|
| 542 |
+
):
|
| 543 |
+
yield chunk
|
| 544 |
+
|
| 545 |
+
# Utility method untuk collect full response dari stream
|
| 546 |
+
async def collect_stream(self, stream_generator: AsyncGenerator[str, None]) -> str:
|
| 547 |
+
"""
|
| 548 |
+
Collect semua chunks dari stream generator menjadi full text
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
stream_generator: AsyncGenerator yang menghasilkan text chunks
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
Complete generated text
|
| 555 |
+
"""
|
| 556 |
+
chunks = []
|
| 557 |
+
async for chunk in stream_generator:
|
| 558 |
+
chunks.append(chunk)
|
| 559 |
+
return "".join(chunks)
|
| 560 |
+
|
| 561 |
+
async def multi_template_generate(self,
|
| 562 |
+
question: str,
|
| 563 |
+
contexts: Union[List[str], RetrievalResult],
|
| 564 |
+
template_types: List[str],
|
| 565 |
+
max_new_tokens: Optional[int] = None,
|
| 566 |
+
**kwargs) -> Dict[str, str]:
|
| 567 |
+
"""
|
| 568 |
+
Generate jawaban menggunakan multiple templates secara concurrent
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
question: User question
|
| 572 |
+
contexts: List of retrieved contexts
|
| 573 |
+
template_types: List of template types to use
|
| 574 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 575 |
+
**kwargs: Parameter tambahan untuk generation
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
Dictionary dengan template_type sebagai key dan response sebagai value
|
| 579 |
+
"""
|
| 580 |
+
await self._check_model_loaded()
|
| 581 |
+
|
| 582 |
+
# Create tasks untuk concurrent generation
|
| 583 |
+
tasks = []
|
| 584 |
+
for template_type in template_types:
|
| 585 |
+
task = asyncio.create_task(
|
| 586 |
+
self._generate_single_template(
|
| 587 |
+
question, contexts, template_type, max_new_tokens, **kwargs
|
| 588 |
+
)
|
| 589 |
+
)
|
| 590 |
+
tasks.append((template_type, task))
|
| 591 |
+
|
| 592 |
+
# Wait for all tasks
|
| 593 |
+
results = {}
|
| 594 |
+
for template_type, task in tasks:
|
| 595 |
+
try:
|
| 596 |
+
response = await task
|
| 597 |
+
results[template_type] = response
|
| 598 |
+
except Exception as e:
|
| 599 |
+
self.logger.error(f"Error generating with template {template_type}: {e}")
|
| 600 |
+
results[template_type] = f"Error: {str(e)}"
|
| 601 |
+
|
| 602 |
+
return results
|
| 603 |
+
|
| 604 |
+
async def _generate_single_template(self,
|
| 605 |
+
question: str,
|
| 606 |
+
contexts: Union[List[str], RetrievalResult],
|
| 607 |
+
template_type: str,
|
| 608 |
+
max_new_tokens: Optional[int] = None,
|
| 609 |
+
**kwargs) -> str:
|
| 610 |
+
"""Helper method untuk single template generation"""
|
| 611 |
+
return await self.rag_generate(
|
| 612 |
+
question=question,
|
| 613 |
+
contexts=contexts,
|
| 614 |
+
template_type=template_type,
|
| 615 |
+
max_new_tokens=max_new_tokens,
|
| 616 |
+
**kwargs
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
async def rag_generate(self,
|
| 620 |
+
question: str,
|
| 621 |
+
contexts: Union[List[str], RetrievalResult],
|
| 622 |
+
template_type: Optional[str] = None,
|
| 623 |
+
max_new_tokens: Optional[int] = None,
|
| 624 |
+
temperature: Optional[float] = None,
|
| 625 |
+
**kwargs) -> str:
|
| 626 |
+
"""
|
| 627 |
+
Generate jawaban untuk RAG secara async
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
question: User question
|
| 631 |
+
contexts: List of retrieved contexts
|
| 632 |
+
template_type: Template type untuk formatting
|
| 633 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 634 |
+
temperature: Temperature untuk generation
|
| 635 |
+
**kwargs: Parameter tambahan untuk generation
|
| 636 |
+
|
| 637 |
+
Returns:
|
| 638 |
+
Generated answer
|
| 639 |
+
"""
|
| 640 |
+
await self._check_model_loaded()
|
| 641 |
+
|
| 642 |
+
# Format prompt
|
| 643 |
+
prompt = await self.format_rag_prompt(question, contexts, template_type)
|
| 644 |
+
|
| 645 |
+
# Generate dengan temperature yang lebih rendah untuk RAG (lebih faktual)
|
| 646 |
+
temp = temperature if temperature is not None else 0.3
|
| 647 |
+
|
| 648 |
+
return await self.generate(
|
| 649 |
+
prompt=prompt,
|
| 650 |
+
max_new_tokens=max_new_tokens,
|
| 651 |
+
temperature=temp,
|
| 652 |
+
**kwargs
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
async def rag_chat(self,
|
| 656 |
+
messages: List[Dict[str, str]],
|
| 657 |
+
contexts: Union[List[str], RetrievalResult],
|
| 658 |
+
template_type: Optional[str] = None,
|
| 659 |
+
max_new_tokens: Optional[int] = None,
|
| 660 |
+
**kwargs) -> str:
|
| 661 |
+
"""
|
| 662 |
+
RAG Chat dengan format conversation secara async
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 666 |
+
contexts: List of retrieved contexts
|
| 667 |
+
template_type: Template type untuk formatting
|
| 668 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 669 |
+
**kwargs: Parameter tambahan untuk generation
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Response text
|
| 673 |
+
"""
|
| 674 |
+
await self._check_model_loaded()
|
| 675 |
+
|
| 676 |
+
# Ambil last user message sebagai question
|
| 677 |
+
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
| 678 |
+
if not user_messages:
|
| 679 |
+
raise ValueError("No user message found in conversation")
|
| 680 |
+
|
| 681 |
+
last_question = user_messages[-1]["content"]
|
| 682 |
+
|
| 683 |
+
# Generate RAG response
|
| 684 |
+
return await self.rag_generate(
|
| 685 |
+
question=last_question,
|
| 686 |
+
contexts=contexts,
|
| 687 |
+
template_type=template_type,
|
| 688 |
+
max_new_tokens=max_new_tokens,
|
| 689 |
+
**kwargs
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
async def _check_model_loaded(self) -> None:
|
| 693 |
+
"""Cek apakah model sudah di-load secara async"""
|
| 694 |
+
if not self.is_loaded:
|
| 695 |
+
raise RuntimeError("Model belum di-load. Panggil await load_model() terlebih dahulu.")
|
| 696 |
+
|
| 697 |
+
async def generate(self,
|
| 698 |
+
prompt: Union[List[Dict], str],
|
| 699 |
+
max_new_tokens: Optional[int] = None,
|
| 700 |
+
temperature: Optional[float] = None,
|
| 701 |
+
top_p: Optional[float] = None,
|
| 702 |
+
**kwargs) -> str:
|
| 703 |
+
"""
|
| 704 |
+
Generate text dari prompt secara async
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
prompt: Input text prompt
|
| 708 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 709 |
+
temperature: Temperature untuk generation (override config)
|
| 710 |
+
top_p: Top-p untuk generation (override config)
|
| 711 |
+
**kwargs: Parameter tambahan untuk generation
|
| 712 |
+
|
| 713 |
+
Returns:
|
| 714 |
+
Generated text
|
| 715 |
+
"""
|
| 716 |
+
|
| 717 |
+
await self._check_model_loaded()
|
| 718 |
+
|
| 719 |
+
def _generate_sync():
|
| 720 |
+
try:
|
| 721 |
+
# Tokenize input
|
| 722 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 723 |
+
prompt,
|
| 724 |
+
add_generation_prompt=True,
|
| 725 |
+
return_tensors="pt"
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
# Override generation config jika diperlukan
|
| 729 |
+
gen_config = self.generation_config
|
| 730 |
+
if any([max_new_tokens, temperature, top_p]):
|
| 731 |
+
gen_config = GenerationConfig(
|
| 732 |
+
max_new_tokens=max_new_tokens or self.config.max_length,
|
| 733 |
+
temperature=temperature or self.config.temperature,
|
| 734 |
+
top_p=top_p or self.config.top_p,
|
| 735 |
+
top_k=self.config.top_k,
|
| 736 |
+
do_sample=self.config.do_sample,
|
| 737 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 738 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 739 |
+
repetition_penalty = self.config.repetition_penalty,
|
| 740 |
+
**kwargs
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# Generate
|
| 744 |
+
with torch.no_grad():
|
| 745 |
+
|
| 746 |
+
self.model.to("cuda")
|
| 747 |
+
input_ids = inputs.to("cuda")
|
| 748 |
+
prompt_length = input_ids.shape[-1]
|
| 749 |
+
outputs = self.model.generate(
|
| 750 |
+
input_ids,
|
| 751 |
+
generation_config=gen_config,
|
| 752 |
+
**kwargs
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
# Decode output
|
| 756 |
+
generated_text = self.tokenizer.decode(
|
| 757 |
+
outputs[0][prompt_length:],
|
| 758 |
+
skip_special_tokens=True
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
print("Generated Text", generated_text)
|
| 762 |
+
# Remove input prompt dari output
|
| 763 |
+
return generated_text
|
| 764 |
+
|
| 765 |
+
except Exception as e:
|
| 766 |
+
self.logger.error(f"Error during generation: {e}")
|
| 767 |
+
raise
|
| 768 |
+
|
| 769 |
+
# Run generation in thread pool dengan timeout
|
| 770 |
+
try:
|
| 771 |
+
result = await asyncio.wait_for(
|
| 772 |
+
asyncio.get_event_loop().run_in_executor(self.executor, _generate_sync),
|
| 773 |
+
timeout=self.config.generation_timeout
|
| 774 |
+
)
|
| 775 |
+
return result
|
| 776 |
+
except asyncio.TimeoutError:
|
| 777 |
+
self.logger.error(f"Generation timeout after {self.config.generation_timeout} seconds")
|
| 778 |
+
raise TimeoutError(f"Generation timeout after {self.config.generation_timeout} seconds")
|
| 779 |
+
|
| 780 |
+
async def chat(self,
|
| 781 |
+
messages: List[Dict[str, str]],
|
| 782 |
+
max_new_tokens: Optional[int] = None,
|
| 783 |
+
**kwargs) -> str:
|
| 784 |
+
"""
|
| 785 |
+
Chat dengan format conversation secara async
|
| 786 |
+
|
| 787 |
+
Args:
|
| 788 |
+
messages: List of messages dengan format [{"role": "user", "content": "..."}]
|
| 789 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 790 |
+
**kwargs: Parameter tambahan untuk generation
|
| 791 |
+
|
| 792 |
+
Returns:
|
| 793 |
+
Response text
|
| 794 |
+
"""
|
| 795 |
+
await self._check_model_loaded()
|
| 796 |
+
|
| 797 |
+
def _format_chat():
|
| 798 |
+
try:
|
| 799 |
+
# Format messages untuk chat
|
| 800 |
+
formatted_prompt = self.tokenizer.apply_chat_template(
|
| 801 |
+
messages,
|
| 802 |
+
chat_template="rag",
|
| 803 |
+
return_tensors="pt"
|
| 804 |
+
)
|
| 805 |
+
return formatted_prompt
|
| 806 |
+
|
| 807 |
+
except Exception as e:
|
| 808 |
+
self.logger.error(f"Error during chat formatting: {e}")
|
| 809 |
+
raise
|
| 810 |
+
|
| 811 |
+
# Format chat template dalam thread pool
|
| 812 |
+
formatted_prompt = await asyncio.get_event_loop().run_in_executor(
|
| 813 |
+
self.executor, _format_chat
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
return await self.generate(
|
| 817 |
+
formatted_prompt,
|
| 818 |
+
max_new_tokens=max_new_tokens,
|
| 819 |
+
**kwargs
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
async def update_config(self, **kwargs) -> None:
|
| 823 |
+
"""
|
| 824 |
+
Update konfigurasi model secara async
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
**kwargs: Parameter konfigurasi yang akan diupdate
|
| 828 |
+
"""
|
| 829 |
+
async with self._lock:
|
| 830 |
+
for key, value in kwargs.items():
|
| 831 |
+
if hasattr(self.config, key):
|
| 832 |
+
setattr(self.config, key, value)
|
| 833 |
+
self.logger.info(f"Updated {key} to {value}")
|
| 834 |
+
else:
|
| 835 |
+
self.logger.warning(f"Unknown config parameter: {key}")
|
| 836 |
+
|
| 837 |
+
# Update generation config jika model sudah loaded
|
| 838 |
+
if self.is_loaded:
|
| 839 |
+
self.generation_config = GenerationConfig(
|
| 840 |
+
max_length=self.config.max_length,
|
| 841 |
+
temperature=self.config.temperature,
|
| 842 |
+
top_p=self.config.top_p,
|
| 843 |
+
top_k=self.config.top_k,
|
| 844 |
+
do_sample=self.config.do_sample,
|
| 845 |
+
pad_token_id=self.config.pad_token_id or self.tokenizer.eos_token_id,
|
| 846 |
+
eos_token_id=self.config.eos_token_id or self.tokenizer.eos_token_id,
|
| 847 |
+
repetition_penalty = self.config.repetition_penalty,
|
| 848 |
+
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
async def get_model_info(self) -> Dict[str, Any]:
|
| 852 |
+
"""
|
| 853 |
+
Dapatkan informasi model secara async
|
| 854 |
+
|
| 855 |
+
Returns:
|
| 856 |
+
Dictionary dengan informasi model
|
| 857 |
+
"""
|
| 858 |
+
info = {
|
| 859 |
+
"model_name": self.config.model_name,
|
| 860 |
+
"is_loaded": self.is_loaded,
|
| 861 |
+
"config": self.config.__dict__
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
if self.is_loaded:
|
| 865 |
+
# Get model info dalam thread pool
|
| 866 |
+
def _get_info():
|
| 867 |
+
return {
|
| 868 |
+
"vocab_size": self.tokenizer.vocab_size,
|
| 869 |
+
"model_parameters": sum(p.numel() for p in self.model.parameters()),
|
| 870 |
+
"device": str(next(self.model.parameters()).device)
|
| 871 |
+
}
|
| 872 |
+
|
| 873 |
+
model_info = await asyncio.get_event_loop().run_in_executor(
|
| 874 |
+
self.executor, _get_info
|
| 875 |
+
)
|
| 876 |
+
info.update(model_info)
|
| 877 |
+
|
| 878 |
+
return info
|
| 879 |
+
|
| 880 |
+
async def batch_generate(self,
|
| 881 |
+
prompts: List[str],
|
| 882 |
+
max_new_tokens: Optional[int] = None,
|
| 883 |
+
**kwargs) -> List[str]:
|
| 884 |
+
"""
|
| 885 |
+
Generate multiple prompts secara batch dan concurrent
|
| 886 |
+
|
| 887 |
+
Args:
|
| 888 |
+
prompts: List of prompts to generate
|
| 889 |
+
max_new_tokens: Maximum token baru yang akan di-generate
|
| 890 |
+
**kwargs: Parameter tambahan untuk generation
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
List of generated texts
|
| 894 |
+
"""
|
| 895 |
+
await self._check_model_loaded()
|
| 896 |
+
|
| 897 |
+
# Create tasks untuk concurrent generation
|
| 898 |
+
tasks = [
|
| 899 |
+
asyncio.create_task(
|
| 900 |
+
self.generate(prompt, max_new_tokens=max_new_tokens, **kwargs)
|
| 901 |
+
)
|
| 902 |
+
for prompt in prompts
|
| 903 |
+
]
|
| 904 |
+
|
| 905 |
+
# Wait for all tasks
|
| 906 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 907 |
+
|
| 908 |
+
# Process results
|
| 909 |
+
processed_results = []
|
| 910 |
+
for i, result in enumerate(results):
|
| 911 |
+
if isinstance(result, Exception):
|
| 912 |
+
self.logger.error(f"Error generating prompt {i}: {result}")
|
| 913 |
+
processed_results.append(f"Error: {str(result)}")
|
| 914 |
+
else:
|
| 915 |
+
processed_results.append(result)
|
| 916 |
+
|
| 917 |
+
return processed_results
|
| 918 |
+
|
| 919 |
+
async def close(self) -> None:
|
| 920 |
+
"""
|
| 921 |
+
Cleanup resources secara async
|
| 922 |
+
"""
|
| 923 |
+
self.logger.info("Closing LM...")
|
| 924 |
+
|
| 925 |
+
# Shutdown executor
|
| 926 |
+
self.executor.shutdown(wait=True)
|
| 927 |
+
|
| 928 |
+
# Clear GPU memory
|
| 929 |
+
if hasattr(self, 'model') and self.model is not None:
|
| 930 |
+
del self.model
|
| 931 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
| 932 |
+
del self.tokenizer
|
| 933 |
+
|
| 934 |
+
if torch.cuda.is_available():
|
| 935 |
+
torch.cuda.empty_cache()
|
| 936 |
+
|
| 937 |
+
self.is_loaded = False
|
| 938 |
+
self.logger.info("LM closed successfully")
|
| 939 |
+
|
| 940 |
+
async def __aenter__(self):
|
| 941 |
+
"""Async context manager entry"""
|
| 942 |
+
await self.load_model()
|
| 943 |
+
return self
|
| 944 |
+
|
| 945 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 946 |
+
"""Async context manager exit"""
|
| 947 |
+
await self.close()
|
space/space/app/rag/retriever/__init__.py
ADDED
|
File without changes
|
space/space/app/rag/retriever/langchain_retriever.py
CHANGED
|
@@ -6,6 +6,7 @@ from langchain_openai import OpenAIEmbeddings
|
|
| 6 |
|
| 7 |
# Vector stores
|
| 8 |
from langchain_community.vectorstores import Chroma, FAISS, Pinecone
|
|
|
|
| 9 |
|
| 10 |
# Retriever base
|
| 11 |
from langchain_core.vectorstores import VectorStoreRetriever
|
|
@@ -24,7 +25,6 @@ from langchain_core.documents import Document
|
|
| 24 |
|
| 25 |
logging.basicConfig(level=logging.INFO)
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
-
|
| 28 |
class LangChainRetriever(BaseRetriever):
|
| 29 |
"""LangChain-based retriever with multiple format support"""
|
| 30 |
|
|
@@ -160,17 +160,34 @@ class LangChainRetriever(BaseRetriever):
|
|
| 160 |
except Exception as e:
|
| 161 |
logger.error(f"Error adding documents: {str(e)}")
|
| 162 |
return False
|
| 163 |
-
|
| 164 |
async def _update_bm25_retriever(self, documents: List[Document]):
|
| 165 |
try:
|
|
|
|
| 166 |
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 167 |
-
self.
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
)
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
logger.error(f"Error updating BM25 retriever: {str(e)}")
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
async def retrieve(self, query: str, k: int = 5) -> RetrievalResult:
|
| 175 |
try:
|
| 176 |
import time
|
|
@@ -181,6 +198,7 @@ class LangChainRetriever(BaseRetriever):
|
|
| 181 |
None, self.retriever.get_relevant_documents, query
|
| 182 |
)
|
| 183 |
retrieved_docs = retrieved_docs[:k]
|
|
|
|
| 184 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 185 |
|
| 186 |
retrieval_time = time.time() - start_time
|
|
@@ -222,4 +240,4 @@ class LangChainRetriever(BaseRetriever):
|
|
| 222 |
return list(self.processed_documents.values())
|
| 223 |
|
| 224 |
def get_supported_formats(self) -> List[str]:
|
| 225 |
-
return self.document_loader.get_supported_extensions()
|
|
|
|
| 6 |
|
| 7 |
# Vector stores
|
| 8 |
from langchain_community.vectorstores import Chroma, FAISS, Pinecone
|
| 9 |
+
from langchain.retrievers import EnsembleRetriever
|
| 10 |
|
| 11 |
# Retriever base
|
| 12 |
from langchain_core.vectorstores import VectorStoreRetriever
|
|
|
|
| 25 |
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
| 27 |
logger = logging.getLogger(__name__)
|
|
|
|
| 28 |
class LangChainRetriever(BaseRetriever):
|
| 29 |
"""LangChain-based retriever with multiple format support"""
|
| 30 |
|
|
|
|
| 160 |
except Exception as e:
|
| 161 |
logger.error(f"Error adding documents: {str(e)}")
|
| 162 |
return False
|
|
|
|
| 163 |
async def _update_bm25_retriever(self, documents: List[Document]):
|
| 164 |
try:
|
| 165 |
+
# Create BM25 retriever from documents
|
| 166 |
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 167 |
+
self.bm25_retriever.k = 10 # Set number of documents to retrieve
|
| 168 |
+
|
| 169 |
+
# For hybrid search, you have several options:
|
| 170 |
+
|
| 171 |
+
# Option 1: Use only BM25 retriever (simplest fix)
|
| 172 |
+
self.retriever = self.bm25_retriever
|
| 173 |
+
|
| 174 |
+
vector_retriever = VectorStoreRetriever(
|
| 175 |
+
vectorstore=self.vectorstore,
|
| 176 |
+
search_kwargs={"k": 10}
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.retriever = EnsembleRetriever(
|
| 180 |
+
retrievers=[vector_retriever, self.bm25_retriever],
|
| 181 |
+
weights=[0.5, 0.5] # Equal weight to both retrievers
|
| 182 |
)
|
| 183 |
+
|
| 184 |
except Exception as e:
|
| 185 |
logger.error(f"Error updating BM25 retriever: {str(e)}")
|
| 186 |
+
# Fallback to vector retriever only
|
| 187 |
+
self.retriever = VectorStoreRetriever(
|
| 188 |
+
vectorstore=self.vectorstore,
|
| 189 |
+
search_kwargs={"k": 10}
|
| 190 |
+
)
|
| 191 |
async def retrieve(self, query: str, k: int = 5) -> RetrievalResult:
|
| 192 |
try:
|
| 193 |
import time
|
|
|
|
| 198 |
None, self.retriever.get_relevant_documents, query
|
| 199 |
)
|
| 200 |
retrieved_docs = retrieved_docs[:k]
|
| 201 |
+
|
| 202 |
scores = [0.9 - (i * 0.1) for i in range(len(retrieved_docs))]
|
| 203 |
|
| 204 |
retrieval_time = time.time() - start_time
|
|
|
|
| 240 |
return list(self.processed_documents.values())
|
| 241 |
|
| 242 |
def get_supported_formats(self) -> List[str]:
|
| 243 |
+
return self.document_loader.get_supported_extensions()
|
space/space/app/rag/web_search/__init__.py
ADDED
|
File without changes
|
space/space/app/rtc/rtc_call_gpt.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fastapi
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
|
| 4 |
+
from fastrtc import ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
|
| 5 |
+
from fastrtc.utils import audio_to_int16
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from elevenlabs.client import ElevenLabs
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from tts.audio_edge_tts import EdgeTTS
|
| 10 |
+
from rag import document_retriever
|
| 11 |
+
import logging
|
| 12 |
+
import time
|
| 13 |
+
import platform
|
| 14 |
+
import socket
|
| 15 |
+
import os
|
| 16 |
+
import numpy as np
|
| 17 |
+
import io
|
| 18 |
+
import wave
|
| 19 |
+
import asyncio
|
| 20 |
+
import librosa
|
| 21 |
+
from pydub import AudioSegment
|
| 22 |
+
# from stt.whisper_stt import WhisperSTT
|
| 23 |
+
from collections import deque
|
| 24 |
+
import torch
|
| 25 |
+
import torchaudio.transforms as T
|
| 26 |
+
import asyncio
|
| 27 |
+
import concurrent.futures
|
| 28 |
+
import threading
|
| 29 |
+
from config.constant import HF_TOKEN
|
| 30 |
+
import threading
|
| 31 |
+
import re
|
| 32 |
+
from openai import OpenAI
|
| 33 |
+
from langchain_core.documents import Document
|
| 34 |
+
|
| 35 |
+
from rag import ddgs
|
| 36 |
+
# Load .env
|
| 37 |
+
load_dotenv()
|
| 38 |
+
logging.basicConfig(level=logging.INFO)
|
| 39 |
+
|
| 40 |
+
class RTCHandler:
|
| 41 |
+
def __init__(self, openai_client: OpenAI, whisper_stt = None, edge_tts : EdgeTTS = None):
|
| 42 |
+
|
| 43 |
+
"""Initialize RTC handler with OpenAI, ElevenLabs, and EdgeTTS"""
|
| 44 |
+
self.whisper_stt = whisper_stt
|
| 45 |
+
self.edge_tts = edge_tts
|
| 46 |
+
self.prompt = ""
|
| 47 |
+
self.sys_prompt = """
|
| 48 |
+
|
| 49 |
+
Kamu adalah customer service yang berbahasa Indonesia dengan baik sopan, santun, tapi santai pembawaannya.
|
| 50 |
+
Kamu bisa menjelaskan sesuatu secara baik dan membimbing customer dalam menghadapi masalah yang ada!
|
| 51 |
+
|
| 52 |
+
Kamu akan menjawab customer dengan media call /telepon jadi anda harus memberikan respon seperlunya saja
|
| 53 |
+
Tidak kepanjanngan, dan sangat jelas,
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
Tidak lebih dari 50 kata.
|
| 57 |
+
"""
|
| 58 |
+
self.openai_client = openai_client
|
| 59 |
+
self.messages = [
|
| 60 |
+
|
| 61 |
+
{
|
| 62 |
+
"role": "system",
|
| 63 |
+
"content": self.sys_prompt
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
]
|
| 67 |
+
self.full_response = ""
|
| 68 |
+
self.stream = None
|
| 69 |
+
self.app = None
|
| 70 |
+
|
| 71 |
+
self._setup_webrtc_ip()
|
| 72 |
+
|
| 73 |
+
def _setup_webrtc_ip(self):
|
| 74 |
+
"""Setup WebRTC IP for Windows"""
|
| 75 |
+
if platform.system() == 'Windows':
|
| 76 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
| 77 |
+
try:
|
| 78 |
+
s.connect(('8.8.8.8', 80))
|
| 79 |
+
local_ip = s.getsockname()[0]
|
| 80 |
+
except Exception:
|
| 81 |
+
local_ip = '127.0.0.1'
|
| 82 |
+
finally:
|
| 83 |
+
s.close()
|
| 84 |
+
os.environ['WEBRTC_IP'] = local_ip
|
| 85 |
+
|
| 86 |
+
def audio_to_bytes(self, audio_tuple, sample_rate=24000) -> io.BufferedReader:
|
| 87 |
+
sr, audio_data = audio_tuple
|
| 88 |
+
audio_int16 = audio_to_int16(audio_tuple)
|
| 89 |
+
|
| 90 |
+
buffer = io.BytesIO()
|
| 91 |
+
with wave.open(buffer, "wb") as wf:
|
| 92 |
+
wf.setnchannels(1)
|
| 93 |
+
wf.setsampwidth(2)
|
| 94 |
+
wf.setframerate(sr)
|
| 95 |
+
wf.writeframes(audio_int16.tobytes())
|
| 96 |
+
buffer.seek(0)
|
| 97 |
+
buffer.name = "audio.wav"
|
| 98 |
+
return buffer
|
| 99 |
+
def echo(self, audio):
|
| 100 |
+
"""Process audio input and generate audio response - Optimized version"""
|
| 101 |
+
try:
|
| 102 |
+
stt_time = time.time()
|
| 103 |
+
logging.info("Performing STT")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# transcription = self.whisper_stt.transcribe(self.audio_to_bytes(audio))
|
| 107 |
+
transcription = self.openai_client.audio.transcriptions.create(
|
| 108 |
+
model="whisper-1",
|
| 109 |
+
file=self.audio_to_bytes(audio),
|
| 110 |
+
language="id"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.prompt = transcription.text
|
| 114 |
+
if self.prompt == "":
|
| 115 |
+
logging.info("STT returned empty string")
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
logging.info(f"STT response: {transcription}")
|
| 119 |
+
|
| 120 |
+
logging.info(f"STT took {time.time() - stt_time} seconds")
|
| 121 |
+
|
| 122 |
+
llm_time = time.time()
|
| 123 |
+
self.full_response = ""
|
| 124 |
+
|
| 125 |
+
# Single async function to handle both text streaming and audio generation
|
| 126 |
+
async def stream_text_to_audio():
|
| 127 |
+
# self.prompt = "Perhitungan BPJS"
|
| 128 |
+
retrieval_result = await document_retriever.retrieve(query = self.prompt)
|
| 129 |
+
contexts = ""
|
| 130 |
+
search_results = []
|
| 131 |
+
|
| 132 |
+
async for result in ddgs.search(self.prompt, max_results=5):
|
| 133 |
+
# self.logger.info(f"Processing SEO Result: {result[:100]}...")
|
| 134 |
+
doc = Document(
|
| 135 |
+
page_content=result,
|
| 136 |
+
metadata={"source": "internet_search", "query": self.prompt}
|
| 137 |
+
)
|
| 138 |
+
print(doc)
|
| 139 |
+
search_results.append(doc)
|
| 140 |
+
|
| 141 |
+
await document_retriever.add_documents([doc])
|
| 142 |
+
|
| 143 |
+
i = 1
|
| 144 |
+
for ctx in retrieval_result.documents:
|
| 145 |
+
contexts += f"{i}. {ctx.page_content}" + "\n"
|
| 146 |
+
print("Retrieved Contexts :", contexts)
|
| 147 |
+
self.messages.append({"role": "user", "content": f"""
|
| 148 |
+
Dari Konteks yang diberikan (jika diperlukan) :
|
| 149 |
+
{contexts}
|
| 150 |
+
|
| 151 |
+
Berikan jawaban atas pertanyaan yang diberikan :
|
| 152 |
+
{self.prompt}
|
| 153 |
+
|
| 154 |
+
"""})
|
| 155 |
+
|
| 156 |
+
response = self.openai_client.chat.completions.create(
|
| 157 |
+
model="gpt-3.5-turbo",
|
| 158 |
+
messages=self.messages,
|
| 159 |
+
max_tokens=200,
|
| 160 |
+
stream=True
|
| 161 |
+
)
|
| 162 |
+
chunk_size = 1024
|
| 163 |
+
no_buffer = 0
|
| 164 |
+
text_buffer = ""
|
| 165 |
+
|
| 166 |
+
for stream_data in response:
|
| 167 |
+
print(stream_data.choices[0].delta.content)
|
| 168 |
+
if stream_data.choices[0].finish_reason == "stop":
|
| 169 |
+
if text_buffer: # Yield sisa text
|
| 170 |
+
yield text_buffer
|
| 171 |
+
break
|
| 172 |
+
if stream_data.choices[0].delta.content:
|
| 173 |
+
chunk = stream_data.choices[0].delta.content
|
| 174 |
+
self.full_response += chunk
|
| 175 |
+
text_buffer += chunk
|
| 176 |
+
# Generate audio immediately for each text chunk
|
| 177 |
+
if re.search(r'[.,?;!]', chunk):
|
| 178 |
+
try:
|
| 179 |
+
audio_buffer_gen = await self.edge_tts.generate_audio_buffer(text_buffer)
|
| 180 |
+
audio_buffer = audio_buffer_gen[0]
|
| 181 |
+
|
| 182 |
+
audio_buffer.seek(0)
|
| 183 |
+
|
| 184 |
+
# Convert MP3 to PCM
|
| 185 |
+
audio_segment = AudioSegment.from_file(audio_buffer, format="mp3")
|
| 186 |
+
samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32) / (2 ** 15)
|
| 187 |
+
|
| 188 |
+
# Handle stereo to mono
|
| 189 |
+
if audio_segment.channels == 2:
|
| 190 |
+
samples = samples.reshape((-1, 2)).mean(axis=1)
|
| 191 |
+
|
| 192 |
+
# # Resample to 24kHz
|
| 193 |
+
# resampled = librosa.resample(samples, orig_sr=audio_segment.frame_rate, target_sr=24000)
|
| 194 |
+
import torch
|
| 195 |
+
import torchaudio
|
| 196 |
+
|
| 197 |
+
# Check if CUDA is available
|
| 198 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 199 |
+
|
| 200 |
+
# Convert numpy array to torch tensor and move to GPU
|
| 201 |
+
audio_tensor = torch.from_numpy(samples).unsqueeze(0).to(device) # Add batch dimension and move to GPU
|
| 202 |
+
|
| 203 |
+
# Create resampler and move to GPU
|
| 204 |
+
resampler = torchaudio.transforms.Resample(
|
| 205 |
+
orig_freq=audio_segment.frame_rate,
|
| 206 |
+
new_freq=24000
|
| 207 |
+
).to(device)
|
| 208 |
+
|
| 209 |
+
# Apply resampling on GPU
|
| 210 |
+
resampled_tensor = resampler(audio_tensor)
|
| 211 |
+
|
| 212 |
+
# Convert back to numpy (move to CPU first)
|
| 213 |
+
resampled = resampled_tensor.squeeze(0).cpu().numpy()
|
| 214 |
+
# Yield audio chunks
|
| 215 |
+
for i in range(0, len(resampled), chunk_size):
|
| 216 |
+
yield (24000, resampled[i:i + chunk_size])
|
| 217 |
+
no_buffer = 0
|
| 218 |
+
text_buffer = ""
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logging.error(f"TTS generation failed for chunk: {e}")
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
# elif stream_data["type"] == "metadata":
|
| 224 |
+
# setup_time = stream_data['data']['setup_time']
|
| 225 |
+
# print(f"\nSetup completed in {setup_time:.2f}s")
|
| 226 |
+
|
| 227 |
+
# elif stream_data["type"] == "complete":
|
| 228 |
+
# total_time = stream_data['data']['total_time']
|
| 229 |
+
# print(f"\nTotal time: {total_time:.2f}s")
|
| 230 |
+
# break
|
| 231 |
+
|
| 232 |
+
# Run the single async function
|
| 233 |
+
loop = asyncio.new_event_loop()
|
| 234 |
+
asyncio.set_event_loop(loop)
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
async_gen = stream_text_to_audio()
|
| 238 |
+
while True:
|
| 239 |
+
try:
|
| 240 |
+
chunk = loop.run_until_complete(async_gen.__anext__())
|
| 241 |
+
yield chunk
|
| 242 |
+
except StopAsyncIteration:
|
| 243 |
+
break
|
| 244 |
+
finally:
|
| 245 |
+
loop.close()
|
| 246 |
+
|
| 247 |
+
self.messages.append({"role": "assistant", "content": self.full_response + " "})
|
| 248 |
+
logging.info(f"LLM response: {self.full_response}")
|
| 249 |
+
logging.info(f"LLM took {time.time() - llm_time} seconds")
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logging.error(f"Error in echo function: {e}")
|
| 253 |
+
error_audio = np.zeros(24000, dtype=np.float32)
|
| 254 |
+
yield (24000, error_audio)
|
| 255 |
+
def reset_conversation(self):
|
| 256 |
+
logging.info("Resetting chat")
|
| 257 |
+
self.messages = [{"role": "system", "content": self.sys_prompt}]
|
| 258 |
+
self.full_response = ""
|
| 259 |
+
|
| 260 |
+
def create_stream(self):
|
| 261 |
+
try:
|
| 262 |
+
async def get_credentials():
|
| 263 |
+
return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN)
|
| 264 |
+
self.stream = Stream(
|
| 265 |
+
rtc_configuration=get_credentials,
|
| 266 |
+
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000),
|
| 267 |
+
handler = ReplyOnPause(
|
| 268 |
+
self.echo,
|
| 269 |
+
algo_options=AlgoOptions(
|
| 270 |
+
audio_chunk_duration=0.5,
|
| 271 |
+
started_talking_threshold=0.1,
|
| 272 |
+
speech_threshold=0.03
|
| 273 |
+
),
|
| 274 |
+
model_options=SileroVadOptions(
|
| 275 |
+
threshold=0.90,
|
| 276 |
+
min_speech_duration_ms=250,
|
| 277 |
+
min_silence_duration_ms=2000,
|
| 278 |
+
speech_pad_ms=400,
|
| 279 |
+
max_speech_duration_s=15
|
| 280 |
+
)
|
| 281 |
+
),
|
| 282 |
+
modality="audio",
|
| 283 |
+
mode="send-receive"
|
| 284 |
+
)
|
| 285 |
+
return self.stream
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logging.error(f"Error creating stream: {e}")
|
| 288 |
+
raise
|
| 289 |
+
|
| 290 |
+
def create_fastapi_app(self):
|
| 291 |
+
try:
|
| 292 |
+
self.app = fastapi.FastAPI()
|
| 293 |
+
self.app.add_middleware(
|
| 294 |
+
CORSMiddleware,
|
| 295 |
+
allow_origins=["*"],
|
| 296 |
+
allow_credentials=True,
|
| 297 |
+
allow_methods=["*"],
|
| 298 |
+
allow_headers=["*"],
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if not self.stream:
|
| 302 |
+
self.create_stream()
|
| 303 |
+
self.stream.mount(self.app)
|
| 304 |
+
|
| 305 |
+
@self.app.get("/reset")
|
| 306 |
+
async def reset():
|
| 307 |
+
try:
|
| 308 |
+
self.reset_conversation()
|
| 309 |
+
return {"status": "success"}
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logging.error(f"Error in reset endpoint: {e}")
|
| 312 |
+
return {"status": "error", "message": str(e)}
|
| 313 |
+
|
| 314 |
+
@self.app.get("/status")
|
| 315 |
+
async def status():
|
| 316 |
+
try:
|
| 317 |
+
return {
|
| 318 |
+
"status": "running",
|
| 319 |
+
"messages_count": len(self.messages),
|
| 320 |
+
"last_response": self.full_response
|
| 321 |
+
}
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logging.error(f"Error in status endpoint: {e}")
|
| 324 |
+
return {"status": "error", "message": str(e)}
|
| 325 |
+
|
| 326 |
+
return self.app
|
| 327 |
+
except Exception as e:
|
| 328 |
+
logging.error(f"Error creating FastAPI app: {e}")
|
| 329 |
+
raise
|
| 330 |
+
|
| 331 |
+
def start_server(self, host: str = "0.0.0.0", port: int = 7860):
|
| 332 |
+
import uvicorn
|
| 333 |
+
if not self.app:
|
| 334 |
+
self.create_fastapi_app()
|
| 335 |
+
logging.info(f"Starting server on {host}:{port}")
|
| 336 |
+
try:
|
| 337 |
+
uvicorn.run(self.app, host=host, port=port, log_level="info")
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logging.error(f"Error starting server: {e}")
|
| 340 |
+
raise
|
| 341 |
+
def launch_ui(self, browser: bool = True):
|
| 342 |
+
try:
|
| 343 |
+
if not self.stream:
|
| 344 |
+
self.create_stream()
|
| 345 |
+
if not self.app:
|
| 346 |
+
self.create_fastapi_app()
|
| 347 |
+
logging.info("Launching RTC UI...")
|
| 348 |
+
self.stream.ui.launch(self.app,
|
| 349 |
+
server_name="0.0.0.0",
|
| 350 |
+
server_port=7860,
|
| 351 |
+
)
|
| 352 |
+
except Exception as e:
|
| 353 |
+
logging.error(f"Error launching UI: {e}")
|
| 354 |
+
raise
|
| 355 |
+
|
| 356 |
+
def get_conversation_history(self):
|
| 357 |
+
return self.messages.copy()
|
| 358 |
+
|
| 359 |
+
def set_system_prompt(self, new_prompt: str):
|
| 360 |
+
self.sys_prompt = new_prompt
|
| 361 |
+
self.messages[0] = {"role": "system", "content": new_prompt}
|
| 362 |
+
|
| 363 |
+
def get_last_response(self):
|
| 364 |
+
return self.full_response
|
space/space/app/tests/qwen_llm_test.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
from rag.retriever.retriever_types import *
|
| 2 |
-
from rag.pipeline.
|
| 3 |
|
| 4 |
import warnings
|
| 5 |
warnings.filterwarnings("ignore")
|
| 6 |
|
| 7 |
-
async def
|
| 8 |
print(" ===== Testing QWEN LLM ==== ")
|
| 9 |
-
"""Example usage of async
|
| 10 |
|
| 11 |
-
config =
|
| 12 |
temperature=0.5,
|
| 13 |
max_length=512,
|
| 14 |
generation_timeout=30
|
|
@@ -23,20 +23,20 @@ async def test_qwen_llm():
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# Using async context manager
|
| 26 |
-
async with
|
| 27 |
await test_qwen_single_generation(llm)
|
| 28 |
await test_qwen_single_rag_generation(llm, contexts)
|
| 29 |
await test_qwen_multiple_template_rag_generation(llm, contexts)
|
| 30 |
await test_qwen_batch_generation(llm, contexts)
|
| 31 |
print(" ===== Testing LLM DONE ==== ")
|
| 32 |
|
| 33 |
-
async def test_qwen_single_generation(llm :
|
| 34 |
print(" * Test Single Generation * ")
|
| 35 |
response = await llm.generate("Jelaskan tentang AI")
|
| 36 |
print(f"Response: {response}")
|
| 37 |
print(" * Test Single Generation Done * ")
|
| 38 |
|
| 39 |
-
async def test_qwen_single_rag_generation(llm :
|
| 40 |
print(" * Test Single RAG Generation * ")
|
| 41 |
rag_response = await llm.rag_generate(
|
| 42 |
question="Apa itu AI dan machine learning?",
|
|
@@ -46,7 +46,7 @@ async def test_qwen_single_rag_generation(llm : QwenLLM, ctx : RetrievalResult):
|
|
| 46 |
print(f"RAG Response: {rag_response}")
|
| 47 |
print(" * Test Single RAG Generation Done * ")
|
| 48 |
|
| 49 |
-
async def test_qwen_multiple_template_rag_generation(llm :
|
| 50 |
print(" * Test Multiple Template Generation * ")
|
| 51 |
multi_responses = await llm.multi_template_generate(
|
| 52 |
question="Apa itu AI?",
|
|
@@ -57,7 +57,7 @@ async def test_qwen_multiple_template_rag_generation(llm : QwenLLM,ctx : Retriev
|
|
| 57 |
print(" * Test Multiple Template Generation Done* ")
|
| 58 |
|
| 59 |
|
| 60 |
-
async def test_qwen_batch_generation(llm :
|
| 61 |
print(" * Test Batch Generation * ")
|
| 62 |
batch_responses = await llm.batch_generate([
|
| 63 |
"Jelaskan tentang Python",
|
|
|
|
| 1 |
from rag.retriever.retriever_types import *
|
| 2 |
+
from rag.pipeline.language_model import LM, LMConfig
|
| 3 |
|
| 4 |
import warnings
|
| 5 |
warnings.filterwarnings("ignore")
|
| 6 |
|
| 7 |
+
async def test_language_model():
|
| 8 |
print(" ===== Testing QWEN LLM ==== ")
|
| 9 |
+
"""Example usage of async LM"""
|
| 10 |
|
| 11 |
+
config = LMConfig(
|
| 12 |
temperature=0.5,
|
| 13 |
max_length=512,
|
| 14 |
generation_timeout=30
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# Using async context manager
|
| 26 |
+
async with LM(config) as llm:
|
| 27 |
await test_qwen_single_generation(llm)
|
| 28 |
await test_qwen_single_rag_generation(llm, contexts)
|
| 29 |
await test_qwen_multiple_template_rag_generation(llm, contexts)
|
| 30 |
await test_qwen_batch_generation(llm, contexts)
|
| 31 |
print(" ===== Testing LLM DONE ==== ")
|
| 32 |
|
| 33 |
+
async def test_qwen_single_generation(llm : LM):
|
| 34 |
print(" * Test Single Generation * ")
|
| 35 |
response = await llm.generate("Jelaskan tentang AI")
|
| 36 |
print(f"Response: {response}")
|
| 37 |
print(" * Test Single Generation Done * ")
|
| 38 |
|
| 39 |
+
async def test_qwen_single_rag_generation(llm : LM, ctx : RetrievalResult):
|
| 40 |
print(" * Test Single RAG Generation * ")
|
| 41 |
rag_response = await llm.rag_generate(
|
| 42 |
question="Apa itu AI dan machine learning?",
|
|
|
|
| 46 |
print(f"RAG Response: {rag_response}")
|
| 47 |
print(" * Test Single RAG Generation Done * ")
|
| 48 |
|
| 49 |
+
async def test_qwen_multiple_template_rag_generation(llm : LM,ctx : RetrievalResult):
|
| 50 |
print(" * Test Multiple Template Generation * ")
|
| 51 |
multi_responses = await llm.multi_template_generate(
|
| 52 |
question="Apa itu AI?",
|
|
|
|
| 57 |
print(" * Test Multiple Template Generation Done* ")
|
| 58 |
|
| 59 |
|
| 60 |
+
async def test_qwen_batch_generation(llm : LM, ctx : RetrievalResult):
|
| 61 |
print(" * Test Batch Generation * ")
|
| 62 |
batch_responses = await llm.batch_generate([
|
| 63 |
"Jelaskan tentang Python",
|
space/space/space/app/__chat__.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
from tests.inference_test import test_inference
|
| 2 |
-
|
|
|
|
| 3 |
import warnings
|
| 4 |
warnings.filterwarnings("ignore")
|
| 5 |
import asyncio
|
| 6 |
def run_test():
|
| 7 |
try:
|
| 8 |
# await test_document_retriever()
|
| 9 |
-
# await
|
| 10 |
-
|
| 11 |
except Exception as e:
|
| 12 |
print(e)
|
| 13 |
|
|
|
|
| 1 |
from tests.inference_test import test_inference
|
| 2 |
+
from huggingface_hub import login
|
| 3 |
+
login(new_session=False)
|
| 4 |
import warnings
|
| 5 |
warnings.filterwarnings("ignore")
|
| 6 |
import asyncio
|
| 7 |
def run_test():
|
| 8 |
try:
|
| 9 |
# await test_document_retriever()
|
| 10 |
+
# await test_language_model()
|
| 11 |
+
test_inference()
|
| 12 |
except Exception as e:
|
| 13 |
print(e)
|
| 14 |
|
space/space/space/app/__test__.py
CHANGED
|
@@ -1,8 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
# from tests.document_retriever_test import test_document_retriever
|
| 3 |
-
# from tests.document_retriever_test import test_document_retriever
|
| 4 |
-
# from tests.qwen_llm_test import test_qwen_llm
|
| 5 |
-
# from tests.inference_test import test_inference
|
| 6 |
from tests.rtc_test import test_rtc
|
| 7 |
import warnings
|
| 8 |
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from tests.rtc_test import test_rtc
|
| 2 |
import warnings
|
| 3 |
warnings.filterwarnings("ignore")
|
space/space/space/app/app.log
ADDED
|
File without changes
|
space/space/space/app/rag/__init__.py
CHANGED
|
@@ -1,17 +1,44 @@
|
|
| 1 |
-
from rag.pipeline.
|
| 2 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 3 |
from rag.inference.inferencer import Inferencer, InferencerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
temperature=0.3,
|
| 7 |
max_length=512,
|
| 8 |
-
generation_timeout=
|
| 9 |
repetition_penalty=1.1,
|
| 10 |
-
max_workers =
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
llm =
|
| 15 |
config = config
|
| 16 |
)
|
| 17 |
|
|
@@ -22,29 +49,42 @@ inferencer_config = InferencerConfig(
|
|
| 22 |
)
|
| 23 |
|
| 24 |
document_retriever = LangChainRetriever(
|
| 25 |
-
embedding_model="all-MiniLM-L6-v2",
|
| 26 |
vectorstore_type="chroma",
|
| 27 |
-
vectorstore_path="
|
| 28 |
use_hybrid_search=True,
|
| 29 |
chunk_size=1000,
|
| 30 |
chunk_overlap=200
|
| 31 |
)
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
model=llm,
|
| 35 |
retriever=document_retriever,
|
|
|
|
| 36 |
reranker=None,
|
| 37 |
config=inferencer_config
|
| 38 |
)
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
async def get_stream_response(question):
|
| 45 |
-
async for item in inferencer.infer_stream(query = question,
|
| 46 |
-
enable_reranking=False,
|
| 47 |
-
template_type="main_template",
|
| 48 |
-
k=3):
|
| 49 |
-
print("Stream Response :", item)
|
| 50 |
-
yield item
|
|
|
|
| 1 |
+
from rag.pipeline.language_model import LM, LMConfig
|
| 2 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 3 |
from rag.inference.inferencer import Inferencer, InferencerConfig
|
| 4 |
+
from rag.agents.customer_service_agent import CSAgent
|
| 5 |
+
from rag.agents.query_maker_agent import QueryMakerAgent
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
from rag.web_search.duckduckgo_search import DuckDuckGoSearch
|
| 8 |
+
from rag.chat_template import get_chat_template
|
| 9 |
+
from transformers import BitsAndBytesConfig
|
| 10 |
+
import torch
|
| 11 |
|
| 12 |
+
import logging
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=logging.DEBUG,
|
| 17 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s() - %(message)s',
|
| 18 |
+
handlers=[
|
| 19 |
+
logging.FileHandler('app.log'),
|
| 20 |
+
logging.StreamHandler(sys.stdout)
|
| 21 |
+
]
|
| 22 |
+
)
|
| 23 |
+
bnb = BitsAndBytesConfig(
|
| 24 |
+
load_in_4bit=True, # Enable 4-bit quantization
|
| 25 |
+
bnb_4bit_use_double_quant=True, # Use double quantization
|
| 26 |
+
bnb_4bit_quant_type="nf4", # Use NF4 quantization
|
| 27 |
+
bnb_4bit_compute_dtype=torch.bfloat16, # Compute dtype for 4bit base models
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
config = LMConfig(
|
| 32 |
+
model_name = "Qwen/Qwen2.5-1.5B-Instruct",
|
| 33 |
temperature=0.3,
|
| 34 |
max_length=512,
|
| 35 |
+
generation_timeout=100,
|
| 36 |
repetition_penalty=1.1,
|
| 37 |
+
max_workers = 2,
|
| 38 |
+
quantization_config = bnb
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
llm = LM(
|
| 42 |
config = config
|
| 43 |
)
|
| 44 |
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
document_retriever = LangChainRetriever(
|
| 52 |
+
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
| 53 |
vectorstore_type="chroma",
|
| 54 |
+
vectorstore_path="vectorstore/",
|
| 55 |
use_hybrid_search=True,
|
| 56 |
chunk_size=1000,
|
| 57 |
chunk_overlap=200
|
| 58 |
)
|
| 59 |
|
| 60 |
+
ddgs = DuckDuckGoSearch()
|
| 61 |
+
|
| 62 |
+
cs_inferencer = Inferencer(
|
| 63 |
model=llm,
|
| 64 |
retriever=document_retriever,
|
| 65 |
+
# search_engine = ddgs,
|
| 66 |
reranker=None,
|
| 67 |
config=inferencer_config
|
| 68 |
)
|
| 69 |
|
| 70 |
+
query_maker_inferencer = Inferencer(
|
| 71 |
+
model=llm,
|
| 72 |
+
config=inferencer_config
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
cs_agent = CSAgent(
|
| 76 |
+
inferencer = cs_inferencer,
|
| 77 |
+
prompt_template = get_chat_template("customer_service")
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
query_maker_chat_template = get_chat_template("query_maker")
|
| 81 |
+
query_maker_chat_template[1]["content"] = """{question}"""
|
| 82 |
+
|
| 83 |
+
query_maker_agent = QueryMakerAgent(
|
| 84 |
+
inferencer = query_maker_inferencer,
|
| 85 |
+
prompt_template = query_maker_chat_template
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
space/space/space/app/rag/inference/inferencer.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 2 |
-
from rag.pipeline.
|
| 3 |
from rag.retriever.retriever_types import RetrievalResult
|
|
|
|
|
|
|
| 4 |
# from rag.pipeline.reranker import BGEM3Reranker
|
| 5 |
from typing import List, Union, Dict, Any, Optional, AsyncGenerator
|
| 6 |
import asyncio
|
|
@@ -29,15 +31,16 @@ class Inferencer:
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self,
|
| 32 |
-
model:
|
| 33 |
-
retriever: LangChainRetriever,
|
|
|
|
| 34 |
reranker=None,
|
| 35 |
config: Optional[InferencerConfig] = None):
|
| 36 |
"""
|
| 37 |
Initialize Inferencer
|
| 38 |
|
| 39 |
Args:
|
| 40 |
-
model:
|
| 41 |
retriever: LangChainRetriever instance
|
| 42 |
reranker: Reranker instance (optional)
|
| 43 |
config: InferencerConfig (optional)
|
|
@@ -45,6 +48,7 @@ class Inferencer:
|
|
| 45 |
self.model = model
|
| 46 |
self.retriever = retriever
|
| 47 |
self.reranker = reranker
|
|
|
|
| 48 |
self.config = config or InferencerConfig()
|
| 49 |
|
| 50 |
# Setup logging
|
|
@@ -85,6 +89,7 @@ class Inferencer:
|
|
| 85 |
try:
|
| 86 |
start_time = datetime.now()
|
| 87 |
contexts = await self.retriever.retrieve(query, k=k)
|
|
|
|
| 88 |
retrieval_time = (datetime.now() - start_time).total_seconds()
|
| 89 |
|
| 90 |
self.logger.info(f"Retrieved {len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts)} contexts in {retrieval_time:.2f}s")
|
|
@@ -292,7 +297,7 @@ class Inferencer:
|
|
| 292 |
yield chunk
|
| 293 |
|
| 294 |
async def infer(self,
|
| 295 |
-
query:
|
| 296 |
response_type: Union[List[str], str] = None,
|
| 297 |
k: Optional[int] = None,
|
| 298 |
enable_reranking: Optional[bool] = None,
|
|
@@ -321,8 +326,12 @@ class Inferencer:
|
|
| 321 |
|
| 322 |
try:
|
| 323 |
# Step 1: Retrieve contexts
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
# Step 2: Rerank contexts (if enabled)
|
| 327 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 328 |
if enable_rerank:
|
|
@@ -363,7 +372,34 @@ class Inferencer:
|
|
| 363 |
except Exception as e:
|
| 364 |
self.logger.error(f"Error during inference: {e}")
|
| 365 |
raise
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
async def infer_stream(self,
|
| 368 |
query: str,
|
| 369 |
k: Optional[int] = None,
|
|
@@ -389,8 +425,14 @@ class Inferencer:
|
|
| 389 |
|
| 390 |
try:
|
| 391 |
# Step 1: Retrieve contexts
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
|
|
|
| 394 |
# Step 2: Rerank contexts (if enabled)
|
| 395 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 396 |
if enable_rerank:
|
|
|
|
| 1 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 2 |
+
from rag.pipeline.language_model import LM, LMConfig
|
| 3 |
from rag.retriever.retriever_types import RetrievalResult
|
| 4 |
+
from rag.web_search.duckduckgo_search import DuckDuckGoSearch
|
| 5 |
+
from langchain_core.documents import Document
|
| 6 |
# from rag.pipeline.reranker import BGEM3Reranker
|
| 7 |
from typing import List, Union, Dict, Any, Optional, AsyncGenerator
|
| 8 |
import asyncio
|
|
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
def __init__(self,
|
| 34 |
+
model: LM,
|
| 35 |
+
retriever: LangChainRetriever = None,
|
| 36 |
+
search_engine = None,
|
| 37 |
reranker=None,
|
| 38 |
config: Optional[InferencerConfig] = None):
|
| 39 |
"""
|
| 40 |
Initialize Inferencer
|
| 41 |
|
| 42 |
Args:
|
| 43 |
+
model: LM instance
|
| 44 |
retriever: LangChainRetriever instance
|
| 45 |
reranker: Reranker instance (optional)
|
| 46 |
config: InferencerConfig (optional)
|
|
|
|
| 48 |
self.model = model
|
| 49 |
self.retriever = retriever
|
| 50 |
self.reranker = reranker
|
| 51 |
+
self.search_engine = search_engine
|
| 52 |
self.config = config or InferencerConfig()
|
| 53 |
|
| 54 |
# Setup logging
|
|
|
|
| 89 |
try:
|
| 90 |
start_time = datetime.now()
|
| 91 |
contexts = await self.retriever.retrieve(query, k=k)
|
| 92 |
+
self.logger.info(f"Retrieved Contexts : {contexts}")
|
| 93 |
retrieval_time = (datetime.now() - start_time).total_seconds()
|
| 94 |
|
| 95 |
self.logger.info(f"Retrieved {len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts)} contexts in {retrieval_time:.2f}s")
|
|
|
|
| 297 |
yield chunk
|
| 298 |
|
| 299 |
async def infer(self,
|
| 300 |
+
query: str,
|
| 301 |
response_type: Union[List[str], str] = None,
|
| 302 |
k: Optional[int] = None,
|
| 303 |
enable_reranking: Optional[bool] = None,
|
|
|
|
| 326 |
|
| 327 |
try:
|
| 328 |
# Step 1: Retrieve contexts
|
| 329 |
+
if(self.search_engine):
|
| 330 |
+
await self.retrieve_from_search_engine(query, k = k)
|
| 331 |
+
if(self.retriever):
|
| 332 |
+
retrieved_contexts = await self.retrieve_context(main_query, k=k)
|
| 333 |
+
else:
|
| 334 |
+
retrieved_contexts = ""
|
| 335 |
# Step 2: Rerank contexts (if enabled)
|
| 336 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 337 |
if enable_rerank:
|
|
|
|
| 372 |
except Exception as e:
|
| 373 |
self.logger.error(f"Error during inference: {e}")
|
| 374 |
raise
|
| 375 |
+
async def retrieve_from_search_engine(self, query: str, k: int = 3):
|
| 376 |
+
"""
|
| 377 |
+
Alternative method: Process results as they come
|
| 378 |
+
"""
|
| 379 |
+
from langchain_core.documents import Document
|
| 380 |
+
|
| 381 |
+
search_results = []
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
# Process results one by one as they come
|
| 385 |
+
async for result in self.search_engine.search(query, max_results=k):
|
| 386 |
+
self.logger.info(f"Processing SEO Result: {result[:100]}...")
|
| 387 |
+
|
| 388 |
+
doc = Document(
|
| 389 |
+
page_content=result,
|
| 390 |
+
metadata={"source": "internet_search", "query": query}
|
| 391 |
+
)
|
| 392 |
+
search_results.append(doc)
|
| 393 |
+
|
| 394 |
+
# Optionally add to retriever immediately
|
| 395 |
+
await self.retriever.add_documents([doc])
|
| 396 |
+
|
| 397 |
+
self.logger.info(f"Processed {len(search_results)} search results")
|
| 398 |
+
return search_results
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
self.logger.error(f"Error in retrieve_from_search_engine_alternative: {e}", exc_info=True)
|
| 402 |
+
raise
|
| 403 |
async def infer_stream(self,
|
| 404 |
query: str,
|
| 405 |
k: Optional[int] = None,
|
|
|
|
| 425 |
|
| 426 |
try:
|
| 427 |
# Step 1: Retrieve contexts
|
| 428 |
+
if(self.search_engine):
|
| 429 |
+
await self.retrieve_from_search_engine(query, k = k)
|
| 430 |
+
if(self.retriever is not None):
|
| 431 |
+
retrieved_contexts = await self.retrieve_context(query, k=k)
|
| 432 |
+
else:
|
| 433 |
+
retrieved_contexts = ""
|
| 434 |
|
| 435 |
+
|
| 436 |
# Step 2: Rerank contexts (if enabled)
|
| 437 |
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 438 |
if enable_rerank:
|
space/space/space/app/rag/pipeline/qwen_llm.py
CHANGED
|
@@ -17,7 +17,7 @@ import copy
|
|
| 17 |
@dataclass
|
| 18 |
class QwenConfig:
|
| 19 |
"""Konfigurasi untuk model Qwen 0.5B"""
|
| 20 |
-
model_name: str = "Qwen/Qwen2.5-
|
| 21 |
device: str = "cuda"
|
| 22 |
torch_dtype: torch.dtype = torch.float16
|
| 23 |
max_length: int = 2048
|
|
@@ -286,14 +286,35 @@ class QwenLLM:
|
|
| 286 |
|
| 287 |
formatted_template = []
|
| 288 |
for cht in template_data:
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
-
|
| 295 |
-
cht["content"] = cht["content"].format(question=question)
|
| 296 |
-
formatted_template.append(cht)
|
| 297 |
|
| 298 |
self.logger.info("Formatted Template", formatted_template)
|
| 299 |
print("Forrmatted Template", formatted_template)
|
|
|
|
| 17 |
@dataclass
|
| 18 |
class QwenConfig:
|
| 19 |
"""Konfigurasi untuk model Qwen 0.5B"""
|
| 20 |
+
model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 21 |
device: str = "cuda"
|
| 22 |
torch_dtype: torch.dtype = torch.float16
|
| 23 |
max_length: int = 2048
|
|
|
|
| 286 |
|
| 287 |
formatted_template = []
|
| 288 |
for cht in template_data:
|
| 289 |
+
# Create a copy of the content to avoid modifying the original
|
| 290 |
+
content = cht["content"]
|
| 291 |
+
|
| 292 |
+
# Format both placeholders at once to avoid KeyError
|
| 293 |
+
if "{context}" in content or "{question}" in content:
|
| 294 |
+
try:
|
| 295 |
+
content = content.format(
|
| 296 |
+
context=formatted_context,
|
| 297 |
+
question=question
|
| 298 |
+
)
|
| 299 |
+
except KeyError as e:
|
| 300 |
+
self.logger.error(f"Missing placeholder in template: {e}")
|
| 301 |
+
# Fallback: format only available placeholders
|
| 302 |
+
if "{context}" in content:
|
| 303 |
+
content = content.replace("{context}", formatted_context)
|
| 304 |
+
if "{question}" in content:
|
| 305 |
+
content = content.replace("{question}", question)
|
| 306 |
+
|
| 307 |
+
# Create new dict with formatted content
|
| 308 |
+
formatted_chat = {
|
| 309 |
+
"role": cht["role"],
|
| 310 |
+
"content": content
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
# Copy other fields if they exist
|
| 314 |
+
if "description" in cht:
|
| 315 |
+
formatted_chat["description"] = cht["description"]
|
| 316 |
|
| 317 |
+
formatted_template.append(formatted_chat)
|
|
|
|
|
|
|
| 318 |
|
| 319 |
self.logger.info("Formatted Template", formatted_template)
|
| 320 |
print("Forrmatted Template", formatted_template)
|
space/space/space/app/rag/prompt_tuner/chat_template.py
CHANGED
|
@@ -8,18 +8,20 @@ def RAG_TEMPLATES():
|
|
| 8 |
|
| 9 |
1. Selalu berikan sapaan yang ramah dan profesional
|
| 10 |
2. Gunakan HANYA informasi dari knowledge base yang tersedia
|
| 11 |
-
3. Berikan jawaban yang jelas, mudah dipahami, dan terstruktur semuanya berdasarkan konteks yang diberikan
|
| 12 |
-
{context}
|
| 13 |
4. Jika informasi tidak tersedia, tawarkan alternatif bantuan atau arahkan ke channel yang tepat
|
| 14 |
5. Gunakan bahasa yang sopan dan empati terhadap kebutuhan pelanggan
|
| 15 |
6. Akhiri dengan penawaran bantuan lebih lanjut
|
|
|
|
| 16 |
""",
|
| 17 |
"description": "Template dengan system prompt untuk customer service professional"
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"role" : "user",
|
| 21 |
-
"content" : """
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
},
|
| 25 |
],
|
|
|
|
| 8 |
|
| 9 |
1. Selalu berikan sapaan yang ramah dan profesional
|
| 10 |
2. Gunakan HANYA informasi dari knowledge base yang tersedia
|
| 11 |
+
3. Berikan jawaban yang jelas, mudah dipahami, dan terstruktur semuanya berdasarkan konteks yang diberikan user.
|
|
|
|
| 12 |
4. Jika informasi tidak tersedia, tawarkan alternatif bantuan atau arahkan ke channel yang tepat
|
| 13 |
5. Gunakan bahasa yang sopan dan empati terhadap kebutuhan pelanggan
|
| 14 |
6. Akhiri dengan penawaran bantuan lebih lanjut
|
| 15 |
+
|
| 16 |
""",
|
| 17 |
"description": "Template dengan system prompt untuk customer service professional"
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"role" : "user",
|
| 21 |
+
"content" : """Dari konteks yang diberikan : {context}
|
| 22 |
+
|
| 23 |
+
berikan jawaban atas pertanyaan saya yaitu : {question}
|
| 24 |
+
|
| 25 |
"""
|
| 26 |
},
|
| 27 |
],
|
space/space/space/app/rag/web_search/duckduckgo_search.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ddgs import DDGS
|
| 2 |
+
from langchain_community.document_loaders import AsyncChromiumLoader
|
| 3 |
+
from langchain_community.document_transformers import BeautifulSoupTransformer
|
| 4 |
+
import re
|
| 5 |
+
import logging
|
| 6 |
+
from typing import AsyncGenerator, List
|
| 7 |
+
|
| 8 |
+
class DuckDuckGoSearch:
|
| 9 |
+
def __init__(self, html_loader: AsyncChromiumLoader = None, html_parser = None):
|
| 10 |
+
# Initialize dengan default values jika tidak diberikan
|
| 11 |
+
self.html_loader = html_loader or AsyncChromiumLoader([])
|
| 12 |
+
self.html_parser = html_parser or BeautifulSoupTransformer()
|
| 13 |
+
self.logger = logging.getLogger("ddgs_logger")
|
| 14 |
+
|
| 15 |
+
async def get_page(self, urls: List[str]):
|
| 16 |
+
"""Get page content from URLs - returns list of documents"""
|
| 17 |
+
try:
|
| 18 |
+
self.html_loader.urls = urls
|
| 19 |
+
html = await self.html_loader.aload() # This returns a LIST
|
| 20 |
+
self.logger.info(f"search engine aload result: {len(html)} documents loaded")
|
| 21 |
+
|
| 22 |
+
docs_transformed = self.html_parser.transform_documents(
|
| 23 |
+
html,
|
| 24 |
+
tags_to_extract=["p"],
|
| 25 |
+
remove_unwanted_tags=["a"]
|
| 26 |
+
)
|
| 27 |
+
return docs_transformed # Returns LIST of documents
|
| 28 |
+
|
| 29 |
+
except Exception as e:
|
| 30 |
+
self.logger.error(f"Error loading pages: {e}", exc_info=True)
|
| 31 |
+
return [] # Return empty list on error
|
| 32 |
+
|
| 33 |
+
def truncate(self, text: str, max_words: int = 400) -> str:
|
| 34 |
+
"""Truncate text to specified number of words"""
|
| 35 |
+
if not text:
|
| 36 |
+
return ""
|
| 37 |
+
|
| 38 |
+
words = text.split()
|
| 39 |
+
if len(words) <= max_words:
|
| 40 |
+
return text
|
| 41 |
+
|
| 42 |
+
truncated = " ".join(words[:max_words])
|
| 43 |
+
return truncated + "..." if len(words) > max_words else truncated
|
| 44 |
+
|
| 45 |
+
async def search(self, query: str, max_results: int = 5) -> AsyncGenerator[str, None]:
|
| 46 |
+
"""
|
| 47 |
+
Search and yield page contents one by one
|
| 48 |
+
|
| 49 |
+
FIXED VERSION: Properly handle async iteration
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
self.logger.info(f"Searching for: {query} (max_results: {max_results})")
|
| 53 |
+
|
| 54 |
+
# Step 1: Get search results from DDGS (regular iterator)
|
| 55 |
+
results = DDGS().text(query, max_results=max_results)
|
| 56 |
+
urls = []
|
| 57 |
+
|
| 58 |
+
# Step 2: Extract URLs using regular for loop (NOT async for)
|
| 59 |
+
for result in results: # ← FIXED: Regular for loop
|
| 60 |
+
url = result.get('href')
|
| 61 |
+
if url:
|
| 62 |
+
urls.append(url)
|
| 63 |
+
|
| 64 |
+
self.logger.info(f"Found {len(urls)} URLs to process")
|
| 65 |
+
|
| 66 |
+
if not urls:
|
| 67 |
+
self.logger.warning("No URLs found from search results")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# Step 3: Get page content (await the coroutine first)
|
| 71 |
+
docs = await self.get_page(urls) # ← FIXED: Await first, get list
|
| 72 |
+
|
| 73 |
+
# Step 4: Process documents using regular for loop (NOT async for)
|
| 74 |
+
for doc in docs: # ← FIXED: Regular for loop on list
|
| 75 |
+
try:
|
| 76 |
+
if hasattr(doc, 'page_content') and doc.page_content:
|
| 77 |
+
# Clean up text
|
| 78 |
+
page_text = re.sub(r"\n\n+", "\n", doc.page_content)
|
| 79 |
+
page_text = page_text.strip()
|
| 80 |
+
|
| 81 |
+
if page_text: # Only yield if there's actual content
|
| 82 |
+
text = self.truncate(page_text)
|
| 83 |
+
yield text # Yield makes this an async generator
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
self.logger.error(f"Error processing document: {e}")
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
self.logger.error(f"Error in search method: {e}", exc_info=True)
|
| 91 |
+
# Don't re-raise, just log and return (generator will be empty)
|
| 92 |
+
|
| 93 |
+
async def search_with_metadata(self, query: str, max_results: int = 5) -> AsyncGenerator[dict, None]:
|
| 94 |
+
"""
|
| 95 |
+
Alternative method that yields dictionaries with metadata
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
results = DDGS().text(query, max_results=max_results)
|
| 99 |
+
urls_and_titles = []
|
| 100 |
+
|
| 101 |
+
# Collect URLs and titles
|
| 102 |
+
for result in results:
|
| 103 |
+
url = result.get('href')
|
| 104 |
+
title = result.get('title', 'No title')
|
| 105 |
+
if url:
|
| 106 |
+
urls_and_titles.append({'url': url, 'title': title})
|
| 107 |
+
|
| 108 |
+
if not urls_and_titles:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
# Get page content
|
| 112 |
+
urls = [item['url'] for item in urls_and_titles]
|
| 113 |
+
docs = await self.get_page(urls)
|
| 114 |
+
|
| 115 |
+
# Process and yield with metadata
|
| 116 |
+
for i, doc in enumerate(docs):
|
| 117 |
+
try:
|
| 118 |
+
if hasattr(doc, 'page_content') and doc.page_content:
|
| 119 |
+
page_text = re.sub(r"\n\n+", "\n", doc.page_content)
|
| 120 |
+
page_text = page_text.strip()
|
| 121 |
+
|
| 122 |
+
if page_text:
|
| 123 |
+
text = self.truncate(page_text)
|
| 124 |
+
|
| 125 |
+
# Get metadata if available
|
| 126 |
+
metadata = {}
|
| 127 |
+
if i < len(urls_and_titles):
|
| 128 |
+
metadata = urls_and_titles[i]
|
| 129 |
+
|
| 130 |
+
yield {
|
| 131 |
+
'content': text,
|
| 132 |
+
'url': metadata.get('url', 'Unknown'),
|
| 133 |
+
'title': metadata.get('title', 'No title'),
|
| 134 |
+
'word_count': len(text.split())
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
self.logger.error(f"Error processing document {i}: {e}")
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
self.logger.error(f"Error in search_with_metadata: {e}", exc_info=True)
|
space/space/space/app/rtc/__init__.py
CHANGED
|
@@ -2,11 +2,13 @@ from openai import OpenAI
|
|
| 2 |
from elevenlabs.client import ElevenLabs
|
| 3 |
from tts.audio_edge_tts import EdgeTTS
|
| 4 |
from config.constant import OPENAI_API_KEY, ELEVENLABS_API_KEY
|
|
|
|
| 5 |
from rtc.rtc_call import RTCHandler
|
| 6 |
from stt.whisper_stt import WhisperSTT
|
| 7 |
|
| 8 |
-
whisper_stt = WhisperSTT("
|
| 9 |
edge_tts = EdgeTTS("id-ID-ArdiNeural", "+0%", "+0%")
|
|
|
|
| 10 |
rtc_handler = RTCHandler(whisper_stt, edge_tts)
|
| 11 |
|
| 12 |
def handle_rtc():
|
|
|
|
| 2 |
from elevenlabs.client import ElevenLabs
|
| 3 |
from tts.audio_edge_tts import EdgeTTS
|
| 4 |
from config.constant import OPENAI_API_KEY, ELEVENLABS_API_KEY
|
| 5 |
+
# from rtc.rtc_call import RTCHandler
|
| 6 |
from rtc.rtc_call import RTCHandler
|
| 7 |
from stt.whisper_stt import WhisperSTT
|
| 8 |
|
| 9 |
+
whisper_stt = WhisperSTT(model_size = "base", device = "cuda")
|
| 10 |
edge_tts = EdgeTTS("id-ID-ArdiNeural", "+0%", "+0%")
|
| 11 |
+
openai_client = OpenAI(api_key = OPENAI_API_KEY)
|
| 12 |
rtc_handler = RTCHandler(whisper_stt, edge_tts)
|
| 13 |
|
| 14 |
def handle_rtc():
|
space/space/space/app/rtc/rtc_call.py
CHANGED
|
@@ -30,7 +30,7 @@ import threading
|
|
| 30 |
import re
|
| 31 |
|
| 32 |
|
| 33 |
-
from rag import
|
| 34 |
# Load .env
|
| 35 |
load_dotenv()
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -94,7 +94,7 @@ class RTCHandler:
|
|
| 94 |
logging.info("STT returned empty string")
|
| 95 |
return
|
| 96 |
|
| 97 |
-
logging.info(f"STT response: {
|
| 98 |
self.messages.append({"role": "user", "content": prompt})
|
| 99 |
logging.info(f"STT took {time.time() - stt_time} seconds")
|
| 100 |
|
|
@@ -106,7 +106,7 @@ class RTCHandler:
|
|
| 106 |
chunk_size = 1024
|
| 107 |
no_buffer = 0
|
| 108 |
text_buffer = ""
|
| 109 |
-
async for stream_data in
|
| 110 |
print(stream_data)
|
| 111 |
|
| 112 |
if stream_data["type"] == "chunk":
|
|
|
|
| 30 |
import re
|
| 31 |
|
| 32 |
|
| 33 |
+
from rag import cs_agent
|
| 34 |
# Load .env
|
| 35 |
load_dotenv()
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 94 |
logging.info("STT returned empty string")
|
| 95 |
return
|
| 96 |
|
| 97 |
+
logging.info(f"STT response: {transcription}")
|
| 98 |
self.messages.append({"role": "user", "content": prompt})
|
| 99 |
logging.info(f"STT took {time.time() - stt_time} seconds")
|
| 100 |
|
|
|
|
| 106 |
chunk_size = 1024
|
| 107 |
no_buffer = 0
|
| 108 |
text_buffer = ""
|
| 109 |
+
async for stream_data in cs_agent.get_result(question = prompt):
|
| 110 |
print(stream_data)
|
| 111 |
|
| 112 |
if stream_data["type"] == "chunk":
|
space/space/space/app/stt/whisper_stt.py
CHANGED
|
@@ -1,31 +1,94 @@
|
|
|
|
|
| 1 |
import whisper
|
|
|
|
| 2 |
from fastrtc.utils import audio_to_int16
|
| 3 |
import io
|
| 4 |
import os
|
| 5 |
import tempfile
|
| 6 |
|
| 7 |
class WhisperSTT:
|
| 8 |
-
def __init__(self, model_size: str = "base"):
|
| 9 |
"""
|
| 10 |
-
Initialize Whisper STT with specified model size
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
|
|
|
| 12 |
cache_dir = os.environ.get('WHISPER_CACHE_DIR', '/tmp/.cache/whisper')
|
| 13 |
os.makedirs(cache_dir, exist_ok=True)
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
self.language = "id" # ISO-639-1 code for Bahasa Indonesia
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def transcribe(self, audio: io.BufferedReader, language: str = "id") -> str:
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 21 |
tmp.write(audio.read())
|
| 22 |
tmp.flush()
|
| 23 |
tmp_path = tmp.name
|
| 24 |
|
| 25 |
try:
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
return result.get("text", "")
|
| 28 |
finally:
|
|
|
|
| 29 |
os.remove(tmp_path)
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
import whisper
|
| 3 |
+
import torch
|
| 4 |
from fastrtc.utils import audio_to_int16
|
| 5 |
import io
|
| 6 |
import os
|
| 7 |
import tempfile
|
| 8 |
|
| 9 |
class WhisperSTT:
|
| 10 |
+
def __init__(self, model_size: str = "base", device: str = "auto"):
|
| 11 |
"""
|
| 12 |
+
Initialize Whisper STT with specified model size and device
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
model_size: Model size (tiny, base, small, medium, large)
|
| 16 |
+
device: Device to use ("auto", "cuda", "cpu")
|
| 17 |
"""
|
| 18 |
+
# Set up cache directory
|
| 19 |
cache_dir = os.environ.get('WHISPER_CACHE_DIR', '/tmp/.cache/whisper')
|
| 20 |
os.makedirs(cache_dir, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
# Determine device
|
| 23 |
+
if device == "auto":
|
| 24 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
else:
|
| 26 |
+
self.device = device
|
| 27 |
+
|
| 28 |
+
# Validate CUDA availability if requested
|
| 29 |
+
if self.device == "cuda" and not torch.cuda.is_available():
|
| 30 |
+
print("Warning: CUDA requested but not available. Falling back to CPU.")
|
| 31 |
+
self.device = "cpu"
|
| 32 |
+
|
| 33 |
+
# Load model with device specification
|
| 34 |
+
print(f"Loading Whisper model '{model_size}' on device: {self.device}")
|
| 35 |
+
self.model = whisper.load_model(model_size, device=self.device, download_root=cache_dir)
|
| 36 |
self.language = "id" # ISO-639-1 code for Bahasa Indonesia
|
| 37 |
|
| 38 |
+
# Print GPU info if using CUDA
|
| 39 |
+
if self.device == "cuda":
|
| 40 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 41 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 42 |
+
print(f"Using GPU: {gpu_name} ({gpu_memory:.1f} GB)")
|
| 43 |
|
| 44 |
def transcribe(self, audio: io.BufferedReader, language: str = "id") -> str:
|
| 45 |
+
"""
|
| 46 |
+
Transcribe audio using Whisper
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
audio: Audio file buffer
|
| 50 |
+
language: Language code (default: "id" for Indonesian)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Transcribed text
|
| 54 |
+
"""
|
| 55 |
+
# Save audio to temporary file
|
| 56 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 57 |
tmp.write(audio.read())
|
| 58 |
tmp.flush()
|
| 59 |
tmp_path = tmp.name
|
| 60 |
|
| 61 |
try:
|
| 62 |
+
# Transcribe with GPU acceleration if available
|
| 63 |
+
result = self.model.transcribe(
|
| 64 |
+
tmp_path,
|
| 65 |
+
language=language,
|
| 66 |
+
# Optional: Add fp16 for faster inference on supported GPUs
|
| 67 |
+
fp16=self.device == "cuda"
|
| 68 |
+
)
|
| 69 |
return result.get("text", "")
|
| 70 |
finally:
|
| 71 |
+
# Clean up temporary file
|
| 72 |
os.remove(tmp_path)
|
| 73 |
+
|
| 74 |
+
def get_device_info(self) -> dict:
|
| 75 |
+
"""
|
| 76 |
+
Get information about the current device being used
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Dictionary with device information
|
| 80 |
+
"""
|
| 81 |
+
info = {
|
| 82 |
+
"device": self.device,
|
| 83 |
+
"cuda_available": torch.cuda.is_available()
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
| 87 |
+
info.update({
|
| 88 |
+
"gpu_name": torch.cuda.get_device_name(0),
|
| 89 |
+
"gpu_memory_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3,
|
| 90 |
+
"gpu_memory_allocated_gb": torch.cuda.memory_allocated() / 1024**3,
|
| 91 |
+
"gpu_memory_reserved_gb": torch.cuda.memory_reserved() / 1024**3
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
return info
|
space/space/space/app/tests/ddgs_test.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.web_search import ddgs
|
| 2 |
+
def test_ddgs():
|
| 3 |
+
# query = input()
|
| 4 |
+
# print("Searching for query = ", query)
|
| 5 |
+
|
| 6 |
+
print("*** searching result : **")
|
| 7 |
+
print(ddgs.search("Perhitungan uang lembur"))
|
space/space/space/app/tests/inference_test.py
CHANGED
|
@@ -1,69 +1,15 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import asyncio
|
| 3 |
-
from rag.pipeline.
|
| 4 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 5 |
from rag.inference.inferencer import InferencerConfig, Inferencer
|
| 6 |
-
|
| 7 |
-
|
| 8 |
"""Main function that sets up and runs the RAG chatbot interface"""
|
| 9 |
|
| 10 |
# Initialize RAG components
|
| 11 |
print("==== Start Inference Test ===")
|
| 12 |
|
| 13 |
-
# Setup LLM
|
| 14 |
-
config = QwenConfig(
|
| 15 |
-
temperature=0.3,
|
| 16 |
-
max_length=512,
|
| 17 |
-
generation_timeout=30,
|
| 18 |
-
repetition_penalty=1.1,
|
| 19 |
-
do_sample = True,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
llm = QwenLLM(config=config)
|
| 23 |
-
|
| 24 |
-
# Setup Document Retriever
|
| 25 |
-
document_retriever = LangChainRetriever(
|
| 26 |
-
embedding_model="text-embedding-3-small",
|
| 27 |
-
vectorstore_type="chroma",
|
| 28 |
-
vectorstore_path="./vectorstore",
|
| 29 |
-
use_hybrid_search=True,
|
| 30 |
-
chunk_size=1000,
|
| 31 |
-
chunk_overlap=200
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
# Load initial documents
|
| 35 |
-
file_paths = [
|
| 36 |
-
"../documents/bpjs.pdf",
|
| 37 |
-
"../documents/pph21.pdf",
|
| 38 |
-
"../documents/lembur.pdf",
|
| 39 |
-
"../documents/uu13.pdf",
|
| 40 |
-
"../documents/file.pdf",
|
| 41 |
-
]
|
| 42 |
-
|
| 43 |
-
for file_path in file_paths:
|
| 44 |
-
try:
|
| 45 |
-
result = await document_retriever.add_document_from_file(file_path)
|
| 46 |
-
if result.success:
|
| 47 |
-
print(f"Successfully processed: {result.document_metadata.file_name}")
|
| 48 |
-
print(f"Chunks created: {result.document_metadata.chunk_count}")
|
| 49 |
-
else:
|
| 50 |
-
print(f"Failed to process: {result.error_message}")
|
| 51 |
-
except Exception as e:
|
| 52 |
-
print(f"Error processing {file_path}: {e}")
|
| 53 |
-
|
| 54 |
-
# Setup Inferencer
|
| 55 |
-
inferencer_config = InferencerConfig(
|
| 56 |
-
default_k=2,
|
| 57 |
-
enable_reranking=False,
|
| 58 |
-
default_template_types=["system"]
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
inferencer = Inferencer(
|
| 62 |
-
model=llm,
|
| 63 |
-
retriever=document_retriever,
|
| 64 |
-
reranker=None,
|
| 65 |
-
config=inferencer_config
|
| 66 |
-
)
|
| 67 |
|
| 68 |
print("RAG system initialized successfully!")
|
| 69 |
|
|
@@ -73,16 +19,16 @@ async def test_inference():
|
|
| 73 |
# Create new event loop for this thread
|
| 74 |
loop = asyncio.new_event_loop()
|
| 75 |
asyncio.set_event_loop(loop)
|
| 76 |
-
|
| 77 |
async def stream_response():
|
| 78 |
partial_response = ""
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
):
|
| 85 |
-
|
| 86 |
if stream_data["type"] == "chunk":
|
| 87 |
chunk = stream_data["data"]["chunk"]
|
| 88 |
partial_response += chunk
|
|
@@ -96,9 +42,8 @@ async def test_inference():
|
|
| 96 |
total_time = stream_data['data']['total_time']
|
| 97 |
print(f"\nTotal time: {total_time:.2f}s")
|
| 98 |
|
| 99 |
-
# Execute async generator
|
| 100 |
async_gen = stream_response()
|
| 101 |
-
|
| 102 |
try:
|
| 103 |
while True:
|
| 104 |
result = loop.run_until_complete(async_gen.__anext__())
|
|
@@ -121,7 +66,7 @@ async def test_inference():
|
|
| 121 |
asyncio.set_event_loop(loop)
|
| 122 |
|
| 123 |
async def add_doc():
|
| 124 |
-
result =
|
| 125 |
return result
|
| 126 |
|
| 127 |
result = loop.run_until_complete(add_doc())
|
|
@@ -158,8 +103,7 @@ async def test_inference():
|
|
| 158 |
# Membuat interface Gradio
|
| 159 |
with gr.Blocks(css=css, title="RAG Chatbot") as demo:
|
| 160 |
gr.Markdown("""
|
| 161 |
-
# 🤖
|
| 162 |
-
Chatbot berbasis Retrieval-Augmented Generation (RAG) dengan dukungan streaming response.
|
| 163 |
""")
|
| 164 |
|
| 165 |
# Status indicator
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import asyncio
|
| 3 |
+
from rag.pipeline.language_model import LM, LMConfig
|
| 4 |
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 5 |
from rag.inference.inferencer import InferencerConfig, Inferencer
|
| 6 |
+
from rag import cs_agent, query_maker_agent
|
| 7 |
+
def test_inference():
|
| 8 |
"""Main function that sets up and runs the RAG chatbot interface"""
|
| 9 |
|
| 10 |
# Initialize RAG components
|
| 11 |
print("==== Start Inference Test ===")
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
print("RAG system initialized successfully!")
|
| 15 |
|
|
|
|
| 19 |
# Create new event loop for this thread
|
| 20 |
loop = asyncio.new_event_loop()
|
| 21 |
asyncio.set_event_loop(loop)
|
| 22 |
+
|
| 23 |
async def stream_response():
|
| 24 |
partial_response = ""
|
| 25 |
+
# print("message = ", message)
|
| 26 |
+
formatted_query = await query_maker_agent.get_result(question = message)
|
| 27 |
+
print("Formatted Query = ", formatted_query)
|
| 28 |
+
formatted_query = formatted_query['responses'][0]['rag_response']
|
| 29 |
+
await cs_agent.load_documents()
|
| 30 |
+
async for stream_data in cs_agent.get_result(question = formatted_query):
|
| 31 |
+
|
| 32 |
if stream_data["type"] == "chunk":
|
| 33 |
chunk = stream_data["data"]["chunk"]
|
| 34 |
partial_response += chunk
|
|
|
|
| 42 |
total_time = stream_data['data']['total_time']
|
| 43 |
print(f"\nTotal time: {total_time:.2f}s")
|
| 44 |
|
|
|
|
| 45 |
async_gen = stream_response()
|
| 46 |
+
|
| 47 |
try:
|
| 48 |
while True:
|
| 49 |
result = loop.run_until_complete(async_gen.__anext__())
|
|
|
|
| 66 |
asyncio.set_event_loop(loop)
|
| 67 |
|
| 68 |
async def add_doc():
|
| 69 |
+
result = ""
|
| 70 |
return result
|
| 71 |
|
| 72 |
result = loop.run_until_complete(add_doc())
|
|
|
|
| 103 |
# Membuat interface Gradio
|
| 104 |
with gr.Blocks(css=css, title="RAG Chatbot") as demo:
|
| 105 |
gr.Markdown("""
|
| 106 |
+
# 🤖 SakuraAI, Virtual Assistant
|
|
|
|
| 107 |
""")
|
| 108 |
|
| 109 |
# Status indicator
|
space/space/space/space/space/.env.example
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_API_KEY =
|
| 2 |
+
ELEVENLABS_API_KEY =
|
| 3 |
+
HF_TOKEN =
|
space/space/space/space/space/.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
documents/SPISy[[:space:]]SaaS[[:space:]]To[[:space:]]The[[:space:]]Next[[:space:]]Level.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
documents/file.pdf filter=lfs diff=lfs merge=lfs -text
|
space/space/space/space/space/.github/workflows/deploy-to-huggingface.yml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Deploy to Huggingface
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
deploy-to-huggingface:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
|
| 12 |
+
steps:
|
| 13 |
+
# Checkout repository
|
| 14 |
+
- name: Checkout Repository
|
| 15 |
+
uses: actions/checkout@v3
|
| 16 |
+
|
| 17 |
+
# Setup Git
|
| 18 |
+
- name: Setup Git for Huggingface
|
| 19 |
+
run: |
|
| 20 |
+
git config --global user.email "abdan.hafidz@gmail.com"
|
| 21 |
+
git config --global user.name "abdanhafidz"
|
| 22 |
+
|
| 23 |
+
# Clone Huggingface Space Repository
|
| 24 |
+
- name: Clone Huggingface Space
|
| 25 |
+
env:
|
| 26 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 27 |
+
run: |
|
| 28 |
+
git clone https://huggingface.co/spaces/lifedebugger/cs-ai-sakura-dev space
|
| 29 |
+
|
| 30 |
+
# Update Git Remote URL and Pull Latest Changes
|
| 31 |
+
- name: Update Remote and Pull Changes
|
| 32 |
+
env:
|
| 33 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 34 |
+
run: |
|
| 35 |
+
cd space
|
| 36 |
+
git remote set-url origin https://lifedebugger:$HF_TOKEN@huggingface.co/spaces/lifedebugger/cs-ai-sakura-dev
|
| 37 |
+
git pull origin main || echo "No changes to pull"
|
| 38 |
+
|
| 39 |
+
# Copy Files to Huggingface Space
|
| 40 |
+
- name: Copy Files to Space
|
| 41 |
+
run: |
|
| 42 |
+
rsync -av --exclude='.git' ./ space/
|
| 43 |
+
|
| 44 |
+
# Commit and Push to Huggingface Space
|
| 45 |
+
- name: Commit and Push to Huggingface
|
| 46 |
+
env:
|
| 47 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 48 |
+
run: |
|
| 49 |
+
cd space
|
| 50 |
+
git add .
|
| 51 |
+
git commit -m "Deploy files from GitHub repository" || echo "No changes to commit"
|
| 52 |
+
git push origin main || echo "No changes to push"
|
space/space/space/space/space/.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
venv/
|
| 3 |
+
.vscode/
|
| 4 |
+
__pycache__/
|
| 5 |
+
my_vectorstore/
|
| 6 |
+
FlagEmbedding/
|
| 7 |
+
.env
|
| 8 |
+
vectorstore/
|
| 9 |
+
documents/
|
space/space/space/space/space/Dockerfile
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gunakan image dasar Python versi 3.13
|
| 2 |
+
FROM python:3.13
|
| 3 |
+
|
| 4 |
+
# Tambahkan user non-root untuk keamanan
|
| 5 |
+
RUN useradd -m -u 1001 appuser
|
| 6 |
+
|
| 7 |
+
# Set working directory
|
| 8 |
+
WORKDIR /rag_be
|
| 9 |
+
|
| 10 |
+
# Set cache directories ke writable location
|
| 11 |
+
ENV HF_HOME=/tmp/.cache/huggingface
|
| 12 |
+
ENV TRANSFORMERS_CACHE=/tmp/.cache/transformers
|
| 13 |
+
ENV TORCH_HOME=/tmp/.cache/torch
|
| 14 |
+
ENV XDG_CACHE_HOME=/tmp/.cache
|
| 15 |
+
ENV TMPDIR=/tmp
|
| 16 |
+
ENV WHISPER_CACHE_DIR=/tmp/.cache/whisper
|
| 17 |
+
|
| 18 |
+
# Copy requirements dan install dependencies
|
| 19 |
+
COPY requirements.txt ./
|
| 20 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Copy aplikasi dengan ownership ke appuser
|
| 23 |
+
COPY --chown=appuser:appuser . /rag_be
|
| 24 |
+
|
| 25 |
+
# Buat file .env dengan variabel environment menggunakan Hugging Face secrets
|
| 26 |
+
RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=false \
|
| 27 |
+
--mount=type=secret,id=HF_TOKEN,mode=0444,required=false \
|
| 28 |
+
--mount=type=secret,id=ELEVENLABS_API_KEY,mode=0444,required=false \
|
| 29 |
+
echo "OPENAI_API_KEY=$(cat /run/secrets/OPENAI_API_KEY 2>/dev/null || echo '')" >> .env && \
|
| 30 |
+
echo "HF_TOKEN=$(cat /run/secrets/HF_TOKEN 2>/dev/null || echo '')" >> .env && \
|
| 31 |
+
echo "ELEVENLABS_API_KEY=$(cat /run/secrets/ELEVENLABS_API_KEY 2>/dev/null || echo '')" >> .env
|
| 32 |
+
|
| 33 |
+
RUN ls -l /rag_be/app && whoami && id
|
| 34 |
+
|
| 35 |
+
# Buat directories yang diperlukan dengan permissions yang tepat
|
| 36 |
+
RUN mkdir -p /tmp/.cache /tmp/.cache/whisper /tmp/.cache/huggingface /rag_be/vectorstore /tmp/.cache/transformers /tmp/.cache/torch \
|
| 37 |
+
/rag_be/app/vectorstore /rag_be/documents && \
|
| 38 |
+
chmod -R 777 /tmp/.cache /rag_be/app /rag_be/app/vectorstore /rag_be/vectorstore /rag_be/documents && \
|
| 39 |
+
chown -R appuser:appuser /tmp/.cache /rag_be/app /rag_be/app/vectorstore /rag_be/vectorstore /rag_be/documents /rag_be/.env
|
| 40 |
+
|
| 41 |
+
RUN apt-get update && apt-get install -y ffmpeg
|
| 42 |
+
# Beralih ke user non-root
|
| 43 |
+
USER appuser
|
| 44 |
+
|
| 45 |
+
# Expose port untuk Hugging Face Spaces
|
| 46 |
+
EXPOSE 7860
|
| 47 |
+
|
| 48 |
+
# Jalankan aplikasi
|
| 49 |
+
CMD ["python", "app/__test__.py"]
|
space/space/space/space/space/README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Cs Ai Sakura Dev
|
| 3 |
+
emoji: 🏢
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Install The Requirements**
|
| 11 |
+
|
| 12 |
+
1.Create a virtual environment and install the dependencies
|
| 13 |
+
```
|
| 14 |
+
python3 -m venv env
|
| 15 |
+
source env/bin/activate
|
| 16 |
+
pip install -r requirements.txt
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
2. Set your OPENAI_API_KEY in .env file
|
| 20 |
+
|
| 21 |
+
3. **TO LAUNCH THE GRADIO UI** Run the command below :
|
| 22 |
+
```
|
| 23 |
+
cd app
|
| 24 |
+
python __test__.py
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
4. **TO LAUNCH THE API ENDPOINT (SERVER)** Run the command below :
|
| 28 |
+
```
|
| 29 |
+
cd app
|
| 30 |
+
python __server__.py
|
| 31 |
+
```
|
space/space/space/space/space/app/.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
space/space/space/space/space/app/__chat__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tests.inference_test import test_inference
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
warnings.filterwarnings("ignore")
|
| 5 |
+
import asyncio
|
| 6 |
+
def run_test():
|
| 7 |
+
try:
|
| 8 |
+
# await test_document_retriever()
|
| 9 |
+
# await test_qwen_llm()
|
| 10 |
+
asyncio.run(test_inference())
|
| 11 |
+
except Exception as e:
|
| 12 |
+
print(e)
|
| 13 |
+
|
| 14 |
+
run_test()
|
space/space/space/space/space/app/__server__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import rtc
|
| 2 |
+
|
| 3 |
+
rtc.handle_rtc_server()
|
space/space/space/space/space/app/__test__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# from tests.document_retriever_test import test_document_retriever
|
| 3 |
+
# from tests.document_retriever_test import test_document_retriever
|
| 4 |
+
# from tests.qwen_llm_test import test_qwen_llm
|
| 5 |
+
# from tests.inference_test import test_inference
|
| 6 |
+
from tests.rtc_test import test_rtc
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings("ignore")
|
| 9 |
+
import asyncio
|
| 10 |
+
def run_test():
|
| 11 |
+
try:
|
| 12 |
+
# await test_document_retriever()
|
| 13 |
+
# await test_qwen_llm()
|
| 14 |
+
# asyncio.run(test_inference())
|
| 15 |
+
test_rtc()
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(e)
|
| 18 |
+
|
| 19 |
+
run_test()
|
space/space/space/space/space/app/config/__init__.py
ADDED
|
File without changes
|
space/space/space/space/space/app/config/constant.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
load_dotenv()
|
| 4 |
+
|
| 5 |
+
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
| 6 |
+
ELEVENLABS_API_KEY = os.getenv('ELEVENLABS_API_KEY')
|
| 7 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
space/space/space/space/space/app/rag/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.pipeline.qwen_llm import QwenLLM, QwenConfig
|
| 2 |
+
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 3 |
+
from rag.inference.inferencer import Inferencer, InferencerConfig
|
| 4 |
+
|
| 5 |
+
config = QwenConfig(
|
| 6 |
+
temperature=0.3,
|
| 7 |
+
max_length=512,
|
| 8 |
+
generation_timeout=30,
|
| 9 |
+
repetition_penalty=1.1,
|
| 10 |
+
max_workers = 1,
|
| 11 |
+
do_sample = True,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
llm = QwenLLM(
|
| 15 |
+
config = config
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
inferencer_config = InferencerConfig(
|
| 19 |
+
default_k=5,
|
| 20 |
+
enable_reranking=False,
|
| 21 |
+
default_template_types="main_template"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
document_retriever = LangChainRetriever(
|
| 25 |
+
embedding_model="all-MiniLM-L6-v2",
|
| 26 |
+
vectorstore_type="chroma",
|
| 27 |
+
vectorstore_path="./vectorstore",
|
| 28 |
+
use_hybrid_search=True,
|
| 29 |
+
chunk_size=1000,
|
| 30 |
+
chunk_overlap=200
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
inferencer = Inferencer(
|
| 34 |
+
model=llm,
|
| 35 |
+
retriever=document_retriever,
|
| 36 |
+
reranker=None,
|
| 37 |
+
config=inferencer_config
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
async def get_response(question):
|
| 41 |
+
result = await inferencer.infer(question, "rag_response")
|
| 42 |
+
return result
|
| 43 |
+
|
| 44 |
+
async def get_stream_response(question):
|
| 45 |
+
async for item in inferencer.infer_stream(query = question,
|
| 46 |
+
enable_reranking=False,
|
| 47 |
+
template_type="main_template",
|
| 48 |
+
k=3):
|
| 49 |
+
print("Stream Response :", item)
|
| 50 |
+
yield item
|
space/space/space/space/space/app/rag/inference/inferencer.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag.retriever.langchain_retriever import LangChainRetriever
|
| 2 |
+
from rag.pipeline.qwen_llm import QwenLLM, QwenConfig
|
| 3 |
+
from rag.retriever.retriever_types import RetrievalResult
|
| 4 |
+
# from rag.pipeline.reranker import BGEM3Reranker
|
| 5 |
+
from typing import List, Union, Dict, Any, Optional, AsyncGenerator
|
| 6 |
+
import asyncio
|
| 7 |
+
import logging
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class InferencerConfig:
|
| 13 |
+
"""Konfigurasi untuk Inferencer"""
|
| 14 |
+
default_k: int = 5
|
| 15 |
+
max_contexts: int = 10
|
| 16 |
+
enable_reranking: bool = False
|
| 17 |
+
reranker_top_k: int = 5
|
| 18 |
+
default_template_types: List[str] = None
|
| 19 |
+
enable_logging: bool = True
|
| 20 |
+
response_timeout: float = 30.0
|
| 21 |
+
|
| 22 |
+
def __post_init__(self):
|
| 23 |
+
if self.default_template_types is None:
|
| 24 |
+
self.default_template_types = ["system", "instruction", "friendly"]
|
| 25 |
+
|
| 26 |
+
class Inferencer:
|
| 27 |
+
"""
|
| 28 |
+
Advanced RAG Inferencer dengan support untuk streaming, reranking, dan multiple response types
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self,
|
| 32 |
+
model: QwenLLM,
|
| 33 |
+
retriever: LangChainRetriever,
|
| 34 |
+
reranker=None,
|
| 35 |
+
config: Optional[InferencerConfig] = None):
|
| 36 |
+
"""
|
| 37 |
+
Initialize Inferencer
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model: QwenLLM instance
|
| 41 |
+
retriever: LangChainRetriever instance
|
| 42 |
+
reranker: Reranker instance (optional)
|
| 43 |
+
config: InferencerConfig (optional)
|
| 44 |
+
"""
|
| 45 |
+
self.model = model
|
| 46 |
+
self.retriever = retriever
|
| 47 |
+
self.reranker = reranker
|
| 48 |
+
self.config = config or InferencerConfig()
|
| 49 |
+
|
| 50 |
+
# Setup logging
|
| 51 |
+
if self.config.enable_logging:
|
| 52 |
+
logging.basicConfig(level=logging.INFO)
|
| 53 |
+
self.logger = logging.getLogger(__name__)
|
| 54 |
+
else:
|
| 55 |
+
self.logger = logging.getLogger(__name__)
|
| 56 |
+
self.logger.setLevel(logging.ERROR)
|
| 57 |
+
|
| 58 |
+
# Model loading flag
|
| 59 |
+
self._model_loaded = False
|
| 60 |
+
|
| 61 |
+
async def _ensure_model_loaded(self):
|
| 62 |
+
"""Pastikan model sudah diload (hanya sekali)"""
|
| 63 |
+
if not self._model_loaded:
|
| 64 |
+
self.logger.info("Loading model...")
|
| 65 |
+
await self.model.load_model()
|
| 66 |
+
self._model_loaded = True
|
| 67 |
+
self.logger.info("Model loaded successfully")
|
| 68 |
+
|
| 69 |
+
async def retrieve_context(self,
|
| 70 |
+
query: str,
|
| 71 |
+
k: Optional[int] = None) -> RetrievalResult:
|
| 72 |
+
"""
|
| 73 |
+
Retrieve context documents
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
query: Search query
|
| 77 |
+
k: Number of documents to retrieve
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
RetrievalResult object
|
| 81 |
+
"""
|
| 82 |
+
k = k or self.config.default_k
|
| 83 |
+
self.logger.info(f"Retrieving {k} contexts for query: {query[:50]}...")
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
start_time = datetime.now()
|
| 87 |
+
contexts = await self.retriever.retrieve(query, k=k)
|
| 88 |
+
retrieval_time = (datetime.now() - start_time).total_seconds()
|
| 89 |
+
|
| 90 |
+
self.logger.info(f"Retrieved {len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts)} contexts in {retrieval_time:.2f}s")
|
| 91 |
+
return contexts
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
self.logger.error(f"Error during retrieval: {e}")
|
| 95 |
+
raise
|
| 96 |
+
|
| 97 |
+
async def rerank_contexts(self,
|
| 98 |
+
contexts: RetrievalResult,
|
| 99 |
+
query: str,
|
| 100 |
+
top_k: Optional[int] = None) -> RetrievalResult:
|
| 101 |
+
"""
|
| 102 |
+
Rerank retrieved contexts
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
contexts: Retrieved contexts
|
| 106 |
+
query: Original query
|
| 107 |
+
top_k: Number of top contexts to keep after reranking
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Reranked RetrievalResult object
|
| 111 |
+
"""
|
| 112 |
+
if not self.reranker or not self.config.enable_reranking:
|
| 113 |
+
self.logger.info("Reranking disabled or reranker not available")
|
| 114 |
+
return contexts
|
| 115 |
+
|
| 116 |
+
top_k = top_k or self.config.reranker_top_k
|
| 117 |
+
self.logger.info(f"Reranking contexts, keeping top {top_k}")
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
start_time = datetime.now()
|
| 121 |
+
reranked_contexts = await self.reranker.rerank(
|
| 122 |
+
query=query,
|
| 123 |
+
contexts=contexts,
|
| 124 |
+
top_k=top_k
|
| 125 |
+
)
|
| 126 |
+
rerank_time = (datetime.now() - start_time).total_seconds()
|
| 127 |
+
|
| 128 |
+
self.logger.info(f"Reranking completed in {rerank_time:.2f}s")
|
| 129 |
+
return reranked_contexts
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
self.logger.error(f"Error during reranking: {e}")
|
| 133 |
+
# Return original contexts if reranking fails
|
| 134 |
+
return contexts
|
| 135 |
+
|
| 136 |
+
async def generate_response(self,
|
| 137 |
+
contexts: RetrievalResult,
|
| 138 |
+
query: Union[str, List[str]],
|
| 139 |
+
response_type: Union[List[str], str] = None,
|
| 140 |
+
template_types: Optional[List[str]] = None,
|
| 141 |
+
max_new_tokens: Optional[int] = None,
|
| 142 |
+
**generation_kwargs) -> List[Dict[str, Any]]:
|
| 143 |
+
"""
|
| 144 |
+
Generate responses based on contexts and query
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
contexts: Retrieved contexts
|
| 148 |
+
query: User query or list of queries
|
| 149 |
+
response_type: Type(s) of response to generate
|
| 150 |
+
template_types: Template types for multi_response
|
| 151 |
+
max_new_tokens: Maximum tokens to generate
|
| 152 |
+
**generation_kwargs: Additional generation parameters
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
List of response dictionaries
|
| 156 |
+
"""
|
| 157 |
+
await self._ensure_model_loaded()
|
| 158 |
+
|
| 159 |
+
# Default response types
|
| 160 |
+
if response_type is None:
|
| 161 |
+
response_type = ["rag_response"]
|
| 162 |
+
elif isinstance(response_type, str):
|
| 163 |
+
response_type = [response_type]
|
| 164 |
+
|
| 165 |
+
# Default template types
|
| 166 |
+
if template_types is None:
|
| 167 |
+
template_types = self.config.default_template_types
|
| 168 |
+
|
| 169 |
+
responses = []
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
# RAG Response
|
| 173 |
+
if "rag_response" in response_type:
|
| 174 |
+
self.logger.info("Generating RAG response...")
|
| 175 |
+
start_time = datetime.now()
|
| 176 |
+
|
| 177 |
+
if isinstance(query, list):
|
| 178 |
+
# Handle multiple queries
|
| 179 |
+
rag_responses = {}
|
| 180 |
+
for i, q in enumerate(query):
|
| 181 |
+
rag_response = await self.model.rag_generate(
|
| 182 |
+
question=q,
|
| 183 |
+
contexts=contexts,
|
| 184 |
+
template_type="friendly",
|
| 185 |
+
max_new_tokens=max_new_tokens,
|
| 186 |
+
**generation_kwargs
|
| 187 |
+
)
|
| 188 |
+
rag_responses[f"query_{i}"] = rag_response
|
| 189 |
+
responses.append({"rag_response": rag_responses})
|
| 190 |
+
else:
|
| 191 |
+
rag_response = await self.model.rag_generate(
|
| 192 |
+
question=query,
|
| 193 |
+
contexts=contexts,
|
| 194 |
+
template_type="friendly",
|
| 195 |
+
max_new_tokens=max_new_tokens,
|
| 196 |
+
**generation_kwargs
|
| 197 |
+
)
|
| 198 |
+
responses.append({"rag_response": rag_response})
|
| 199 |
+
|
| 200 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 201 |
+
self.logger.info(f"RAG response generated in {generation_time:.2f}s")
|
| 202 |
+
|
| 203 |
+
# Multi-template Response
|
| 204 |
+
if "multi_response" in response_type:
|
| 205 |
+
self.logger.info("Generating multi-template responses...")
|
| 206 |
+
start_time = datetime.now()
|
| 207 |
+
|
| 208 |
+
if isinstance(query, list):
|
| 209 |
+
multi_responses = {}
|
| 210 |
+
for i, q in enumerate(query):
|
| 211 |
+
multi_response = await self.model.multi_template_generate(
|
| 212 |
+
question=q,
|
| 213 |
+
contexts=contexts,
|
| 214 |
+
template_types=template_types,
|
| 215 |
+
max_new_tokens=max_new_tokens,
|
| 216 |
+
**generation_kwargs
|
| 217 |
+
)
|
| 218 |
+
multi_responses[f"query_{i}"] = multi_response
|
| 219 |
+
responses.append({"multi_responses": multi_responses})
|
| 220 |
+
else:
|
| 221 |
+
multi_responses = await self.model.multi_template_generate(
|
| 222 |
+
question=query,
|
| 223 |
+
contexts=contexts,
|
| 224 |
+
template_types=template_types,
|
| 225 |
+
max_new_tokens=max_new_tokens,
|
| 226 |
+
**generation_kwargs
|
| 227 |
+
)
|
| 228 |
+
responses.append({"multi_responses": multi_responses})
|
| 229 |
+
|
| 230 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 231 |
+
self.logger.info(f"Multi-template responses generated in {generation_time:.2f}s")
|
| 232 |
+
|
| 233 |
+
# Batch Response (untuk multiple prompts tanpa RAG context)
|
| 234 |
+
if "batch_response" in response_type:
|
| 235 |
+
self.logger.info("Generating batch responses...")
|
| 236 |
+
start_time = datetime.now()
|
| 237 |
+
|
| 238 |
+
if isinstance(query, list):
|
| 239 |
+
batch_responses = await self.model.batch_generate(
|
| 240 |
+
query,
|
| 241 |
+
max_new_tokens=max_new_tokens,
|
| 242 |
+
**generation_kwargs
|
| 243 |
+
)
|
| 244 |
+
else:
|
| 245 |
+
batch_responses = await self.model.batch_generate(
|
| 246 |
+
[query],
|
| 247 |
+
max_new_tokens=max_new_tokens,
|
| 248 |
+
**generation_kwargs
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
responses.append({"batch_responses": batch_responses})
|
| 252 |
+
|
| 253 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 254 |
+
self.logger.info(f"Batch responses generated in {generation_time:.2f}s")
|
| 255 |
+
|
| 256 |
+
return responses
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
self.logger.error(f"Error during response generation: {e}")
|
| 260 |
+
raise
|
| 261 |
+
|
| 262 |
+
async def generate_response_stream(self,
|
| 263 |
+
contexts: RetrievalResult,
|
| 264 |
+
query: str,
|
| 265 |
+
template_type: str = "main_template",
|
| 266 |
+
max_new_tokens: Optional[int] = None,
|
| 267 |
+
**generation_kwargs) -> AsyncGenerator[str, None]:
|
| 268 |
+
"""
|
| 269 |
+
Generate RAG response with streaming
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
contexts: Retrieved contexts
|
| 273 |
+
query: User query
|
| 274 |
+
template_type: Template type to use
|
| 275 |
+
max_new_tokens: Maximum tokens to generate
|
| 276 |
+
**generation_kwargs: Additional generation parameters
|
| 277 |
+
|
| 278 |
+
Yields:
|
| 279 |
+
Response chunks
|
| 280 |
+
"""
|
| 281 |
+
await self._ensure_model_loaded()
|
| 282 |
+
|
| 283 |
+
self.logger.info(f"Generating streaming RAG response with template: {template_type}")
|
| 284 |
+
|
| 285 |
+
async for chunk in self.model.rag_generate_stream(
|
| 286 |
+
question=query,
|
| 287 |
+
contexts=contexts,
|
| 288 |
+
template_type=template_type,
|
| 289 |
+
max_new_tokens=max_new_tokens,
|
| 290 |
+
**generation_kwargs
|
| 291 |
+
):
|
| 292 |
+
yield chunk
|
| 293 |
+
|
| 294 |
+
async def infer(self,
|
| 295 |
+
query: Union[str, List[str]],
|
| 296 |
+
response_type: Union[List[str], str] = None,
|
| 297 |
+
k: Optional[int] = None,
|
| 298 |
+
enable_reranking: Optional[bool] = None,
|
| 299 |
+
template_types: Optional[List[str]] = None,
|
| 300 |
+
max_new_tokens: Optional[int] = None,
|
| 301 |
+
**generation_kwargs) -> Dict[str, Any]:
|
| 302 |
+
"""
|
| 303 |
+
Complete inference pipeline
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
query: User query or list of queries
|
| 307 |
+
response_type: Type(s) of response to generate
|
| 308 |
+
k: Number of contexts to retrieve
|
| 309 |
+
enable_reranking: Whether to enable reranking
|
| 310 |
+
template_types: Template types for multi_response
|
| 311 |
+
max_new_tokens: Maximum tokens to generate
|
| 312 |
+
**generation_kwargs: Additional generation parameters
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Dictionary with results and metadata
|
| 316 |
+
"""
|
| 317 |
+
start_time = datetime.now()
|
| 318 |
+
|
| 319 |
+
# Handle single query
|
| 320 |
+
main_query = query[0] if isinstance(query, list) else query
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
# Step 1: Retrieve contexts
|
| 324 |
+
retrieved_contexts = await self.retrieve_context(main_query, k=k)
|
| 325 |
+
|
| 326 |
+
# Step 2: Rerank contexts (if enabled)
|
| 327 |
+
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 328 |
+
if enable_rerank:
|
| 329 |
+
contexts = await self.rerank_contexts(retrieved_contexts, main_query)
|
| 330 |
+
else:
|
| 331 |
+
contexts = retrieved_contexts
|
| 332 |
+
|
| 333 |
+
# Step 3: Generate responses
|
| 334 |
+
responses = await self.generate_response(
|
| 335 |
+
contexts=contexts,
|
| 336 |
+
query=query,
|
| 337 |
+
response_type=response_type,
|
| 338 |
+
template_types=template_types,
|
| 339 |
+
max_new_tokens=max_new_tokens,
|
| 340 |
+
**generation_kwargs
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
total_time = (datetime.now() - start_time).total_seconds()
|
| 344 |
+
|
| 345 |
+
# Prepare result
|
| 346 |
+
result = {
|
| 347 |
+
"query": query,
|
| 348 |
+
"responses": responses,
|
| 349 |
+
"contexts": contexts,
|
| 350 |
+
"metadata": {
|
| 351 |
+
"total_time": total_time,
|
| 352 |
+
"retrieval_enabled": True,
|
| 353 |
+
"reranking_enabled": enable_rerank,
|
| 354 |
+
"num_contexts": len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts),
|
| 355 |
+
"response_types": response_type,
|
| 356 |
+
"timestamp": datetime.now().isoformat()
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
self.logger.info(f"Inference completed in {total_time:.2f}s")
|
| 361 |
+
return result
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
self.logger.error(f"Error during inference: {e}")
|
| 365 |
+
raise
|
| 366 |
+
|
| 367 |
+
async def infer_stream(self,
|
| 368 |
+
query: str,
|
| 369 |
+
k: Optional[int] = None,
|
| 370 |
+
enable_reranking: Optional[bool] = None,
|
| 371 |
+
template_type: str = "main_template",
|
| 372 |
+
max_new_tokens: Optional[int] = None,
|
| 373 |
+
**generation_kwargs) -> AsyncGenerator[Dict[str, Any], None]:
|
| 374 |
+
"""
|
| 375 |
+
Complete inference pipeline with streaming response
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
query: User query
|
| 379 |
+
k: Number of contexts to retrieve
|
| 380 |
+
enable_reranking: Whether to enable reranking
|
| 381 |
+
template_type: Template type to use
|
| 382 |
+
max_new_tokens: Maximum tokens to generate
|
| 383 |
+
**generation_kwargs: Additional generation parameters
|
| 384 |
+
|
| 385 |
+
Yields:
|
| 386 |
+
Dictionaries with stream data and metadata
|
| 387 |
+
"""
|
| 388 |
+
start_time = datetime.now()
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
# Step 1: Retrieve contexts
|
| 392 |
+
retrieved_contexts = await self.retrieve_context(query, k=k)
|
| 393 |
+
|
| 394 |
+
# Step 2: Rerank contexts (if enabled)
|
| 395 |
+
enable_rerank = enable_reranking if enable_reranking is not None else self.config.enable_reranking
|
| 396 |
+
if enable_rerank:
|
| 397 |
+
contexts = await self.rerank_contexts(retrieved_contexts, query)
|
| 398 |
+
else:
|
| 399 |
+
contexts = retrieved_contexts
|
| 400 |
+
|
| 401 |
+
# Yield metadata first
|
| 402 |
+
setup_time = (datetime.now() - start_time).total_seconds()
|
| 403 |
+
yield {
|
| 404 |
+
"type": "metadata",
|
| 405 |
+
"data": {
|
| 406 |
+
"query": query,
|
| 407 |
+
"setup_time": setup_time,
|
| 408 |
+
"num_contexts": len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts),
|
| 409 |
+
"reranking_enabled": enable_rerank,
|
| 410 |
+
"template_type": template_type
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
# Step 3: Stream response
|
| 415 |
+
response_start = datetime.now()
|
| 416 |
+
accumulated_text = ""
|
| 417 |
+
|
| 418 |
+
async for chunk in self.generate_response_stream(
|
| 419 |
+
contexts=contexts,
|
| 420 |
+
query=query,
|
| 421 |
+
template_type=template_type,
|
| 422 |
+
max_new_tokens=max_new_tokens,
|
| 423 |
+
**generation_kwargs
|
| 424 |
+
):
|
| 425 |
+
accumulated_text += chunk
|
| 426 |
+
yield {
|
| 427 |
+
"type": "chunk",
|
| 428 |
+
"data": {
|
| 429 |
+
"chunk": chunk,
|
| 430 |
+
"accumulated_text": accumulated_text,
|
| 431 |
+
"generation_time": (datetime.now() - response_start).total_seconds()
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
# Yield final metadata
|
| 436 |
+
total_time = (datetime.now() - start_time).total_seconds()
|
| 437 |
+
yield {
|
| 438 |
+
"type": "complete",
|
| 439 |
+
"data": {
|
| 440 |
+
"total_time": total_time,
|
| 441 |
+
"final_response": accumulated_text,
|
| 442 |
+
"contexts": contexts
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
except Exception as e:
|
| 447 |
+
self.logger.error(f"Error during streaming inference: {e}")
|
| 448 |
+
yield {
|
| 449 |
+
"type": "error",
|
| 450 |
+
"data": {
|
| 451 |
+
"error": str(e),
|
| 452 |
+
"error_time": (datetime.now() - start_time).total_seconds()
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
async def batch_infer(self,
|
| 457 |
+
queries: List[str],
|
| 458 |
+
response_type: Union[List[str], str] = None,
|
| 459 |
+
k: Optional[int] = None,
|
| 460 |
+
enable_reranking: Optional[bool] = None,
|
| 461 |
+
**generation_kwargs) -> List[Dict[str, Any]]:
|
| 462 |
+
"""
|
| 463 |
+
Batch inference untuk multiple queries
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
queries: List of queries
|
| 467 |
+
response_type: Type(s) of response to generate
|
| 468 |
+
k: Number of contexts to retrieve per query
|
| 469 |
+
enable_reranking: Whether to enable reranking
|
| 470 |
+
**generation_kwargs: Additional generation parameters
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
List of inference results
|
| 474 |
+
"""
|
| 475 |
+
self.logger.info(f"Starting batch inference for {len(queries)} queries")
|
| 476 |
+
|
| 477 |
+
# Create tasks untuk concurrent processing
|
| 478 |
+
tasks = [
|
| 479 |
+
asyncio.create_task(
|
| 480 |
+
self.infer(
|
| 481 |
+
query=query,
|
| 482 |
+
response_type=response_type,
|
| 483 |
+
k=k,
|
| 484 |
+
enable_reranking=enable_reranking,
|
| 485 |
+
**generation_kwargs
|
| 486 |
+
)
|
| 487 |
+
)
|
| 488 |
+
for query in queries
|
| 489 |
+
]
|
| 490 |
+
|
| 491 |
+
# Wait for all tasks
|
| 492 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 493 |
+
|
| 494 |
+
# Process results
|
| 495 |
+
processed_results = []
|
| 496 |
+
for i, result in enumerate(results):
|
| 497 |
+
if isinstance(result, Exception):
|
| 498 |
+
self.logger.error(f"Error processing query {i}: {result}")
|
| 499 |
+
processed_results.append({
|
| 500 |
+
"query": queries[i],
|
| 501 |
+
"error": str(result),
|
| 502 |
+
"success": False
|
| 503 |
+
})
|
| 504 |
+
else:
|
| 505 |
+
result["success"] = True
|
| 506 |
+
processed_results.append(result)
|
| 507 |
+
|
| 508 |
+
return processed_results
|
| 509 |
+
|
| 510 |
+
async def get_available_templates(self) -> List[str]:
|
| 511 |
+
"""Get available template types from model"""
|
| 512 |
+
await self._ensure_model_loaded()
|
| 513 |
+
return self.model.get_available_templates()
|
| 514 |
+
|
| 515 |
+
async def preview_template(self,
|
| 516 |
+
template_type: str,
|
| 517 |
+
sample_query: str = "Apa itu AI?") -> str:
|
| 518 |
+
"""Preview template formatting"""
|
| 519 |
+
await self._ensure_model_loaded()
|
| 520 |
+
return self.model.preview_template(
|
| 521 |
+
template_type=template_type,
|
| 522 |
+
sample_question=sample_query,
|
| 523 |
+
sample_context="Sample context untuk preview template..."
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
async def get_model_info(self) -> Dict[str, Any]:
|
| 527 |
+
"""Get model information"""
|
| 528 |
+
await self._ensure_model_loaded()
|
| 529 |
+
model_info = await self.model.get_model_info()
|
| 530 |
+
|
| 531 |
+
return {
|
| 532 |
+
"model_info": model_info,
|
| 533 |
+
"inferencer_config": self.config.__dict__,
|
| 534 |
+
"reranker_available": self.reranker is not None,
|
| 535 |
+
"available_templates": await self.get_available_templates()
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
async def close(self):
|
| 539 |
+
"""Clean up resources"""
|
| 540 |
+
self.logger.info("Closing Inferencer...")
|
| 541 |
+
if self.model:
|
| 542 |
+
await self.model.close()
|
| 543 |
+
self.logger.info("Inferencer closed successfully")
|
| 544 |
+
|
| 545 |
+
async def __aenter__(self):
|
| 546 |
+
"""Async context manager entry"""
|
| 547 |
+
await self._ensure_model_loaded()
|
| 548 |
+
return self
|
| 549 |
+
|
| 550 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 551 |
+
"""Async context manager exit"""
|
| 552 |
+
await self.close()
|
space/space/space/space/space/app/rag/pipeline/__init__.py
ADDED
|
File without changes
|