WHEC_Chatbot_Test / query_index.py
mohamedhassan22's picture
Update query_index.py
f006f4f verified
import os
import logging
from typing import Dict, List, Optional
from dotenv import load_dotenv
from llama_index.llms.openai import OpenAI
from llama_index.core import (
StorageContext,
load_index_from_storage,
Settings
)
# Standalone imports for Multimodal RAG
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.clip import ClipEmbedding
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MultimodalRAGConfig:
"""Configuration for the Standalone Multimodal RAG Pipeline"""
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Hardcoded to requested paths
INDEX_DIR = os.path.join(BASE_DIR, "multimodal_rag_index")
IMAGES_DIR = os.path.join(BASE_DIR, "extracted_images")
# Models
TEXT_EMBED_MODEL = "text-embedding-3-small"
IMAGE_EMBED_MODEL = "ViT-B/32"
LLM_MODEL = "gpt-4o"
TOP_K = 3 # Retrieve top 3 text and top 3 images
class MultimodalRAGSystem:
"""
Standalone Multimodal RAG System (Read-Only)
"""
def __init__(self):
self.config = MultimodalRAGConfig()
self.index = None
self.query_engine = None
self.rewrite_llm = OpenAI(
model="gpt-4o-mini",
temperature=0.0
)
self._initialize_system()
def _initialize_system(self):
logger.info("Initializing Multimodal RAG System...")
if not os.path.exists(self.config.INDEX_DIR):
logger.error(f"Index directory not found: {self.config.INDEX_DIR}")
raise FileNotFoundError(f"Index directory not found: {self.config.INDEX_DIR}")
if not os.getenv("OPENAI_API_KEY"):
logger.error("OPENAI_API_KEY not found in environment variables.")
raise ValueError("OPENAI_API_KEY not found.")
# 1. Setup Models
logger.info("Setting up models...")
text_embed = OpenAIEmbedding(model=self.config.TEXT_EMBED_MODEL)
image_embed = ClipEmbedding(model_name=self.config.IMAGE_EMBED_MODEL)
# GPT-4o for Multimodal Generation
openai_mm_llm = OpenAIMultiModal(
model=self.config.LLM_MODEL,
max_new_tokens=512
)
# 2. Load Index
logger.info(f"Loading index from {self.config.INDEX_DIR}...")
storage_context = StorageContext.from_defaults(persist_dir=self.config.INDEX_DIR)
self.index = load_index_from_storage(
storage_context,
embed_model=text_embed,
image_embed_model=image_embed
)
# 3. Create Query Engine
self.query_engine = self.index.as_query_engine(
llm=openai_mm_llm,
similarity_top_k=self.config.TOP_K,
image_similarity_top_k=self.config.TOP_K
)
logger.info(f"System Ready! Model: {self.config.LLM_MODEL}")
def ask(self, query_str: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict:
"""
Query the RAG system with optional chat history for context.
Args:
query_str: The user's question
chat_history: List of dicts with 'role' and 'content' keys
Returns:
Dict with 'answer', 'images', and 'texts' keys
"""
if not self.query_engine:
raise RuntimeError("Query engine not initialized")
logger.info(f"Original question: {query_str}")
# Rewrite follow-up into standalone question if history exists
standalone_question = query_str
if chat_history and len(chat_history) > 0:
# Convert chat history to context string
history_text = "\n".join(
f"{turn['role'].capitalize()}: {turn['content']}"
for turn in chat_history[-4:] # last 2 turns (4 messages)
)
rewrite_prompt = (
"Given the previous conversation and the follow-up question, "
"rewrite the follow-up question as a standalone question that includes all necessary context.\n\n"
f"Conversation history:\n{history_text}\n\n"
f"Follow-up question:\n{query_str}\n\n"
"Rewrite this as a standalone question that can be understood without the conversation history. "
"Only output the rewritten question, nothing else.\n\n"
"Standalone question:"
)
standalone_question = self.rewrite_llm.complete(
rewrite_prompt
).text.strip()
logger.info(f"Rewritten question: {standalone_question}")
response = self.query_engine.query(standalone_question)
source_images = []
source_texts = []
for node_score in response.source_nodes:
node = node_score.node
if node.metadata.get("image_source"):
# It's an image node
# Try to get image path from node attribute or metadata
img_path = getattr(node, "image_path", None) or node.metadata.get("image_path")
# Normalize path if possible to be relative or filename
if img_path:
img_filename = os.path.basename(img_path)
# We assume app.py serves 'extracted_images' as static
# So let's provide a relative web path or just the filename for app.py to handle
web_path = f"/extracted_images/{img_filename}"
else:
web_path = None
img_filename = "unknown"
source_images.append({
"path": web_path,
"filename": img_filename,
"score": node_score.score,
"page": node.metadata.get("page_number"),
"file": node.metadata.get("file_name")
})
else:
# Text node
file_name = node.metadata.get("file_name", "N/A")
page_num = node.metadata.get("page_number", "N/A")
web_link = None
if file_name != "N/A":
# URL encode the filename to handle spaces and special chars safely
from urllib.parse import quote
safe_filename = quote(file_name)
web_link = f"/documents/{safe_filename}"
if page_num != "N/A":
web_link += f"#page={page_num}"
# DEBUG: Print link construction details
logger.info(f"DEBUG: File: {file_name}, Page: {page_num}, Link: {web_link}")
source_texts.append({
"text": node.text[:200] + "...",
"score": node_score.score,
"page": page_num,
"file": file_name,
"link": web_link
})
return {
"answer": str(response),
"images": source_images,
"texts": source_texts
}
# Main for simple testing
def main():
try:
rag = MultimodalRAGSystem()
chat_hist = []
while True:
q = input("Query (q to quit): ")
if q.lower() == 'q': break
result = rag.ask(q, chat_history=chat_hist)
print(f"\nAnswer: {result['answer']}\n")
# Update history
chat_hist.append({"role": "user", "content": q})
chat_hist.append({"role": "assistant", "content": result['answer']})
# Keep history reasonable
if len(chat_hist) > 6:
chat_hist = chat_hist[-6:]
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()