ThesisBackend / src /multimodal /multimodal_rag_chain.py
AdarshRajDS
stable multimodal supabase ingestion milestone
5484978
from pathlib import Path
import os
import shutil
import numpy as np
from src.multimodal.multimodal_retriever import MultimodalRetriever
from src.llm.llm_factory import get_llm
from src.multimodal.clip_embedding import CLIPEmbedding
from src.embeddings.embedding_factory import get_text_embedding
class MultimodalRAG:
def __init__(self):
print("DEBUG: Initializing MultimodalRAG")
self.retriever = MultimodalRetriever()
self.llm = get_llm()
# cross modal models
self.clip = CLIPEmbedding()
self.text_embedder = get_text_embedding()
# project root
self.PROJECT_ROOT = Path(__file__).resolve().parents[2]
print("DEBUG PROJECT_ROOT:", self.PROJECT_ROOT)
self.BASE_DATA_DIR = Path(os.getenv("HF_HOME", "data"))
self.OUTPUT_DIR = (self.PROJECT_ROOT / self.BASE_DATA_DIR / "outputs").resolve()
self.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
self.PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "")
def ask(self, query):
print("\n==============================")
print("DEBUG QUERY:", query)
print("==============================")
# --------------------------------------------------
# STEP 1 — RETRIEVE
# --------------------------------------------------
docs, metas = self.retriever.retrieve(query, k=20)
print("DEBUG total metas:", len(metas))
image_metas = [m for m in metas if m.get("type") == "image"]
print("DEBUG image metas:", len(image_metas))
# --------------------------------------------------
# STEP 2 — LLM ANSWER
# --------------------------------------------------
context = "\n\n".join(docs)
prompt = f"""
You are a medical anatomy assistant.
Use the context to answer the question.
Context:
{context}
Question:
{query}
"""
response = self.llm.invoke(prompt)
# --------------------------------------------------
# STEP 3 — CROSS MODAL RERANK
# --------------------------------------------------
reranked = []
# CLIP text embedding expects a list of texts
query_clip = self.clip.embed_text([query])[0]
query_text = self.text_embedder.embed_query(query)
for m in image_metas:
raw_caption = m.get("caption") or ""
context_text = m.get("context") or m.get("nearby_text") or ""
# Prefer explicit caption; fall back to local context text
candidate_text = raw_caption.strip() or context_text.strip()
image_path = m.get("image_path")
print("\n--- IMAGE CANDIDATE ---")
print("DEBUG caption:", raw_caption)
print("DEBUG candidate_text:", candidate_text[:200])
print("DEBUG raw image_path:", image_path)
# skip if we still don't have enough semantic signal
if not candidate_text or len(candidate_text) < 10:
print("DEBUG skipped: weak caption")
continue
if not image_path:
print("DEBUG skipped: no image path")
continue
image_path = Path(image_path)
if not image_path.is_absolute():
image_path = self.PROJECT_ROOT / image_path
print("DEBUG resolved image_path:", image_path)
if not image_path.exists():
print("DEBUG skipped: file missing")
continue
try:
print("DEBUG computing CLIP embedding")
image_embedding = self.clip.embed_image([str(image_path)])[0]
# cosine similarity
clip_score = float(
np.dot(query_clip, image_embedding)
/ (np.linalg.norm(query_clip) * np.linalg.norm(image_embedding))
)
caption_embedding = self.text_embedder.embed_query(candidate_text)
caption_score = float(
np.dot(query_text, caption_embedding)
/ (np.linalg.norm(query_text) * np.linalg.norm(caption_embedding))
)
# caption weighted higher
final_score = 0.25 * clip_score + 0.75 * caption_score
print("DEBUG clip_score:", clip_score)
print("DEBUG caption_score:", caption_score)
print("DEBUG final_score:", final_score)
# filter weak matches
if final_score < 0.18:
print("DEBUG skipped: score too low")
continue
reranked.append((final_score, m))
except Exception as e:
print("DEBUG embedding error:", str(e))
continue
print("\nDEBUG reranked count:", len(reranked))
# sort best images
reranked.sort(key=lambda x: x[0], reverse=True)
image_metas = [x[1] for x in reranked[:5]]
print("DEBUG top image candidates:", len(image_metas))
# --------------------------------------------------
# STEP 4 — BUILD PUBLIC URLS
# --------------------------------------------------
public_images = []
seen_keys = set()
MAX_IMAGES = 3
for m in image_metas:
if len(public_images) >= MAX_IMAGES:
break
object_key = m.get("object_key")
print("\nDEBUG generating URL for:", object_key)
# MINIO
if object_key and object_key not in seen_keys:
try:
from app.services.object_storage import get_presigned_url
url = get_presigned_url(object_key)
print("DEBUG presigned url:", url)
if url:
public_images.append(url)
seen_keys.add(object_key)
continue
except Exception as e:
print("DEBUG MinIO error:", str(e))
# LOCAL FALLBACK
src_path = m.get("image_path")
if not src_path:
continue
src_path = Path(src_path)
if not src_path.is_absolute():
src_path = self.PROJECT_ROOT / src_path
if not src_path.exists():
continue
filename = src_path.name
if filename in seen_keys:
continue
seen_keys.add(filename)
dst_path = self.OUTPUT_DIR / filename
if not dst_path.exists():
shutil.copy(src_path, dst_path)
public_url = f"{self.PUBLIC_BASE_URL}/outputs/{filename}"
print("DEBUG fallback url:", public_url)
public_images.append(public_url)
print("\nDEBUG FINAL IMAGE COUNT:", len(public_images))
return response.content, public_images