from llm_with_prompt import LLM from typing import List, Dict, Any, Optional import yaml import os current_dir = os.path.dirname(os.path.abspath(__file__)) import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) #-------------------- # Helper functions #-------------------- config = yaml.load(open(os.path.join(current_dir, "..", "config.yaml")), Loader=yaml.FullLoader)[0] model_config = config["model_config"] USER_PROMPT_TEMPLATE = """ Câu hỏi của sinh viên: {query} Dựa trên các thông tin sau, hãy đưa ra câu trả lời ngắn gọn, chính xác, trích từ nội dung khi cần: {context} """ def build_user_prompt(query: str, chunks_retrieved: List[Dict[str, Any]]) -> str: """ Build prompt for the model to generate answer. """ context = [] for i, chunk in enumerate(chunks_retrieved): context.append(f"Chunk {i+1}**: \n {chunk['doc_id']}: \n {chunk['chunk_for_embedding']}") context = "\n\n".join(context) return USER_PROMPT_TEMPLATE.format(query=query, context=context) class ResponseGenerator: def __init__(self): self.gen_answer_agent = LLM(model_name=model_config["model_name"], prompt_id="gen_answer", version="1.0") def generate(self, query, chunks_retrieved: List[Dict[str, Any]]) -> Dict[str, Any]: """ chunks_retrieved: list of chunk dicts (ranked). We'll take top self.top_k_chunks for prompt building. Returns structured dict: answer_json + grounding metrics + raw model output """ user_prompt = build_user_prompt(query, chunks_retrieved) try: raw_answer = self.gen_answer_agent.generate_response(user_prompt) except Exception as e: logger.exception("Error: %s", e) return self._get_error_message(e) return raw_answer, user_prompt def _get_error_message(self, e: Exception) -> Dict[str, Any]: return { "answer": "Error: " + str(e), "claims": [], "citations": [], "grounding": {"total_claims": 0, "supported_claims": 0, "rate": 0.0}, "raw_model": e } def _safe_not_found(self, raw_model: Optional[Any] = None) -> Dict[str, Any]: return { "answer": "Không tìm thấy quy định", "claims": [], "citations": [], "grounding": {"total_claims": 0, "supported_claims": 0, "rate": 0.0}, "raw_model": raw_model } if __name__ == "__main__": chunks_retrieved = [ { "id": "11.QD::CH3::A5::C1", "doc_id": "1563/QĐ-ĐHTL", "chunk_text": "Nếu sinh viên nghỉ học dưới 10 ngày tại kỳ chính thì không phải nộp học phí.", "chunk_for_embedding": None, "metadata": {"chapter":3, "article":5, "identifiers":["1563/QĐ-ĐHTL"]} }, { "id": "11.QD::CH3::A5::C2", "doc_id": "1563/QĐ-ĐHTL", "chunk_text": "Nếu nghỉ từ 10 ngày đến dưới 20 ngày tại kỳ chính phải nộp 30% học phí.", "metadata": {"chapter":3, "article":5} } ] query = "Em quên đóng học phí quá hạn có bị cấm bảo vệ khóa luận không ạ ?" gen = ResponseGenerator() out = gen.generate(query, chunks_retrieved)