|
|
import os |
|
|
import logging |
|
|
from typing import Dict, List, Optional |
|
|
from dotenv import load_dotenv |
|
|
from llama_index.llms.openai import OpenAI |
|
|
|
|
|
from llama_index.core import ( |
|
|
StorageContext, |
|
|
load_index_from_storage, |
|
|
Settings |
|
|
) |
|
|
|
|
|
from llama_index.multi_modal_llms.openai import OpenAIMultiModal |
|
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
|
from llama_index.embeddings.clip import ClipEmbedding |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class MultimodalRAGConfig: |
|
|
"""Configuration for the Standalone Multimodal RAG Pipeline""" |
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
INDEX_DIR = os.path.join(BASE_DIR, "multimodal_rag_index") |
|
|
IMAGES_DIR = os.path.join(BASE_DIR, "extracted_images") |
|
|
|
|
|
|
|
|
TEXT_EMBED_MODEL = "text-embedding-3-small" |
|
|
IMAGE_EMBED_MODEL = "ViT-B/32" |
|
|
LLM_MODEL = "gpt-4o" |
|
|
|
|
|
TOP_K = 3 |
|
|
|
|
|
class MultimodalRAGSystem: |
|
|
""" |
|
|
Standalone Multimodal RAG System (Read-Only) |
|
|
""" |
|
|
def __init__(self): |
|
|
self.config = MultimodalRAGConfig() |
|
|
self.index = None |
|
|
self.query_engine = None |
|
|
self.rewrite_llm = OpenAI( |
|
|
model="gpt-4o-mini", |
|
|
temperature=0.0 |
|
|
) |
|
|
self._initialize_system() |
|
|
|
|
|
def _initialize_system(self): |
|
|
logger.info("Initializing Multimodal RAG System...") |
|
|
|
|
|
if not os.path.exists(self.config.INDEX_DIR): |
|
|
logger.error(f"Index directory not found: {self.config.INDEX_DIR}") |
|
|
raise FileNotFoundError(f"Index directory not found: {self.config.INDEX_DIR}") |
|
|
|
|
|
if not os.getenv("OPENAI_API_KEY"): |
|
|
logger.error("OPENAI_API_KEY not found in environment variables.") |
|
|
raise ValueError("OPENAI_API_KEY not found.") |
|
|
|
|
|
|
|
|
logger.info("Setting up models...") |
|
|
text_embed = OpenAIEmbedding(model=self.config.TEXT_EMBED_MODEL) |
|
|
image_embed = ClipEmbedding(model_name=self.config.IMAGE_EMBED_MODEL) |
|
|
|
|
|
|
|
|
openai_mm_llm = OpenAIMultiModal( |
|
|
model=self.config.LLM_MODEL, |
|
|
max_new_tokens=512 |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"Loading index from {self.config.INDEX_DIR}...") |
|
|
storage_context = StorageContext.from_defaults(persist_dir=self.config.INDEX_DIR) |
|
|
|
|
|
self.index = load_index_from_storage( |
|
|
storage_context, |
|
|
embed_model=text_embed, |
|
|
image_embed_model=image_embed |
|
|
) |
|
|
|
|
|
|
|
|
self.query_engine = self.index.as_query_engine( |
|
|
llm=openai_mm_llm, |
|
|
similarity_top_k=self.config.TOP_K, |
|
|
image_similarity_top_k=self.config.TOP_K |
|
|
) |
|
|
|
|
|
logger.info(f"System Ready! Model: {self.config.LLM_MODEL}") |
|
|
|
|
|
def ask(self, query_str: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict: |
|
|
""" |
|
|
Query the RAG system with optional chat history for context. |
|
|
|
|
|
Args: |
|
|
query_str: The user's question |
|
|
chat_history: List of dicts with 'role' and 'content' keys |
|
|
|
|
|
Returns: |
|
|
Dict with 'answer', 'images', and 'texts' keys |
|
|
""" |
|
|
if not self.query_engine: |
|
|
raise RuntimeError("Query engine not initialized") |
|
|
|
|
|
logger.info(f"Original question: {query_str}") |
|
|
|
|
|
|
|
|
standalone_question = query_str |
|
|
|
|
|
if chat_history and len(chat_history) > 0: |
|
|
|
|
|
history_text = "\n".join( |
|
|
f"{turn['role'].capitalize()}: {turn['content']}" |
|
|
for turn in chat_history[-4:] |
|
|
) |
|
|
|
|
|
rewrite_prompt = ( |
|
|
"Given the previous conversation and the follow-up question, " |
|
|
"rewrite the follow-up question as a standalone question that includes all necessary context.\n\n" |
|
|
f"Conversation history:\n{history_text}\n\n" |
|
|
f"Follow-up question:\n{query_str}\n\n" |
|
|
"Rewrite this as a standalone question that can be understood without the conversation history. " |
|
|
"Only output the rewritten question, nothing else.\n\n" |
|
|
"Standalone question:" |
|
|
) |
|
|
|
|
|
standalone_question = self.rewrite_llm.complete( |
|
|
rewrite_prompt |
|
|
).text.strip() |
|
|
|
|
|
logger.info(f"Rewritten question: {standalone_question}") |
|
|
|
|
|
response = self.query_engine.query(standalone_question) |
|
|
|
|
|
|
|
|
source_images = [] |
|
|
source_texts = [] |
|
|
|
|
|
for node_score in response.source_nodes: |
|
|
node = node_score.node |
|
|
if node.metadata.get("image_source"): |
|
|
|
|
|
|
|
|
img_path = getattr(node, "image_path", None) or node.metadata.get("image_path") |
|
|
|
|
|
|
|
|
if img_path: |
|
|
img_filename = os.path.basename(img_path) |
|
|
|
|
|
|
|
|
web_path = f"/extracted_images/{img_filename}" |
|
|
else: |
|
|
web_path = None |
|
|
img_filename = "unknown" |
|
|
|
|
|
source_images.append({ |
|
|
"path": web_path, |
|
|
"filename": img_filename, |
|
|
"score": node_score.score, |
|
|
"page": node.metadata.get("page_number"), |
|
|
"file": node.metadata.get("file_name") |
|
|
}) |
|
|
else: |
|
|
|
|
|
file_name = node.metadata.get("file_name", "N/A") |
|
|
page_num = node.metadata.get("page_number", "N/A") |
|
|
|
|
|
web_link = None |
|
|
if file_name != "N/A": |
|
|
|
|
|
from urllib.parse import quote |
|
|
safe_filename = quote(file_name) |
|
|
web_link = f"/documents/{safe_filename}" |
|
|
|
|
|
if page_num != "N/A": |
|
|
web_link += f"#page={page_num}" |
|
|
|
|
|
|
|
|
logger.info(f"DEBUG: File: {file_name}, Page: {page_num}, Link: {web_link}") |
|
|
|
|
|
source_texts.append({ |
|
|
"text": node.text[:200] + "...", |
|
|
"score": node_score.score, |
|
|
"page": page_num, |
|
|
"file": file_name, |
|
|
"link": web_link |
|
|
}) |
|
|
|
|
|
return { |
|
|
"answer": str(response), |
|
|
"images": source_images, |
|
|
"texts": source_texts |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
try: |
|
|
rag = MultimodalRAGSystem() |
|
|
chat_hist = [] |
|
|
while True: |
|
|
q = input("Query (q to quit): ") |
|
|
if q.lower() == 'q': break |
|
|
|
|
|
result = rag.ask(q, chat_history=chat_hist) |
|
|
print(f"\nAnswer: {result['answer']}\n") |
|
|
|
|
|
|
|
|
chat_hist.append({"role": "user", "content": q}) |
|
|
chat_hist.append({"role": "assistant", "content": result['answer']}) |
|
|
|
|
|
|
|
|
if len(chat_hist) > 6: |
|
|
chat_hist = chat_hist[-6:] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |