RAG2 / agent_ml.py
Donlagon007's picture
Upload 35 files
0070833 verified
# ============================================================
# agency_multimodal.py
# ============================================================
import os, json, base64
from pathlib import Path
import numpy as np
from PIL import Image
import pytesseract
from transformers import BlipProcessor, BlipForConditionalGeneration
# from agent import build_agent, VS ← 原本這行註解掉
from agent_pdfimages import VS, IMG_INDEX, rag_tool, SYSTEM_PROMPT # 🔁 從原檔匯入
from langchain.tools import Tool
from langchain.agents import initialize_agent, AgentType
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
# ============================================================
# 1️⃣ 基本設定
# ============================================================
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise RuntimeError("❌ Please set OPENAI_API_KEY in .env")
BASE_DIR = Path(__file__).resolve().parent
_emb = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY)
# ============================================================
# 2️⃣ 載入 Caption / OCR 模型
# ============================================================
# 初始化 BLIP(當 GPT-4o 失敗時用)
_caption_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
_caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
def caption_image(path: str) -> str:
"""
先嘗試用 GPT-4o 產生影像描述(modality、region、abnormality)
若 GPT-4o 失敗,則自動 fallback 到 BLIP。
"""
try:
# -------- GPT-4o 影像理解 --------
with open(path, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
llm = ChatOpenAI(model="gpt-4o", temperature=0.0, openai_api_key=OPENAI_API_KEY)
user_content = [
{
"role": "user",
"content": [
{
"type": "text",
"text": (
"Describe this medical image in one concise English sentence. "
"Identify the modality (e.g., mammogram, pathology slide, MRI), "
"anatomical region, and any visible abnormalities or lesions."
),
},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{data}" }},
],
}
]
res = llm.invoke(user_content)
caption = res.content.strip()
# 過濾太短或重複字詞的無效描述
if len(caption.split()) < 3 or caption.lower().count("mri") > 5:
raise ValueError("Caption too generic or repetitive")
print(f"[Caption] ✅ GPT-4o caption success for {path}")
return caption
except Exception as e:
print(f"[Caption] ⚠️ GPT-4o failed ({e}), fallback to BLIP for {path}")
# -------- BLIP 備援 --------
try:
image = Image.open(path).convert("RGB")
inputs = _caption_proc(image, return_tensors="pt")
out = _caption_model.generate(**inputs, max_new_tokens=40)
caption = _caption_proc.decode(out[0], skip_special_tokens=True)
print(f"[Caption] ✅ BLIP caption fallback success for {path}")
return caption
except Exception as e2:
print(f"[Caption] ❌ BLIP caption also failed ({e2}) for {path}")
return ""
def ocr_text(path: str) -> str:
"""OCR 文字識別"""
try:
return pytesseract.image_to_string(Image.open(path))
except Exception:
return ""
def caption_score(query: str, caption: str) -> float:
"""query 與 caption 的語義相似度"""
if not caption.strip():
return 0.0
qv = np.array(_emb.embed_query(query))
cv = np.array(_emb.embed_query(caption))
return float(np.dot(qv, cv) / (np.linalg.norm(qv)*np.linalg.norm(cv)+1e-9))
# ============================================================
# 3️⃣ VLM 視覺理解 (GPT-4o 看圖 re-rank)
# ============================================================
def vlm_score_images(query: str, img_items):
"""Use GPT-4o to rate image relevance (0–1)."""
llm_vlm = ChatOpenAI(model="gpt-4o", temperature=0.0, openai_api_key=OPENAI_API_KEY)
user_content = [
{
"type": "text",
"text": (
f"You are a scoring model. For each image, output only a JSON list of floating-point relevance scores "
f"(0–1) to the query: '{query}'. Example: [0.8, 0.5, 0.3]. No explanations."
),
}
]
def encode_img(path):
with open(path, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{data}"}
}
user_content += [encode_img(it["image_path"]) for it in img_items[:5]]
try:
res = llm_vlm.invoke([{"role": "user", "content": user_content}])
print("[VLM raw output]:", res.content)
import re, json
try:
scores = json.loads(res.content)
except Exception:
nums = re.findall(r"\d+\.\d+", res.content)
scores = [float(x) for x in nums] if nums else []
for i, s in enumerate(scores):
if i < len(img_items):
img_items[i]["vlm_score"] = float(s)
except Exception as e:
print(f"[VLM] failed: {e}")
return img_items
# ============================================================
# 4️⃣ 融合邏輯 (含 Caption, OCR, VLM)
# ============================================================
def _advanced_fusion(text_items, img_items, w_text=0.4, w_img=0.3, w_cap=0.2, w_vlm=0.1):
def _z(x): x = np.array(x, float); return (x - x.mean()) / (x.std() + 1e-9)
t_rank = _z([1/(i+1) for i,_ in enumerate(text_items)])
i_clip = _z([float(it.get("score",0)) for it in img_items])
i_rank = _z([1/(i+1) for i,_ in enumerate(img_items)])
i_cap = _z([float(it.get("caption_score",0)) for it in img_items])
i_vlm = _z([float(it.get("vlm_score",0)) for it in img_items])
fused = []
for i,it in enumerate(text_items):
x=dict(it); x["_f"]=w_text*t_rank[i]; fused.append(x)
for i,it in enumerate(img_items):
x=dict(it)
x["_f"]= w_img*(0.7*i_clip[i]+0.3*i_rank[i]) + w_cap*i_cap[i] + w_vlm*i_vlm[i]
fused.append(x)
fused.sort(key=lambda x:-x["_f"])
for it in fused: it.pop("_f",None)
return fused
# ============================================================
# 5️⃣ 多模態檢索主函式
# ============================================================
def multimodal_rag(query: str) -> str:
"""進階版:文字 + 影像 + Caption + OCR + VLM"""
# 🔹 文字
try:
docs = VS.max_marginal_relevance_search(query, k=5, fetch_k=12)
except Exception:
docs = VS.similarity_search(query, k=5)
text_items = [{"type": "text", "text": d.page_content, "source": d.metadata.get("source")} for d in docs]
# 🔹 影像
img_items = IMG_INDEX.query(query, k=5) if IMG_INDEX else []
# 🔹 Caption + OCR
for it in img_items:
p = it["image_path"]
cap = caption_image(p)
ocr = ocr_text(p)
it["caption"] = cap
it["ocr"] = ocr
it["caption_score"] = caption_score(query, cap + " " + ocr)
# 🔹 VLM 評估 (GPT-4o 看圖)
USE_VLM_RERANK = bool(int(os.getenv("USE_VLM_RERANK", "1")))
if USE_VLM_RERANK and img_items:
print("[VLM] running GPT-4o scoring...")
img_items = vlm_score_images(query, img_items)
fused = _advanced_fusion(text_items, img_items)
return json.dumps({
"text_topk": text_items,
"image_topk": img_items,
"fused": fused[:10]
}, ensure_ascii=False, indent=2)
# ============================================================
# 6️⃣ Tool + Agent
# ============================================================
multi_tool = Tool(
name="BreastCancerMultiRAG",
func=multimodal_rag,
description=(
"Retrieve 3–5 relevant text chunks and 3–5 relevant images via CLIP. "
"Images are enriched with BLIP captions, OCR, and optionally GPT-4o visual scores. "
"Return JSON with text_topk, image_topk, and fused."
),
)
SYSTEM_PROMPT_MM = SYSTEM_PROMPT + (
"\nYou now have access to `BreastCancerMultiRAG`, which also retrieves and interprets image evidence."
)
def build_agent_multimodal():
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=OPENAI_API_KEY)
agent = initialize_agent(
tools=[multi_tool],
llm=llm,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True,
handle_parsing_errors=True,
max_iterations=3,
max_execution_time=60,
early_stopping_method="generate",
system_message=SYSTEM_PROMPT_MM,
)
return agent