Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import os | |
| import shutil | |
| import numpy as np | |
| from src.multimodal.multimodal_retriever import MultimodalRetriever | |
| from src.llm.llm_factory import get_llm | |
| from src.multimodal.clip_embedding import CLIPEmbedding | |
| from src.embeddings.embedding_factory import get_text_embedding | |
| class MultimodalRAG: | |
| def __init__(self): | |
| print("DEBUG: Initializing MultimodalRAG") | |
| self.retriever = MultimodalRetriever() | |
| self.llm = get_llm() | |
| # cross modal models | |
| self.clip = CLIPEmbedding() | |
| self.text_embedder = get_text_embedding() | |
| # project root | |
| self.PROJECT_ROOT = Path(__file__).resolve().parents[2] | |
| print("DEBUG PROJECT_ROOT:", self.PROJECT_ROOT) | |
| self.BASE_DATA_DIR = Path(os.getenv("HF_HOME", "data")) | |
| self.OUTPUT_DIR = (self.PROJECT_ROOT / self.BASE_DATA_DIR / "outputs").resolve() | |
| self.OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| self.PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "") | |
| def ask(self, query): | |
| print("\n==============================") | |
| print("DEBUG QUERY:", query) | |
| print("==============================") | |
| # -------------------------------------------------- | |
| # STEP 1 — RETRIEVE | |
| # -------------------------------------------------- | |
| docs, metas = self.retriever.retrieve(query, k=20) | |
| print("DEBUG total metas:", len(metas)) | |
| image_metas = [m for m in metas if m.get("type") == "image"] | |
| print("DEBUG image metas:", len(image_metas)) | |
| # -------------------------------------------------- | |
| # STEP 2 — LLM ANSWER | |
| # -------------------------------------------------- | |
| context = "\n\n".join(docs) | |
| prompt = f""" | |
| You are a medical anatomy assistant. | |
| Use the context to answer the question. | |
| Context: | |
| {context} | |
| Question: | |
| {query} | |
| """ | |
| response = self.llm.invoke(prompt) | |
| # -------------------------------------------------- | |
| # STEP 3 — CROSS MODAL RERANK | |
| # -------------------------------------------------- | |
| reranked = [] | |
| # CLIP text embedding expects a list of texts | |
| query_clip = self.clip.embed_text([query])[0] | |
| query_text = self.text_embedder.embed_query(query) | |
| for m in image_metas: | |
| raw_caption = m.get("caption") or "" | |
| context_text = m.get("context") or m.get("nearby_text") or "" | |
| # Prefer explicit caption; fall back to local context text | |
| candidate_text = raw_caption.strip() or context_text.strip() | |
| image_path = m.get("image_path") | |
| print("\n--- IMAGE CANDIDATE ---") | |
| print("DEBUG caption:", raw_caption) | |
| print("DEBUG candidate_text:", candidate_text[:200]) | |
| print("DEBUG raw image_path:", image_path) | |
| # skip if we still don't have enough semantic signal | |
| if not candidate_text or len(candidate_text) < 10: | |
| print("DEBUG skipped: weak caption") | |
| continue | |
| if not image_path: | |
| print("DEBUG skipped: no image path") | |
| continue | |
| image_path = Path(image_path) | |
| if not image_path.is_absolute(): | |
| image_path = self.PROJECT_ROOT / image_path | |
| print("DEBUG resolved image_path:", image_path) | |
| if not image_path.exists(): | |
| print("DEBUG skipped: file missing") | |
| continue | |
| try: | |
| print("DEBUG computing CLIP embedding") | |
| image_embedding = self.clip.embed_image([str(image_path)])[0] | |
| # cosine similarity | |
| clip_score = float( | |
| np.dot(query_clip, image_embedding) | |
| / (np.linalg.norm(query_clip) * np.linalg.norm(image_embedding)) | |
| ) | |
| caption_embedding = self.text_embedder.embed_query(candidate_text) | |
| caption_score = float( | |
| np.dot(query_text, caption_embedding) | |
| / (np.linalg.norm(query_text) * np.linalg.norm(caption_embedding)) | |
| ) | |
| # caption weighted higher | |
| final_score = 0.25 * clip_score + 0.75 * caption_score | |
| print("DEBUG clip_score:", clip_score) | |
| print("DEBUG caption_score:", caption_score) | |
| print("DEBUG final_score:", final_score) | |
| # filter weak matches | |
| if final_score < 0.18: | |
| print("DEBUG skipped: score too low") | |
| continue | |
| reranked.append((final_score, m)) | |
| except Exception as e: | |
| print("DEBUG embedding error:", str(e)) | |
| continue | |
| print("\nDEBUG reranked count:", len(reranked)) | |
| # sort best images | |
| reranked.sort(key=lambda x: x[0], reverse=True) | |
| image_metas = [x[1] for x in reranked[:5]] | |
| print("DEBUG top image candidates:", len(image_metas)) | |
| # -------------------------------------------------- | |
| # STEP 4 — BUILD PUBLIC URLS | |
| # -------------------------------------------------- | |
| public_images = [] | |
| seen_keys = set() | |
| MAX_IMAGES = 3 | |
| for m in image_metas: | |
| if len(public_images) >= MAX_IMAGES: | |
| break | |
| object_key = m.get("object_key") | |
| print("\nDEBUG generating URL for:", object_key) | |
| # MINIO | |
| if object_key and object_key not in seen_keys: | |
| try: | |
| from app.services.object_storage import get_presigned_url | |
| url = get_presigned_url(object_key) | |
| print("DEBUG presigned url:", url) | |
| if url: | |
| public_images.append(url) | |
| seen_keys.add(object_key) | |
| continue | |
| except Exception as e: | |
| print("DEBUG MinIO error:", str(e)) | |
| # LOCAL FALLBACK | |
| src_path = m.get("image_path") | |
| if not src_path: | |
| continue | |
| src_path = Path(src_path) | |
| if not src_path.is_absolute(): | |
| src_path = self.PROJECT_ROOT / src_path | |
| if not src_path.exists(): | |
| continue | |
| filename = src_path.name | |
| if filename in seen_keys: | |
| continue | |
| seen_keys.add(filename) | |
| dst_path = self.OUTPUT_DIR / filename | |
| if not dst_path.exists(): | |
| shutil.copy(src_path, dst_path) | |
| public_url = f"{self.PUBLIC_BASE_URL}/outputs/{filename}" | |
| print("DEBUG fallback url:", public_url) | |
| public_images.append(public_url) | |
| print("\nDEBUG FINAL IMAGE COUNT:", len(public_images)) | |
| return response.content, public_images |