Spaces:
Runtime error
Runtime error
| # agent.py | |
| import os, json, glob | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.tools import Tool | |
| from langchain.agents import initialize_agent, AgentType | |
| # ==== 新增:CLIP 影像索引需要的套件 ==== | |
| from PIL import Image | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| # ------------------ 基本設定 ------------------ | |
| load_dotenv() | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| BASE_DIR = Path(__file__).resolve().parent | |
| INDEX_DIR = BASE_DIR / "faiss_breast" | |
| EMBED_MODEL = "text-embedding-3-small" | |
| # 影像索引路徑(請先用你的重建工具建立) | |
| IMAGE_DIR = Path(os.getenv("IMAGE_DIR", str(BASE_DIR / "images"))) | |
| IMAGE_IDX_DIR = Path(os.getenv("IMAGE_IDX_DIR", str(BASE_DIR / "faiss_images"))) | |
| IMAGE_IDX_PATH = IMAGE_IDX_DIR / "clip.index" | |
| IMAGE_META_PATH = IMAGE_IDX_DIR / "metadata.json" | |
| CLIP_MODEL_NAME = os.getenv("CLIP_MODEL", "clip-ViT-L-14") | |
| # ------------------ 只載入一次:文字索引 ------------------ | |
| # ---------- Lazy-load แค่ 12 บรรทัด ---------- | |
| class _LazyVS: | |
| def __init__(self): | |
| self._vs = None | |
| def _ensure(self): | |
| if self._vs is None: | |
| self._vs = FAISS.load_local( | |
| str(INDEX_DIR), | |
| OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY), | |
| allow_dangerous_deserialization=True, | |
| ) | |
| # proxy เมธอดที่ app ใช้อยู่ | |
| def similarity_search(self, *args, **kwargs): | |
| self._ensure(); return self._vs.similarity_search(*args, **kwargs) | |
| def max_marginal_relevance_search(self, *args, **kwargs): | |
| self._ensure(); return self._vs.max_marginal_relevance_search(*args, **kwargs) | |
| # >>> ส่งออกชื่อเดิมให้โค้ดที่เหลือใช้ได้เหมือนเดิม | |
| VS = _LazyVS() | |
| # ------------------ 工具函式 ------------------ | |
| def _short(s: str, n: int = 700) -> str: | |
| s = (s or "").strip() | |
| return s if len(s) <= n else s[:n] + " …" | |
| def _is_image(path: str) -> bool: | |
| return path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp")) | |
| # ------------------ 新增:CLIP 影像索引類別與單例 ------------------ | |
| class ClipImageIndex: | |
| """文字↔影像同空間檢索(CLIP)""" | |
| def __init__(self, model_name: str = CLIP_MODEL_NAME, device: str | None = None): | |
| try: | |
| import torch | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| except Exception: | |
| device = None | |
| self.model = SentenceTransformer(model_name, device=device) if device else SentenceTransformer(model_name) | |
| self.index = None | |
| self.meta: List[Dict[str, Any]] = [] | |
| def load(self, idx_path: Path, meta_path: Path): | |
| self.index = faiss.read_index(str(idx_path)) | |
| import json as _json | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| self.meta = _json.load(f) | |
| def query(self, text: str, k: int = 5) -> List[Dict[str, Any]]: | |
| if self.index is None: | |
| return [] | |
| q = self.model.encode([text], normalize_embeddings=True).astype("float32") | |
| D, I = self.index.search(q, k) | |
| out = [] | |
| for rank, idx in enumerate(I[0]): | |
| if idx == -1: | |
| continue | |
| m = self.meta[idx] | |
| out.append({ | |
| "type": "image", | |
| "rank": rank + 1, | |
| "score": float(D[0][rank]), # CLIP 相似分數 | |
| "image_path": m.get("path"), | |
| "rel_path": m.get("rel_path") | |
| }) | |
| return out | |
| # 單例:若影像索引存在則載入,否則為 None(不影響文字 RAG) | |
| IMG_INDEX: ClipImageIndex | None = None | |
| if IMAGE_IDX_PATH.exists() and IMAGE_META_PATH.exists(): | |
| try: | |
| _idx = ClipImageIndex(CLIP_MODEL_NAME) | |
| _idx.load(IMAGE_IDX_PATH, IMAGE_META_PATH) | |
| IMG_INDEX = _idx | |
| print(f"[agency] Loaded image index: {IMAGE_IDX_DIR}") | |
| except Exception as e: | |
| print(f"[agency] WARNING: failed to load image index: {e}") | |
| # ------------------ 檢索與融合邏輯 ------------------ | |
| K_TEXT = 5 | |
| K_IMAGE = 5 | |
| def _serialize_text_docs(docs) -> List[Dict[str, Any]]: | |
| items: List[Dict[str, Any]] = [] | |
| for d in docs: | |
| meta = d.metadata or {} | |
| items.append({ | |
| "type": "text", | |
| "source_file": meta.get("source_file", meta.get("source", "unknown")), | |
| "page": meta.get("page"), | |
| "year": meta.get("year"), | |
| "text": _short(d.page_content) | |
| }) | |
| return items | |
| def _rank_fusion(text_items: List[Dict], img_items: List[Dict], | |
| w_text: float = 0.5, w_img: float = 0.5) -> List[Dict]: | |
| """ | |
| 簡易融合:文字結果用「倒數排名分數」;影像結果用 CLIP score + 倒數排名分數。 | |
| """ | |
| fused = [] | |
| # 文字:沒有原生分數,用排名分數 1/(rank+1) | |
| for i, it in enumerate(text_items): | |
| it = dict(it) | |
| it["_fused_score"] = w_text * (1.0 / (i + 1)) | |
| fused.append(it) | |
| # 影像:用 CLIP score + 排名分數 | |
| for j, it in enumerate(img_items): | |
| it = dict(it) | |
| base = float(it.get("score", 0.0)) | |
| it["_fused_score"] = w_img * (base + 1.0 / (j + 1)) | |
| fused.append(it) | |
| fused.sort(key=lambda x: -x["_fused_score"]) | |
| for it in fused: | |
| it.pop("_fused_score", None) | |
| return fused | |
| def rag_search(query: str) -> str: | |
| """同時做文字 +(若可用)影像檢索,回傳 JSON(含 per-modality 與 fused)。""" | |
| # 文字(MMR 優先) | |
| try: | |
| text_docs = VS.max_marginal_relevance_search(query, k=K_TEXT, fetch_k=max(12, 2*K_TEXT)) | |
| except Exception: | |
| text_docs = VS.similarity_search(query, k=K_TEXT) | |
| text_items = _serialize_text_docs(text_docs) | |
| # 影像(若有索引) | |
| img_items = IMG_INDEX.query(query, k=K_IMAGE) if IMG_INDEX else [] | |
| fused = _rank_fusion(text_items, img_items, w_text=0.5, w_img=0.5) | |
| return json.dumps({ | |
| "text_topk": text_items, | |
| "image_topk": img_items, | |
| "fused": fused[:10] | |
| }, ensure_ascii=False, indent=2) | |
| # ------------------ Tool 定義(沿用原名,內含多模態融合) ------------------ | |
| rag_tool = Tool( | |
| name="BreastCancerRAG", | |
| func=rag_search, | |
| description=( | |
| "Retrieve 3–5 relevant TEXT chunks from the breast cancer knowledge base and (if available) " | |
| "3–5 relevant IMAGES via CLIP, then return a JSON object with 'text_topk', 'image_topk', and a 'fused' list. " | |
| "Use this tool once per question. If evidence is insufficient, say what else is needed." | |
| ), | |
| ) | |
| # ------------------ System Prompt(小幅增補:提示有影像) ------------------ | |
| SYSTEM_PROMPT = ( | |
| "You are an assistant specializing in breast cancer epidemiology and screening policy.\n" | |
| "Workflow:\n" | |
| "1) Call the tool `BreastCancerRAG` once to obtain evidence (text and, if available, images).\n" | |
| "2) Answer ONLY based on the retrieved evidence. Do NOT fabricate.\n" | |
| "3) If you reference an image, include its file name or relative path from the tool output.\n" | |
| "4) If the evidence is insufficient, say so and specify what extra info is needed.\n\n" | |
| "Answer format:\n" | |
| "- Use bullet points or short paragraphs.\n" | |
| "- Add citation tags like [Wu 2013, p.X] or [Yen 2017, p.Y] for text.\n" | |
| "- Mark general knowledge as '(general knowledge)'." | |
| ) | |
| # ------------------ 建立 Agent ------------------ | |
| def build_agent(): | |
| llm_direct = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=OPENAI_API_KEY) | |
| agent = initialize_agent( | |
| tools=[rag_tool], | |
| llm=llm_direct, | |
| agent=AgentType.OPENAI_FUNCTIONS, | |
| verbose=True, | |
| handle_parsing_errors=True, | |
| max_iterations=3, | |
| max_execution_time=60, | |
| early_stopping_method="generate", | |
| system_message=SYSTEM_PROMPT, | |
| memory=memory, | |
| ) | |
| return agent | |