Spaces:
Sleeping
Sleeping
File size: 4,431 Bytes
a9e6507 a4137be a9e6507 a4137be a9e6507 a4137be a9e6507 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | # rag.py
import os
from dotenv import load_dotenv
# 新增:用於定義結構化輸出格式
from typing import List, Dict, Any
from pydantic import BaseModel, Field
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.prompts import PromptTemplate # 確保導入這個,用於 HumanMessage 的子模板
from langchain_core.prompts import ChatPromptTemplate
from models.model_wrapper import get_llm
from db.postgres_init import get_vectordb
from langchain_core.exceptions import OutputParserException
# --- 🎯 食譜 Pydantic 結構定義 ---
class Ingredient(BaseModel):
name: str = Field(description="材料名稱,例如:豬五花肉")
amount: str = Field(description="份量/數量,例如:300克 或 2大匙")
class Recipe(BaseModel):
"""用於儲存完整食譜的 JSON 結構"""
title: str = Field(description="食譜的繁體中文名稱")
ingredients: List[Ingredient] = Field(description="所有材料的清單")
steps: List[str] = Field(description="詳細的步驟說明")
notes: List[str] = Field(description="食譜的額外提醒或替代食材建議")
# 定義一個包含多個 Recipe 的容器
class RecipeList(BaseModel):
"""用於儲存多個完整食譜的列表結構"""
recipes: List[Recipe] = Field(
description="一個包含多個食譜物件(Recipe)的列表。"
)
# -----------------------------------
# --- 將 PROMPT_TMPL 內容拆分 ---
# 1. 系統提示 (System Prompt) - 放置角色、格式和主要限制
SYS_TMPL = """
你是一位專業中文料理師傅。
請生成一個完整、易懂的繁體中文菜譜。
請嚴格以 **純 JSON 格式** 輸出,且內容必須符合指定的 Schema。
以下是輸出限制:
- 優先使用資料庫中的資訊
- 禁止憑空編造不存在的材料
- 若必要,請在 notes 加上替代食材建議
"""
# 2. 用戶/輸入提示 (Human Prompt) - 放置變量輸入
HUMAN_TMPL = """
使用者需求:{query}
請根據需求和資料庫上下文,生成**兩個**不同的食譜。
資料庫上下文:
{context}
"""
# --- 建立 ChatPromptTemplate ---
# A. 建立 System Message Template
system_message_prompt = SystemMessagePromptTemplate.from_template(SYS_TMPL)
# B. 建立 Human Message Template (使用 PromptTemplate 包裝變量)
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
input_variables=["query", "context"],
template=HUMAN_TMPL
)
)
# C. 組合 ChatPromptTemplate
base_prompt = ChatPromptTemplate.from_messages([
system_message_prompt,
human_message_prompt
])
# --- 建立 RAG function ---
def build_rag_chain(k=4):
db = get_vectordb()
llm = get_llm()
retriever = db.as_retriever(search_kwargs={"k": k})
structured_llm = llm.with_structured_output(RecipeList)
def get_context_and_query(query: str):
# 這裡的 retriever.invoke() 現在會對 PostgreSQL 執行向量相似性搜索
docs = retriever.invoke(query)
context = "\n".join([d.page_content for d in docs])
return {"context": context, "query": query, "docs": docs}
# 調整 rag 函式以返回更清晰的結果
def rag(query: str):
docs = []
# ----------------------------------------------------
# 1. 執行檢索 (RunnableLambda 讓我們在 LCEL 外執行並拿到中間結果)
try:
input_data = get_context_and_query(query)
docs = input_data.pop("docs")
# ----------------------------------------------------
# 2. 建立 PromptValue
prompt_value = base_prompt.invoke(input_data)
# ----------------------------------------------------
# 3. 呼叫 LLM 並解析 JSON 輸出
answer = structured_llm.invoke(prompt_value)
result_dict = answer.dict()
final_list = result_dict.get('recipes', [])
return {"result": final_list, "source_documents": docs}
except OutputParserException as e:
return {"result": {"error": "LLM 輸出格式錯誤,無法解析 JSON"}, "source_documents": docs}
except Exception as e:
return {"result": {"error": f"LLM 呼叫失敗: {e}"}, "source_documents": docs}
# ----------------------------------------------------
return rag |