RAG2 / agent_pdfimages.py
Donlagon007's picture
Update agent_pdfimages.py
61ded4d verified
# agent.py
import os, json, glob
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.tools import Tool
from langchain.agents import initialize_agent, AgentType
# ==== 新增:CLIP 影像索引需要的套件 ====
from PIL import Image
import faiss
from sentence_transformers import SentenceTransformer
# ------------------ 基本設定 ------------------
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
BASE_DIR = Path(__file__).resolve().parent
INDEX_DIR = BASE_DIR / "faiss_breast"
EMBED_MODEL = "text-embedding-3-small"
# 影像索引路徑(請先用你的重建工具建立)
IMAGE_DIR = Path(os.getenv("IMAGE_DIR", str(BASE_DIR / "images")))
IMAGE_IDX_DIR = Path(os.getenv("IMAGE_IDX_DIR", str(BASE_DIR / "faiss_images")))
IMAGE_IDX_PATH = IMAGE_IDX_DIR / "clip.index"
IMAGE_META_PATH = IMAGE_IDX_DIR / "metadata.json"
CLIP_MODEL_NAME = os.getenv("CLIP_MODEL", "clip-ViT-L-14")
# ------------------ 只載入一次:文字索引 ------------------
# ---------- Lazy-load แค่ 12 บรรทัด ----------
class _LazyVS:
def __init__(self):
self._vs = None
def _ensure(self):
if self._vs is None:
self._vs = FAISS.load_local(
str(INDEX_DIR),
OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY),
allow_dangerous_deserialization=True,
)
# proxy เมธอดที่ app ใช้อยู่
def similarity_search(self, *args, **kwargs):
self._ensure(); return self._vs.similarity_search(*args, **kwargs)
def max_marginal_relevance_search(self, *args, **kwargs):
self._ensure(); return self._vs.max_marginal_relevance_search(*args, **kwargs)
# >>> ส่งออกชื่อเดิมให้โค้ดที่เหลือใช้ได้เหมือนเดิม
VS = _LazyVS()
# ------------------ 工具函式 ------------------
def _short(s: str, n: int = 700) -> str:
s = (s or "").strip()
return s if len(s) <= n else s[:n] + " …"
def _is_image(path: str) -> bool:
return path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"))
# ------------------ 新增:CLIP 影像索引類別與單例 ------------------
class ClipImageIndex:
"""文字↔影像同空間檢索(CLIP)"""
def __init__(self, model_name: str = CLIP_MODEL_NAME, device: str | None = None):
try:
import torch
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
device = None
self.model = SentenceTransformer(model_name, device=device) if device else SentenceTransformer(model_name)
self.index = None
self.meta: List[Dict[str, Any]] = []
def load(self, idx_path: Path, meta_path: Path):
self.index = faiss.read_index(str(idx_path))
import json as _json
with open(meta_path, "r", encoding="utf-8") as f:
self.meta = _json.load(f)
def query(self, text: str, k: int = 5) -> List[Dict[str, Any]]:
if self.index is None:
return []
q = self.model.encode([text], normalize_embeddings=True).astype("float32")
D, I = self.index.search(q, k)
out = []
for rank, idx in enumerate(I[0]):
if idx == -1:
continue
m = self.meta[idx]
out.append({
"type": "image",
"rank": rank + 1,
"score": float(D[0][rank]), # CLIP 相似分數
"image_path": m.get("path"),
"rel_path": m.get("rel_path")
})
return out
# 單例:若影像索引存在則載入,否則為 None(不影響文字 RAG)
IMG_INDEX: ClipImageIndex | None = None
if IMAGE_IDX_PATH.exists() and IMAGE_META_PATH.exists():
try:
_idx = ClipImageIndex(CLIP_MODEL_NAME)
_idx.load(IMAGE_IDX_PATH, IMAGE_META_PATH)
IMG_INDEX = _idx
print(f"[agency] Loaded image index: {IMAGE_IDX_DIR}")
except Exception as e:
print(f"[agency] WARNING: failed to load image index: {e}")
# ------------------ 檢索與融合邏輯 ------------------
K_TEXT = 5
K_IMAGE = 5
def _serialize_text_docs(docs) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
for d in docs:
meta = d.metadata or {}
items.append({
"type": "text",
"source_file": meta.get("source_file", meta.get("source", "unknown")),
"page": meta.get("page"),
"year": meta.get("year"),
"text": _short(d.page_content)
})
return items
def _rank_fusion(text_items: List[Dict], img_items: List[Dict],
w_text: float = 0.5, w_img: float = 0.5) -> List[Dict]:
"""
簡易融合:文字結果用「倒數排名分數」;影像結果用 CLIP score + 倒數排名分數。
"""
fused = []
# 文字:沒有原生分數,用排名分數 1/(rank+1)
for i, it in enumerate(text_items):
it = dict(it)
it["_fused_score"] = w_text * (1.0 / (i + 1))
fused.append(it)
# 影像:用 CLIP score + 排名分數
for j, it in enumerate(img_items):
it = dict(it)
base = float(it.get("score", 0.0))
it["_fused_score"] = w_img * (base + 1.0 / (j + 1))
fused.append(it)
fused.sort(key=lambda x: -x["_fused_score"])
for it in fused:
it.pop("_fused_score", None)
return fused
def rag_search(query: str) -> str:
"""同時做文字 +(若可用)影像檢索,回傳 JSON(含 per-modality 與 fused)。"""
# 文字(MMR 優先)
try:
text_docs = VS.max_marginal_relevance_search(query, k=K_TEXT, fetch_k=max(12, 2*K_TEXT))
except Exception:
text_docs = VS.similarity_search(query, k=K_TEXT)
text_items = _serialize_text_docs(text_docs)
# 影像(若有索引)
img_items = IMG_INDEX.query(query, k=K_IMAGE) if IMG_INDEX else []
fused = _rank_fusion(text_items, img_items, w_text=0.5, w_img=0.5)
return json.dumps({
"text_topk": text_items,
"image_topk": img_items,
"fused": fused[:10]
}, ensure_ascii=False, indent=2)
# ------------------ Tool 定義(沿用原名,內含多模態融合) ------------------
rag_tool = Tool(
name="BreastCancerRAG",
func=rag_search,
description=(
"Retrieve 3–5 relevant TEXT chunks from the breast cancer knowledge base and (if available) "
"3–5 relevant IMAGES via CLIP, then return a JSON object with 'text_topk', 'image_topk', and a 'fused' list. "
"Use this tool once per question. If evidence is insufficient, say what else is needed."
),
)
# ------------------ System Prompt(小幅增補:提示有影像) ------------------
SYSTEM_PROMPT = (
"You are an assistant specializing in breast cancer epidemiology and screening policy.\n"
"Workflow:\n"
"1) Call the tool `BreastCancerRAG` once to obtain evidence (text and, if available, images).\n"
"2) Answer ONLY based on the retrieved evidence. Do NOT fabricate.\n"
"3) If you reference an image, include its file name or relative path from the tool output.\n"
"4) If the evidence is insufficient, say so and specify what extra info is needed.\n\n"
"Answer format:\n"
"- Use bullet points or short paragraphs.\n"
"- Add citation tags like [Wu 2013, p.X] or [Yen 2017, p.Y] for text.\n"
"- Mark general knowledge as '(general knowledge)'."
)
# ------------------ 建立 Agent ------------------
def build_agent():
llm_direct = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=OPENAI_API_KEY)
agent = initialize_agent(
tools=[rag_tool],
llm=llm_direct,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True,
handle_parsing_errors=True,
max_iterations=3,
max_execution_time=60,
early_stopping_method="generate",
system_message=SYSTEM_PROMPT,
memory=memory,
)
return agent