Spaces:
Runtime error
Runtime error
| # ============================================================ | |
| # agency_multimodal.py | |
| # ============================================================ | |
| import os, json, base64 | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import pytesseract | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| # from agent import build_agent, VS ← 原本這行註解掉 | |
| from agent_pdfimages import VS, IMG_INDEX, rag_tool, SYSTEM_PROMPT # 🔁 從原檔匯入 | |
| from langchain.tools import Tool | |
| from langchain.agents import initialize_agent, AgentType | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| # ============================================================ | |
| # 1️⃣ 基本設定 | |
| # ============================================================ | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| raise RuntimeError("❌ Please set OPENAI_API_KEY in .env") | |
| BASE_DIR = Path(__file__).resolve().parent | |
| _emb = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY) | |
| # ============================================================ | |
| # 2️⃣ 載入 Caption / OCR 模型 | |
| # ============================================================ | |
| # 初始化 BLIP(當 GPT-4o 失敗時用) | |
| _caption_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| _caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| def caption_image(path: str) -> str: | |
| """ | |
| 先嘗試用 GPT-4o 產生影像描述(modality、region、abnormality) | |
| 若 GPT-4o 失敗,則自動 fallback 到 BLIP。 | |
| """ | |
| try: | |
| # -------- GPT-4o 影像理解 -------- | |
| with open(path, "rb") as f: | |
| data = base64.b64encode(f.read()).decode("utf-8") | |
| llm = ChatOpenAI(model="gpt-4o", temperature=0.0, openai_api_key=OPENAI_API_KEY) | |
| user_content = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": ( | |
| "Describe this medical image in one concise English sentence. " | |
| "Identify the modality (e.g., mammogram, pathology slide, MRI), " | |
| "anatomical region, and any visible abnormalities or lesions." | |
| ), | |
| }, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{data}" }}, | |
| ], | |
| } | |
| ] | |
| res = llm.invoke(user_content) | |
| caption = res.content.strip() | |
| # 過濾太短或重複字詞的無效描述 | |
| if len(caption.split()) < 3 or caption.lower().count("mri") > 5: | |
| raise ValueError("Caption too generic or repetitive") | |
| print(f"[Caption] ✅ GPT-4o caption success for {path}") | |
| return caption | |
| except Exception as e: | |
| print(f"[Caption] ⚠️ GPT-4o failed ({e}), fallback to BLIP for {path}") | |
| # -------- BLIP 備援 -------- | |
| try: | |
| image = Image.open(path).convert("RGB") | |
| inputs = _caption_proc(image, return_tensors="pt") | |
| out = _caption_model.generate(**inputs, max_new_tokens=40) | |
| caption = _caption_proc.decode(out[0], skip_special_tokens=True) | |
| print(f"[Caption] ✅ BLIP caption fallback success for {path}") | |
| return caption | |
| except Exception as e2: | |
| print(f"[Caption] ❌ BLIP caption also failed ({e2}) for {path}") | |
| return "" | |
| def ocr_text(path: str) -> str: | |
| """OCR 文字識別""" | |
| try: | |
| return pytesseract.image_to_string(Image.open(path)) | |
| except Exception: | |
| return "" | |
| def caption_score(query: str, caption: str) -> float: | |
| """query 與 caption 的語義相似度""" | |
| if not caption.strip(): | |
| return 0.0 | |
| qv = np.array(_emb.embed_query(query)) | |
| cv = np.array(_emb.embed_query(caption)) | |
| return float(np.dot(qv, cv) / (np.linalg.norm(qv)*np.linalg.norm(cv)+1e-9)) | |
| # ============================================================ | |
| # 3️⃣ VLM 視覺理解 (GPT-4o 看圖 re-rank) | |
| # ============================================================ | |
| def vlm_score_images(query: str, img_items): | |
| """Use GPT-4o to rate image relevance (0–1).""" | |
| llm_vlm = ChatOpenAI(model="gpt-4o", temperature=0.0, openai_api_key=OPENAI_API_KEY) | |
| user_content = [ | |
| { | |
| "type": "text", | |
| "text": ( | |
| f"You are a scoring model. For each image, output only a JSON list of floating-point relevance scores " | |
| f"(0–1) to the query: '{query}'. Example: [0.8, 0.5, 0.3]. No explanations." | |
| ), | |
| } | |
| ] | |
| def encode_img(path): | |
| with open(path, "rb") as f: | |
| data = base64.b64encode(f.read()).decode("utf-8") | |
| return { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{data}"} | |
| } | |
| user_content += [encode_img(it["image_path"]) for it in img_items[:5]] | |
| try: | |
| res = llm_vlm.invoke([{"role": "user", "content": user_content}]) | |
| print("[VLM raw output]:", res.content) | |
| import re, json | |
| try: | |
| scores = json.loads(res.content) | |
| except Exception: | |
| nums = re.findall(r"\d+\.\d+", res.content) | |
| scores = [float(x) for x in nums] if nums else [] | |
| for i, s in enumerate(scores): | |
| if i < len(img_items): | |
| img_items[i]["vlm_score"] = float(s) | |
| except Exception as e: | |
| print(f"[VLM] failed: {e}") | |
| return img_items | |
| # ============================================================ | |
| # 4️⃣ 融合邏輯 (含 Caption, OCR, VLM) | |
| # ============================================================ | |
| def _advanced_fusion(text_items, img_items, w_text=0.4, w_img=0.3, w_cap=0.2, w_vlm=0.1): | |
| def _z(x): x = np.array(x, float); return (x - x.mean()) / (x.std() + 1e-9) | |
| t_rank = _z([1/(i+1) for i,_ in enumerate(text_items)]) | |
| i_clip = _z([float(it.get("score",0)) for it in img_items]) | |
| i_rank = _z([1/(i+1) for i,_ in enumerate(img_items)]) | |
| i_cap = _z([float(it.get("caption_score",0)) for it in img_items]) | |
| i_vlm = _z([float(it.get("vlm_score",0)) for it in img_items]) | |
| fused = [] | |
| for i,it in enumerate(text_items): | |
| x=dict(it); x["_f"]=w_text*t_rank[i]; fused.append(x) | |
| for i,it in enumerate(img_items): | |
| x=dict(it) | |
| x["_f"]= w_img*(0.7*i_clip[i]+0.3*i_rank[i]) + w_cap*i_cap[i] + w_vlm*i_vlm[i] | |
| fused.append(x) | |
| fused.sort(key=lambda x:-x["_f"]) | |
| for it in fused: it.pop("_f",None) | |
| return fused | |
| # ============================================================ | |
| # 5️⃣ 多模態檢索主函式 | |
| # ============================================================ | |
| def multimodal_rag(query: str) -> str: | |
| """進階版:文字 + 影像 + Caption + OCR + VLM""" | |
| # 🔹 文字 | |
| try: | |
| docs = VS.max_marginal_relevance_search(query, k=5, fetch_k=12) | |
| except Exception: | |
| docs = VS.similarity_search(query, k=5) | |
| text_items = [{"type": "text", "text": d.page_content, "source": d.metadata.get("source")} for d in docs] | |
| # 🔹 影像 | |
| img_items = IMG_INDEX.query(query, k=5) if IMG_INDEX else [] | |
| # 🔹 Caption + OCR | |
| for it in img_items: | |
| p = it["image_path"] | |
| cap = caption_image(p) | |
| ocr = ocr_text(p) | |
| it["caption"] = cap | |
| it["ocr"] = ocr | |
| it["caption_score"] = caption_score(query, cap + " " + ocr) | |
| # 🔹 VLM 評估 (GPT-4o 看圖) | |
| USE_VLM_RERANK = bool(int(os.getenv("USE_VLM_RERANK", "1"))) | |
| if USE_VLM_RERANK and img_items: | |
| print("[VLM] running GPT-4o scoring...") | |
| img_items = vlm_score_images(query, img_items) | |
| fused = _advanced_fusion(text_items, img_items) | |
| return json.dumps({ | |
| "text_topk": text_items, | |
| "image_topk": img_items, | |
| "fused": fused[:10] | |
| }, ensure_ascii=False, indent=2) | |
| # ============================================================ | |
| # 6️⃣ Tool + Agent | |
| # ============================================================ | |
| multi_tool = Tool( | |
| name="BreastCancerMultiRAG", | |
| func=multimodal_rag, | |
| description=( | |
| "Retrieve 3–5 relevant text chunks and 3–5 relevant images via CLIP. " | |
| "Images are enriched with BLIP captions, OCR, and optionally GPT-4o visual scores. " | |
| "Return JSON with text_topk, image_topk, and fused." | |
| ), | |
| ) | |
| SYSTEM_PROMPT_MM = SYSTEM_PROMPT + ( | |
| "\nYou now have access to `BreastCancerMultiRAG`, which also retrieves and interprets image evidence." | |
| ) | |
| def build_agent_multimodal(): | |
| llm = ChatOpenAI(model="gpt-4o", temperature=0.2, openai_api_key=OPENAI_API_KEY) | |
| agent = initialize_agent( | |
| tools=[multi_tool], | |
| llm=llm, | |
| agent=AgentType.OPENAI_FUNCTIONS, | |
| verbose=True, | |
| handle_parsing_errors=True, | |
| max_iterations=3, | |
| max_execution_time=60, | |
| early_stopping_method="generate", | |
| system_message=SYSTEM_PROMPT_MM, | |
| ) | |
| return agent | |