|
|
from llm_with_prompt import LLM |
|
|
from typing import List, Dict, Any, Optional |
|
|
import yaml |
|
|
import os |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
import logging |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
config = yaml.load(open(os.path.join(current_dir, "..", "config.yaml")), Loader=yaml.FullLoader)[0] |
|
|
model_config = config["model_config"] |
|
|
|
|
|
USER_PROMPT_TEMPLATE = """ |
|
|
Câu hỏi của sinh viên: |
|
|
{query} |
|
|
|
|
|
Dựa trên các thông tin sau, hãy đưa ra câu trả lời ngắn gọn, chính xác, trích từ nội dung khi cần: |
|
|
{context} |
|
|
""" |
|
|
|
|
|
|
|
|
def build_user_prompt(query: str, chunks_retrieved: List[Dict[str, Any]]) -> str: |
|
|
""" |
|
|
Build prompt for the model to generate answer. |
|
|
""" |
|
|
context = [] |
|
|
for i, chunk in enumerate(chunks_retrieved): |
|
|
context.append(f"Chunk {i+1}**: \n {chunk['doc_id']}: \n {chunk['chunk_for_embedding']}") |
|
|
context = "\n\n".join(context) |
|
|
return USER_PROMPT_TEMPLATE.format(query=query, context=context) |
|
|
|
|
|
class ResponseGenerator: |
|
|
def __init__(self): |
|
|
self.gen_answer_agent = LLM(model_name=model_config["model_name"], prompt_id="gen_answer", version="1.0") |
|
|
|
|
|
def generate(self, query, chunks_retrieved: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
|
""" |
|
|
chunks_retrieved: list of chunk dicts (ranked). We'll take top self.top_k_chunks for prompt building. |
|
|
Returns structured dict: answer_json + grounding metrics + raw model output |
|
|
""" |
|
|
user_prompt = build_user_prompt(query, chunks_retrieved) |
|
|
try: |
|
|
raw_answer = self.gen_answer_agent.generate_response(user_prompt) |
|
|
except Exception as e: |
|
|
logger.exception("Error: %s", e) |
|
|
return self._get_error_message(e) |
|
|
return raw_answer, user_prompt |
|
|
|
|
|
def _get_error_message(self, e: Exception) -> Dict[str, Any]: |
|
|
return { |
|
|
"answer": "Error: " + str(e), |
|
|
"claims": [], |
|
|
"citations": [], |
|
|
"grounding": {"total_claims": 0, "supported_claims": 0, "rate": 0.0}, |
|
|
"raw_model": e |
|
|
} |
|
|
|
|
|
def _safe_not_found(self, raw_model: Optional[Any] = None) -> Dict[str, Any]: |
|
|
return { |
|
|
"answer": "Không tìm thấy quy định", |
|
|
"claims": [], |
|
|
"citations": [], |
|
|
"grounding": {"total_claims": 0, "supported_claims": 0, "rate": 0.0}, |
|
|
"raw_model": raw_model |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
chunks_retrieved = [ |
|
|
{ |
|
|
"id": "11.QD::CH3::A5::C1", |
|
|
"doc_id": "1563/QĐ-ĐHTL", |
|
|
"chunk_text": "Nếu sinh viên nghỉ học dưới 10 ngày tại kỳ chính thì không phải nộp học phí.", |
|
|
"chunk_for_embedding": None, |
|
|
"metadata": {"chapter":3, "article":5, "identifiers":["1563/QĐ-ĐHTL"]} |
|
|
}, |
|
|
{ |
|
|
"id": "11.QD::CH3::A5::C2", |
|
|
"doc_id": "1563/QĐ-ĐHTL", |
|
|
"chunk_text": "Nếu nghỉ từ 10 ngày đến dưới 20 ngày tại kỳ chính phải nộp 30% học phí.", |
|
|
"metadata": {"chapter":3, "article":5} |
|
|
} |
|
|
] |
|
|
query = "Em quên đóng học phí quá hạn có bị cấm bảo vệ khóa luận không ạ ?" |
|
|
gen = ResponseGenerator() |
|
|
out = gen.generate(query, chunks_retrieved) |