Spaces:
Sleeping
Sleeping
| # rag.py | |
| import os | |
| from dotenv import load_dotenv | |
| # 新增:用於定義結構化輸出格式 | |
| from typing import List, Dict, Any | |
| from pydantic import BaseModel, Field | |
| from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
| from langchain_core.prompts import PromptTemplate # 確保導入這個,用於 HumanMessage 的子模板 | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from models.model_wrapper import get_llm | |
| from db.postgres_init import get_vectordb | |
| from langchain_core.exceptions import OutputParserException | |
| # --- 🎯 食譜 Pydantic 結構定義 --- | |
| class Ingredient(BaseModel): | |
| name: str = Field(description="材料名稱,例如:豬五花肉") | |
| amount: str = Field(description="份量/數量,例如:300克 或 2大匙") | |
| class Recipe(BaseModel): | |
| """用於儲存完整食譜的 JSON 結構""" | |
| title: str = Field(description="食譜的繁體中文名稱") | |
| ingredients: List[Ingredient] = Field(description="所有材料的清單") | |
| steps: List[str] = Field(description="詳細的步驟說明") | |
| notes: List[str] = Field(description="食譜的額外提醒或替代食材建議") | |
| # 定義一個包含多個 Recipe 的容器 | |
| class RecipeList(BaseModel): | |
| """用於儲存多個完整食譜的列表結構""" | |
| recipes: List[Recipe] = Field( | |
| description="一個包含多個食譜物件(Recipe)的列表。" | |
| ) | |
| # ----------------------------------- | |
| # --- 將 PROMPT_TMPL 內容拆分 --- | |
| # 1. 系統提示 (System Prompt) - 放置角色、格式和主要限制 | |
| SYS_TMPL = """ | |
| 你是一位專業中文料理師傅。 | |
| 請生成一個完整、易懂的繁體中文菜譜。 | |
| 請嚴格以 **純 JSON 格式** 輸出,且內容必須符合指定的 Schema。 | |
| 以下是輸出限制: | |
| - 優先使用資料庫中的資訊 | |
| - 禁止憑空編造不存在的材料 | |
| - 若必要,請在 notes 加上替代食材建議 | |
| """ | |
| # 2. 用戶/輸入提示 (Human Prompt) - 放置變量輸入 | |
| HUMAN_TMPL = """ | |
| 使用者需求:{query} | |
| 請根據需求和資料庫上下文,生成**兩個**不同的食譜。 | |
| 資料庫上下文: | |
| {context} | |
| """ | |
| # --- 建立 ChatPromptTemplate --- | |
| # A. 建立 System Message Template | |
| system_message_prompt = SystemMessagePromptTemplate.from_template(SYS_TMPL) | |
| # B. 建立 Human Message Template (使用 PromptTemplate 包裝變量) | |
| human_message_prompt = HumanMessagePromptTemplate( | |
| prompt=PromptTemplate( | |
| input_variables=["query", "context"], | |
| template=HUMAN_TMPL | |
| ) | |
| ) | |
| # C. 組合 ChatPromptTemplate | |
| base_prompt = ChatPromptTemplate.from_messages([ | |
| system_message_prompt, | |
| human_message_prompt | |
| ]) | |
| # --- 建立 RAG function --- | |
| def build_rag_chain(k=4): | |
| db = get_vectordb() | |
| llm = get_llm() | |
| retriever = db.as_retriever(search_kwargs={"k": k}) | |
| structured_llm = llm.with_structured_output(RecipeList) | |
| def get_context_and_query(query: str): | |
| # 這裡的 retriever.invoke() 現在會對 PostgreSQL 執行向量相似性搜索 | |
| docs = retriever.invoke(query) | |
| context = "\n".join([d.page_content for d in docs]) | |
| return {"context": context, "query": query, "docs": docs} | |
| # 調整 rag 函式以返回更清晰的結果 | |
| def rag(query: str): | |
| docs = [] | |
| # ---------------------------------------------------- | |
| # 1. 執行檢索 (RunnableLambda 讓我們在 LCEL 外執行並拿到中間結果) | |
| try: | |
| input_data = get_context_and_query(query) | |
| docs = input_data.pop("docs") | |
| # ---------------------------------------------------- | |
| # 2. 建立 PromptValue | |
| prompt_value = base_prompt.invoke(input_data) | |
| # ---------------------------------------------------- | |
| # 3. 呼叫 LLM 並解析 JSON 輸出 | |
| answer = structured_llm.invoke(prompt_value) | |
| result_dict = answer.dict() | |
| final_list = result_dict.get('recipes', []) | |
| return {"result": final_list, "source_documents": docs} | |
| except OutputParserException as e: | |
| return {"result": {"error": "LLM 輸出格式錯誤,無法解析 JSON"}, "source_documents": docs} | |
| except Exception as e: | |
| return {"result": {"error": f"LLM 呼叫失敗: {e}"}, "source_documents": docs} | |
| # ---------------------------------------------------- | |
| return rag |