| |
|
| | import os
|
| | import logging
|
| | from typing import Dict, List, Optional
|
| | from dotenv import load_dotenv
|
| |
|
| | from llama_index.core import (
|
| | StorageContext,
|
| | load_index_from_storage,
|
| | Settings
|
| | )
|
| |
|
| | from llama_index.multi_modal_llms.openai import OpenAIMultiModal
|
| | from llama_index.embeddings.openai import OpenAIEmbedding
|
| | from llama_index.embeddings.clip import ClipEmbedding
|
| |
|
| |
|
| | load_dotenv()
|
| |
|
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | class MultimodalRAGConfig:
|
| | """Configuration for the Standalone Multimodal RAG Pipeline"""
|
| | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| |
|
| | INDEX_DIR = os.path.join(BASE_DIR, "multimodal_rag_index")
|
| | IMAGES_DIR = os.path.join(BASE_DIR, "extracted_images")
|
| |
|
| |
|
| | TEXT_EMBED_MODEL = "text-embedding-3-small"
|
| | IMAGE_EMBED_MODEL = "ViT-B/32"
|
| | LLM_MODEL = "gpt-4o"
|
| |
|
| | TOP_K = 3
|
| |
|
| | class MultimodalRAGSystem:
|
| | """
|
| | Standalone Multimodal RAG System (Read-Only)
|
| | """
|
| | def __init__(self):
|
| | self.config = MultimodalRAGConfig()
|
| | self.index = None
|
| | self.query_engine = None
|
| |
|
| | self._initialize_system()
|
| |
|
| | def _initialize_system(self):
|
| | logger.info("Initializing Multimodal RAG System...")
|
| |
|
| | if not os.path.exists(self.config.INDEX_DIR):
|
| | logger.error(f"Index directory not found: {self.config.INDEX_DIR}")
|
| | raise FileNotFoundError(f"Index directory not found: {self.config.INDEX_DIR}")
|
| |
|
| | if not os.getenv("OPENAI_API_KEY"):
|
| | logger.error("OPENAI_API_KEY not found in environment variables.")
|
| | raise ValueError("OPENAI_API_KEY not found.")
|
| |
|
| |
|
| | logger.info("Setting up models...")
|
| | text_embed = OpenAIEmbedding(model=self.config.TEXT_EMBED_MODEL)
|
| | image_embed = ClipEmbedding(model_name=self.config.IMAGE_EMBED_MODEL)
|
| |
|
| |
|
| | openai_mm_llm = OpenAIMultiModal(
|
| | model=self.config.LLM_MODEL,
|
| | max_new_tokens=512
|
| | )
|
| |
|
| |
|
| | logger.info(f"Loading index from {self.config.INDEX_DIR}...")
|
| | storage_context = StorageContext.from_defaults(persist_dir=self.config.INDEX_DIR)
|
| |
|
| | self.index = load_index_from_storage(
|
| | storage_context,
|
| | embed_model=text_embed,
|
| | image_embed_model=image_embed
|
| | )
|
| |
|
| |
|
| | self.query_engine = self.index.as_query_engine(
|
| | llm=openai_mm_llm,
|
| | similarity_top_k=self.config.TOP_K,
|
| | image_similarity_top_k=self.config.TOP_K
|
| | )
|
| |
|
| | logger.info(f"System Ready! Model: {self.config.LLM_MODEL}")
|
| |
|
| | def ask(self, query_str: str) -> Dict:
|
| | """
|
| | Ask a question and return answer + source images
|
| | """
|
| | if not self.query_engine:
|
| | raise RuntimeError("Query engine not initialized")
|
| |
|
| | logger.info(f"Querying: {query_str}")
|
| |
|
| | response = self.query_engine.query(query_str)
|
| |
|
| | source_images = []
|
| | source_texts = []
|
| |
|
| | for node_score in response.source_nodes:
|
| | node = node_score.node
|
| | if node.metadata.get("image_source"):
|
| |
|
| |
|
| | img_path = getattr(node, "image_path", None) or node.metadata.get("image_path")
|
| |
|
| |
|
| | if img_path:
|
| | img_filename = os.path.basename(img_path)
|
| |
|
| |
|
| | web_path = f"/extracted_images/{img_filename}"
|
| | else:
|
| | web_path = None
|
| | img_filename = "unknown"
|
| |
|
| | source_images.append({
|
| | "path": web_path,
|
| | "filename": img_filename,
|
| | "score": node_score.score,
|
| | "page": node.metadata.get("page_number"),
|
| | "file": node.metadata.get("file_name")
|
| | })
|
| | else:
|
| |
|
| | file_name = node.metadata.get("file_name", "N/A")
|
| | page_num = node.metadata.get("page_number", "N/A")
|
| |
|
| | web_link = None
|
| | if file_name != "N/A":
|
| |
|
| | from urllib.parse import quote
|
| | safe_filename = quote(file_name)
|
| | web_link = f"/documents/{safe_filename}"
|
| |
|
| | if page_num != "N/A":
|
| | web_link += f"#page={page_num}"
|
| |
|
| |
|
| | logger.info(f"DEBUG: File: {file_name}, Page: {page_num}, Link: {web_link}")
|
| |
|
| | source_texts.append({
|
| | "text": node.text[:200] + "...",
|
| | "score": node_score.score,
|
| | "page": page_num,
|
| | "file": file_name,
|
| | "link": web_link
|
| | })
|
| |
|
| | return {
|
| | "answer": str(response),
|
| | "images": source_images,
|
| | "texts": source_texts
|
| | }
|
| |
|
| |
|
| | def main():
|
| | try:
|
| | rag = MultimodalRAGSystem()
|
| | while True:
|
| | q = input("Query (q to quit): ")
|
| | if q.lower() == 'q': break
|
| | print(rag.ask(q))
|
| | except Exception as e:
|
| | print(f"Error: {e}")
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|