snote / scripts /response_generator.py
xuanbao01's picture
Upload folder using huggingface_hub
44c5827 verified
from llm_with_prompt import LLM
from typing import List, Dict, Any, Optional
import yaml
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
#--------------------
# Helper functions
#--------------------
config = yaml.load(open(os.path.join(current_dir, "..", "config.yaml")), Loader=yaml.FullLoader)[0]
model_config = config["model_config"]
USER_PROMPT_TEMPLATE = """
Câu hỏi của sinh viên:
{query}
Dựa trên các thông tin sau, hãy đưa ra câu trả lời ngắn gọn, chính xác, trích từ nội dung khi cần:
{context}
"""
def build_user_prompt(query: str, chunks_retrieved: List[Dict[str, Any]]) -> str:
"""
Build prompt for the model to generate answer.
"""
context = []
for i, chunk in enumerate(chunks_retrieved):
context.append(f"Chunk {i+1}**: \n {chunk['doc_id']}: \n {chunk['chunk_for_embedding']}")
context = "\n\n".join(context)
return USER_PROMPT_TEMPLATE.format(query=query, context=context)
class ResponseGenerator:
def __init__(self):
self.gen_answer_agent = LLM(model_name=model_config["model_name"], prompt_id="gen_answer", version="1.0")
def generate(self, query, chunks_retrieved: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
chunks_retrieved: list of chunk dicts (ranked). We'll take top self.top_k_chunks for prompt building.
Returns structured dict: answer_json + grounding metrics + raw model output
"""
user_prompt = build_user_prompt(query, chunks_retrieved)
try:
raw_answer = self.gen_answer_agent.generate_response(user_prompt)
except Exception as e:
logger.exception("Error: %s", e)
return self._get_error_message(e)
return raw_answer, user_prompt
def _get_error_message(self, e: Exception) -> Dict[str, Any]:
return {
"answer": "Error: " + str(e),
"claims": [],
"citations": [],
"grounding": {"total_claims": 0, "supported_claims": 0, "rate": 0.0},
"raw_model": e
}
def _safe_not_found(self, raw_model: Optional[Any] = None) -> Dict[str, Any]:
return {
"answer": "Không tìm thấy quy định",
"claims": [],
"citations": [],
"grounding": {"total_claims": 0, "supported_claims": 0, "rate": 0.0},
"raw_model": raw_model
}
if __name__ == "__main__":
chunks_retrieved = [
{
"id": "11.QD::CH3::A5::C1",
"doc_id": "1563/QĐ-ĐHTL",
"chunk_text": "Nếu sinh viên nghỉ học dưới 10 ngày tại kỳ chính thì không phải nộp học phí.",
"chunk_for_embedding": None,
"metadata": {"chapter":3, "article":5, "identifiers":["1563/QĐ-ĐHTL"]}
},
{
"id": "11.QD::CH3::A5::C2",
"doc_id": "1563/QĐ-ĐHTL",
"chunk_text": "Nếu nghỉ từ 10 ngày đến dưới 20 ngày tại kỳ chính phải nộp 30% học phí.",
"metadata": {"chapter":3, "article":5}
}
]
query = "Em quên đóng học phí quá hạn có bị cấm bảo vệ khóa luận không ạ ?"
gen = ResponseGenerator()
out = gen.generate(query, chunks_retrieved)