Donlagon007 commited on
Commit
61ded4d
·
verified ·
1 Parent(s): 5ec62e7

Update agent_pdfimages.py

Browse files
Files changed (1) hide show
  1. agent_pdfimages.py +216 -201
agent_pdfimages.py CHANGED
@@ -1,201 +1,216 @@
1
- # agent.py
2
-
3
- import os, json, glob
4
- from pathlib import Path
5
- from typing import List, Dict, Any
6
- import numpy as np
7
-
8
- from dotenv import load_dotenv
9
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
10
- from langchain_community.vectorstores import FAISS
11
- from langchain.tools import Tool
12
- from langchain.agents import initialize_agent, AgentType
13
-
14
- # ==== 新增:CLIP 影像索引需要的套件 ====
15
- from PIL import Image
16
- import faiss
17
- from sentence_transformers import SentenceTransformer
18
-
19
- # ------------------ 基本設定 ------------------
20
- load_dotenv()
21
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
22
-
23
- BASE_DIR = Path(__file__).resolve().parent
24
- INDEX_DIR = BASE_DIR / "faiss_breast"
25
- EMBED_MODEL = "text-embedding-3-small"
26
-
27
- # 影像索引路徑(請先用你的重建工具建立)
28
- IMAGE_DIR = Path(os.getenv("IMAGE_DIR", str(BASE_DIR / "images")))
29
- IMAGE_IDX_DIR = Path(os.getenv("IMAGE_IDX_DIR", str(BASE_DIR / "faiss_images")))
30
- IMAGE_IDX_PATH = IMAGE_IDX_DIR / "clip.index"
31
- IMAGE_META_PATH = IMAGE_IDX_DIR / "metadata.json"
32
- CLIP_MODEL_NAME = os.getenv("CLIP_MODEL", "clip-ViT-L-14")
33
-
34
- # ------------------ 只載入一次:文字索引 ------------------
35
- VS = FAISS.load_local(
36
- str(INDEX_DIR),
37
- OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY),
38
- allow_dangerous_deserialization=True
39
- )
40
-
41
- # ------------------ 工具函式 ------------------
42
- def _short(s: str, n: int = 700) -> str:
43
- s = (s or "").strip()
44
- return s if len(s) <= n else s[:n] + " …"
45
-
46
- def _is_image(path: str) -> bool:
47
- return path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"))
48
-
49
- # ------------------ 新增:CLIP 影像索引類別與單例 ------------------
50
- class ClipImageIndex:
51
- """文字↔影像同空間檢索(CLIP)"""
52
- def __init__(self, model_name: str = CLIP_MODEL_NAME, device: str | None = None):
53
- try:
54
- import torch
55
- if device is None:
56
- device = "cuda" if torch.cuda.is_available() else "cpu"
57
- except Exception:
58
- device = None
59
- self.model = SentenceTransformer(model_name, device=device) if device else SentenceTransformer(model_name)
60
- self.index = None
61
- self.meta: List[Dict[str, Any]] = []
62
-
63
- def load(self, idx_path: Path, meta_path: Path):
64
- self.index = faiss.read_index(str(idx_path))
65
- import json as _json
66
- with open(meta_path, "r", encoding="utf-8") as f:
67
- self.meta = _json.load(f)
68
-
69
- def query(self, text: str, k: int = 5) -> List[Dict[str, Any]]:
70
- if self.index is None:
71
- return []
72
- q = self.model.encode([text], normalize_embeddings=True).astype("float32")
73
- D, I = self.index.search(q, k)
74
- out = []
75
- for rank, idx in enumerate(I[0]):
76
- if idx == -1:
77
- continue
78
- m = self.meta[idx]
79
- out.append({
80
- "type": "image",
81
- "rank": rank + 1,
82
- "score": float(D[0][rank]), # CLIP 相似分數
83
- "image_path": m.get("path"),
84
- "rel_path": m.get("rel_path")
85
- })
86
- return out
87
-
88
- # 單例:若影像索引存在則載入,否則為 None(不影響文字 RAG)
89
- IMG_INDEX: ClipImageIndex | None = None
90
- if IMAGE_IDX_PATH.exists() and IMAGE_META_PATH.exists():
91
- try:
92
- _idx = ClipImageIndex(CLIP_MODEL_NAME)
93
- _idx.load(IMAGE_IDX_PATH, IMAGE_META_PATH)
94
- IMG_INDEX = _idx
95
- print(f"[agency] Loaded image index: {IMAGE_IDX_DIR}")
96
- except Exception as e:
97
- print(f"[agency] WARNING: failed to load image index: {e}")
98
-
99
- # ------------------ 檢索與融合邏輯 ------------------
100
- K_TEXT = 5
101
- K_IMAGE = 5
102
-
103
- def _serialize_text_docs(docs) -> List[Dict[str, Any]]:
104
- items: List[Dict[str, Any]] = []
105
- for d in docs:
106
- meta = d.metadata or {}
107
- items.append({
108
- "type": "text",
109
- "source_file": meta.get("source_file", meta.get("source", "unknown")),
110
- "page": meta.get("page"),
111
- "year": meta.get("year"),
112
- "text": _short(d.page_content)
113
- })
114
- return items
115
-
116
- def _rank_fusion(text_items: List[Dict], img_items: List[Dict],
117
- w_text: float = 0.5, w_img: float = 0.5) -> List[Dict]:
118
- """
119
- 簡易融合:文字結果用「倒數排名分數」;影像結果用 CLIP score + 倒數排名分數。
120
- """
121
- fused = []
122
-
123
- # 文字:沒有原生分數,用排名分數 1/(rank+1)
124
- for i, it in enumerate(text_items):
125
- it = dict(it)
126
- it["_fused_score"] = w_text * (1.0 / (i + 1))
127
- fused.append(it)
128
-
129
- # 影像:用 CLIP score + 排名分數
130
- for j, it in enumerate(img_items):
131
- it = dict(it)
132
- base = float(it.get("score", 0.0))
133
- it["_fused_score"] = w_img * (base + 1.0 / (j + 1))
134
- fused.append(it)
135
-
136
- fused.sort(key=lambda x: -x["_fused_score"])
137
- for it in fused:
138
- it.pop("_fused_score", None)
139
- return fused
140
-
141
- def rag_search(query: str) -> str:
142
- """同時做文字 +(若可用)影像檢索,回傳 JSON(含 per-modality 與 fused)。"""
143
- # 文字(MMR 優先)
144
- try:
145
- text_docs = VS.max_marginal_relevance_search(query, k=K_TEXT, fetch_k=max(12, 2*K_TEXT))
146
- except Exception:
147
- text_docs = VS.similarity_search(query, k=K_TEXT)
148
- text_items = _serialize_text_docs(text_docs)
149
-
150
- # 影像(若有索引)
151
- img_items = IMG_INDEX.query(query, k=K_IMAGE) if IMG_INDEX else []
152
-
153
- fused = _rank_fusion(text_items, img_items, w_text=0.5, w_img=0.5)
154
-
155
- return json.dumps({
156
- "text_topk": text_items,
157
- "image_topk": img_items,
158
- "fused": fused[:10]
159
- }, ensure_ascii=False, indent=2)
160
-
161
- # ------------------ Tool 定義(沿用原名,內含多模態融合) ------------------
162
- rag_tool = Tool(
163
- name="BreastCancerRAG",
164
- func=rag_search,
165
- description=(
166
- "Retrieve 3–5 relevant TEXT chunks from the breast cancer knowledge base and (if available) "
167
- "3–5 relevant IMAGES via CLIP, then return a JSON object with 'text_topk', 'image_topk', and a 'fused' list. "
168
- "Use this tool once per question. If evidence is insufficient, say what else is needed."
169
- ),
170
- )
171
-
172
- # ------------------ System Prompt(小幅增補:提示有影像) ------------------
173
- SYSTEM_PROMPT = (
174
- "You are an assistant specializing in breast cancer epidemiology and screening policy.\n"
175
- "Workflow:\n"
176
- "1) Call the tool `BreastCancerRAG` once to obtain evidence (text and, if available, images).\n"
177
- "2) Answer ONLY based on the retrieved evidence. Do NOT fabricate.\n"
178
- "3) If you reference an image, include its file name or relative path from the tool output.\n"
179
- "4) If the evidence is insufficient, say so and specify what extra info is needed.\n\n"
180
- "Answer format:\n"
181
- "- Use bullet points or short paragraphs.\n"
182
- "- Add citation tags like [Wu 2013, p.X] or [Yen 2017, p.Y] for text.\n"
183
- "- Mark general knowledge as '(general knowledge)'."
184
- )
185
-
186
- # ------------------ 建立 Agent ------------------
187
- def build_agent():
188
- llm_direct = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=OPENAI_API_KEY)
189
- agent = initialize_agent(
190
- tools=[rag_tool],
191
- llm=llm_direct,
192
- agent=AgentType.OPENAI_FUNCTIONS,
193
- verbose=True,
194
- handle_parsing_errors=True,
195
- max_iterations=3,
196
- max_execution_time=60,
197
- early_stopping_method="generate",
198
- system_message=SYSTEM_PROMPT,
199
- memory=memory,
200
- )
201
- return agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent.py
2
+
3
+ import os, json, glob
4
+ from pathlib import Path
5
+ from typing import List, Dict, Any
6
+ import numpy as np
7
+
8
+ from dotenv import load_dotenv
9
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain.tools import Tool
12
+ from langchain.agents import initialize_agent, AgentType
13
+
14
+ # ==== 新增:CLIP 影像索引需要的套件 ====
15
+ from PIL import Image
16
+ import faiss
17
+ from sentence_transformers import SentenceTransformer
18
+
19
+ # ------------------ 基本設定 ------------------
20
+ load_dotenv()
21
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
22
+
23
+ BASE_DIR = Path(__file__).resolve().parent
24
+ INDEX_DIR = BASE_DIR / "faiss_breast"
25
+ EMBED_MODEL = "text-embedding-3-small"
26
+
27
+ # 影像索引路徑(請先用你的重建工具建立)
28
+ IMAGE_DIR = Path(os.getenv("IMAGE_DIR", str(BASE_DIR / "images")))
29
+ IMAGE_IDX_DIR = Path(os.getenv("IMAGE_IDX_DIR", str(BASE_DIR / "faiss_images")))
30
+ IMAGE_IDX_PATH = IMAGE_IDX_DIR / "clip.index"
31
+ IMAGE_META_PATH = IMAGE_IDX_DIR / "metadata.json"
32
+ CLIP_MODEL_NAME = os.getenv("CLIP_MODEL", "clip-ViT-L-14")
33
+
34
+ # ------------------ 只載入一次:文字索引 ------------------
35
+
36
+ # ---------- Lazy-load แค่ 12 บรรทัด ----------
37
+ class _LazyVS:
38
+ def __init__(self):
39
+ self._vs = None
40
+ def _ensure(self):
41
+ if self._vs is None:
42
+ self._vs = FAISS.load_local(
43
+ str(INDEX_DIR),
44
+ OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY),
45
+ allow_dangerous_deserialization=True,
46
+ )
47
+ # proxy เมธอดที่ app ใช้อยู่
48
+ def similarity_search(self, *args, **kwargs):
49
+ self._ensure(); return self._vs.similarity_search(*args, **kwargs)
50
+ def max_marginal_relevance_search(self, *args, **kwargs):
51
+ self._ensure(); return self._vs.max_marginal_relevance_search(*args, **kwargs)
52
+
53
+ # >>> ส่งออกชื่อเดิมให้โค้ดที่เหลือใช้ได้เหมือนเดิม
54
+ VS = _LazyVS()
55
+
56
+ # ------------------ 工具函式 ------------------
57
+ def _short(s: str, n: int = 700) -> str:
58
+ s = (s or "").strip()
59
+ return s if len(s) <= n else s[:n] + " …"
60
+
61
+ def _is_image(path: str) -> bool:
62
+ return path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"))
63
+
64
+ # ------------------ 新增:CLIP 影像索引類別與單例 ------------------
65
+ class ClipImageIndex:
66
+ """文字↔影像同空間檢索(CLIP)"""
67
+ def __init__(self, model_name: str = CLIP_MODEL_NAME, device: str | None = None):
68
+ try:
69
+ import torch
70
+ if device is None:
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ except Exception:
73
+ device = None
74
+ self.model = SentenceTransformer(model_name, device=device) if device else SentenceTransformer(model_name)
75
+ self.index = None
76
+ self.meta: List[Dict[str, Any]] = []
77
+
78
+ def load(self, idx_path: Path, meta_path: Path):
79
+ self.index = faiss.read_index(str(idx_path))
80
+ import json as _json
81
+ with open(meta_path, "r", encoding="utf-8") as f:
82
+ self.meta = _json.load(f)
83
+
84
+ def query(self, text: str, k: int = 5) -> List[Dict[str, Any]]:
85
+ if self.index is None:
86
+ return []
87
+ q = self.model.encode([text], normalize_embeddings=True).astype("float32")
88
+ D, I = self.index.search(q, k)
89
+ out = []
90
+ for rank, idx in enumerate(I[0]):
91
+ if idx == -1:
92
+ continue
93
+ m = self.meta[idx]
94
+ out.append({
95
+ "type": "image",
96
+ "rank": rank + 1,
97
+ "score": float(D[0][rank]), # CLIP 相似分數
98
+ "image_path": m.get("path"),
99
+ "rel_path": m.get("rel_path")
100
+ })
101
+ return out
102
+
103
+ # 單例:若影像索引存在則載入,否則為 None(不影響文字 RAG)
104
+ IMG_INDEX: ClipImageIndex | None = None
105
+ if IMAGE_IDX_PATH.exists() and IMAGE_META_PATH.exists():
106
+ try:
107
+ _idx = ClipImageIndex(CLIP_MODEL_NAME)
108
+ _idx.load(IMAGE_IDX_PATH, IMAGE_META_PATH)
109
+ IMG_INDEX = _idx
110
+ print(f"[agency] Loaded image index: {IMAGE_IDX_DIR}")
111
+ except Exception as e:
112
+ print(f"[agency] WARNING: failed to load image index: {e}")
113
+
114
+ # ------------------ 檢索與融合邏輯 ------------------
115
+ K_TEXT = 5
116
+ K_IMAGE = 5
117
+
118
+ def _serialize_text_docs(docs) -> List[Dict[str, Any]]:
119
+ items: List[Dict[str, Any]] = []
120
+ for d in docs:
121
+ meta = d.metadata or {}
122
+ items.append({
123
+ "type": "text",
124
+ "source_file": meta.get("source_file", meta.get("source", "unknown")),
125
+ "page": meta.get("page"),
126
+ "year": meta.get("year"),
127
+ "text": _short(d.page_content)
128
+ })
129
+ return items
130
+
131
+ def _rank_fusion(text_items: List[Dict], img_items: List[Dict],
132
+ w_text: float = 0.5, w_img: float = 0.5) -> List[Dict]:
133
+ """
134
+ 簡易融合:文字結果用「倒數排名分數」;影像結果用 CLIP score + 倒數排名分數。
135
+ """
136
+ fused = []
137
+
138
+ # 文字:沒有原生分數,用排名分數 1/(rank+1)
139
+ for i, it in enumerate(text_items):
140
+ it = dict(it)
141
+ it["_fused_score"] = w_text * (1.0 / (i + 1))
142
+ fused.append(it)
143
+
144
+ # 影像:用 CLIP score + 排名分數
145
+ for j, it in enumerate(img_items):
146
+ it = dict(it)
147
+ base = float(it.get("score", 0.0))
148
+ it["_fused_score"] = w_img * (base + 1.0 / (j + 1))
149
+ fused.append(it)
150
+
151
+ fused.sort(key=lambda x: -x["_fused_score"])
152
+ for it in fused:
153
+ it.pop("_fused_score", None)
154
+ return fused
155
+
156
+ def rag_search(query: str) -> str:
157
+ """同時做文字 +(若可用)影像檢索,回傳 JSON(含 per-modality 與 fused)。"""
158
+ # 文字(MMR 優先)
159
+ try:
160
+ text_docs = VS.max_marginal_relevance_search(query, k=K_TEXT, fetch_k=max(12, 2*K_TEXT))
161
+ except Exception:
162
+ text_docs = VS.similarity_search(query, k=K_TEXT)
163
+ text_items = _serialize_text_docs(text_docs)
164
+
165
+ # 影像(若有索引)
166
+ img_items = IMG_INDEX.query(query, k=K_IMAGE) if IMG_INDEX else []
167
+
168
+ fused = _rank_fusion(text_items, img_items, w_text=0.5, w_img=0.5)
169
+
170
+ return json.dumps({
171
+ "text_topk": text_items,
172
+ "image_topk": img_items,
173
+ "fused": fused[:10]
174
+ }, ensure_ascii=False, indent=2)
175
+
176
+ # ------------------ Tool 定義(沿用原名,內含多模態融合) ------------------
177
+ rag_tool = Tool(
178
+ name="BreastCancerRAG",
179
+ func=rag_search,
180
+ description=(
181
+ "Retrieve 3–5 relevant TEXT chunks from the breast cancer knowledge base and (if available) "
182
+ "3–5 relevant IMAGES via CLIP, then return a JSON object with 'text_topk', 'image_topk', and a 'fused' list. "
183
+ "Use this tool once per question. If evidence is insufficient, say what else is needed."
184
+ ),
185
+ )
186
+
187
+ # ------------------ System Prompt(小幅增補:提示有影像) ------------------
188
+ SYSTEM_PROMPT = (
189
+ "You are an assistant specializing in breast cancer epidemiology and screening policy.\n"
190
+ "Workflow:\n"
191
+ "1) Call the tool `BreastCancerRAG` once to obtain evidence (text and, if available, images).\n"
192
+ "2) Answer ONLY based on the retrieved evidence. Do NOT fabricate.\n"
193
+ "3) If you reference an image, include its file name or relative path from the tool output.\n"
194
+ "4) If the evidence is insufficient, say so and specify what extra info is needed.\n\n"
195
+ "Answer format:\n"
196
+ "- Use bullet points or short paragraphs.\n"
197
+ "- Add citation tags like [Wu 2013, p.X] or [Yen 2017, p.Y] for text.\n"
198
+ "- Mark general knowledge as '(general knowledge)'."
199
+ )
200
+
201
+ # ------------------ 建立 Agent ------------------
202
+ def build_agent():
203
+ llm_direct = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=OPENAI_API_KEY)
204
+ agent = initialize_agent(
205
+ tools=[rag_tool],
206
+ llm=llm_direct,
207
+ agent=AgentType.OPENAI_FUNCTIONS,
208
+ verbose=True,
209
+ handle_parsing_errors=True,
210
+ max_iterations=3,
211
+ max_execution_time=60,
212
+ early_stopping_method="generate",
213
+ system_message=SYSTEM_PROMPT,
214
+ memory=memory,
215
+ )
216
+ return agent