Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from typing import Dict, List, Optional | |
| from dotenv import load_dotenv | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core import ( | |
| StorageContext, | |
| load_index_from_storage, | |
| Settings | |
| ) | |
| # Standalone imports for Multimodal RAG | |
| from llama_index.multi_modal_llms.openai import OpenAIMultiModal | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.embeddings.clip import ClipEmbedding | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class MultimodalRAGConfig: | |
| """Configuration for the Standalone Multimodal RAG Pipeline""" | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Hardcoded to requested paths | |
| INDEX_DIR = os.path.join(BASE_DIR, "multimodal_rag_index") | |
| IMAGES_DIR = os.path.join(BASE_DIR, "extracted_images") | |
| # Models | |
| TEXT_EMBED_MODEL = "text-embedding-3-small" | |
| IMAGE_EMBED_MODEL = "ViT-B/32" | |
| LLM_MODEL = "gpt-4o" | |
| TOP_K = 3 # Retrieve top 3 text and top 3 images | |
| class MultimodalRAGSystem: | |
| """ | |
| Standalone Multimodal RAG System (Read-Only) | |
| """ | |
| def __init__(self): | |
| self.config = MultimodalRAGConfig() | |
| self.index = None | |
| self.query_engine = None | |
| self.rewrite_llm = OpenAI( | |
| model="gpt-4o-mini", | |
| temperature=0.0 | |
| ) | |
| self._initialize_system() | |
| def _initialize_system(self): | |
| logger.info("Initializing Multimodal RAG System...") | |
| if not os.path.exists(self.config.INDEX_DIR): | |
| logger.error(f"Index directory not found: {self.config.INDEX_DIR}") | |
| raise FileNotFoundError(f"Index directory not found: {self.config.INDEX_DIR}") | |
| if not os.getenv("OPENAI_API_KEY"): | |
| logger.error("OPENAI_API_KEY not found in environment variables.") | |
| raise ValueError("OPENAI_API_KEY not found.") | |
| # 1. Setup Models | |
| logger.info("Setting up models...") | |
| text_embed = OpenAIEmbedding(model=self.config.TEXT_EMBED_MODEL) | |
| image_embed = ClipEmbedding(model_name=self.config.IMAGE_EMBED_MODEL) | |
| # GPT-4o for Multimodal Generation | |
| openai_mm_llm = OpenAIMultiModal( | |
| model=self.config.LLM_MODEL, | |
| max_new_tokens=512 | |
| ) | |
| # 2. Load Index | |
| logger.info(f"Loading index from {self.config.INDEX_DIR}...") | |
| storage_context = StorageContext.from_defaults(persist_dir=self.config.INDEX_DIR) | |
| self.index = load_index_from_storage( | |
| storage_context, | |
| embed_model=text_embed, | |
| image_embed_model=image_embed | |
| ) | |
| # 3. Create Query Engine | |
| self.query_engine = self.index.as_query_engine( | |
| llm=openai_mm_llm, | |
| similarity_top_k=self.config.TOP_K, | |
| image_similarity_top_k=self.config.TOP_K | |
| ) | |
| logger.info(f"System Ready! Model: {self.config.LLM_MODEL}") | |
| def ask(self, query_str: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict: | |
| """ | |
| Query the RAG system with optional chat history for context. | |
| Args: | |
| query_str: The user's question | |
| chat_history: List of dicts with 'role' and 'content' keys | |
| Returns: | |
| Dict with 'answer', 'images', and 'texts' keys | |
| """ | |
| if not self.query_engine: | |
| raise RuntimeError("Query engine not initialized") | |
| logger.info(f"Original question: {query_str}") | |
| # Rewrite follow-up into standalone question if history exists | |
| standalone_question = query_str | |
| if chat_history and len(chat_history) > 0: | |
| # Convert chat history to context string | |
| history_text = "\n".join( | |
| f"{turn['role'].capitalize()}: {turn['content']}" | |
| for turn in chat_history[-4:] # last 2 turns (4 messages) | |
| ) | |
| rewrite_prompt = ( | |
| "Given the previous conversation and the follow-up question, " | |
| "rewrite the follow-up question as a standalone question that includes all necessary context.\n\n" | |
| f"Conversation history:\n{history_text}\n\n" | |
| f"Follow-up question:\n{query_str}\n\n" | |
| "Rewrite this as a standalone question that can be understood without the conversation history. " | |
| "Only output the rewritten question, nothing else.\n\n" | |
| "Standalone question:" | |
| ) | |
| standalone_question = self.rewrite_llm.complete( | |
| rewrite_prompt | |
| ).text.strip() | |
| logger.info(f"Rewritten question: {standalone_question}") | |
| response = self.query_engine.query(standalone_question) | |
| source_images = [] | |
| source_texts = [] | |
| for node_score in response.source_nodes: | |
| node = node_score.node | |
| if node.metadata.get("image_source"): | |
| # It's an image node | |
| # Try to get image path from node attribute or metadata | |
| img_path = getattr(node, "image_path", None) or node.metadata.get("image_path") | |
| # Normalize path if possible to be relative or filename | |
| if img_path: | |
| img_filename = os.path.basename(img_path) | |
| # We assume app.py serves 'extracted_images' as static | |
| # So let's provide a relative web path or just the filename for app.py to handle | |
| web_path = f"/extracted_images/{img_filename}" | |
| else: | |
| web_path = None | |
| img_filename = "unknown" | |
| source_images.append({ | |
| "path": web_path, | |
| "filename": img_filename, | |
| "score": node_score.score, | |
| "page": node.metadata.get("page_number"), | |
| "file": node.metadata.get("file_name") | |
| }) | |
| else: | |
| # Text node | |
| file_name = node.metadata.get("file_name", "N/A") | |
| page_num = node.metadata.get("page_number", "N/A") | |
| web_link = None | |
| if file_name != "N/A": | |
| # URL encode the filename to handle spaces and special chars safely | |
| from urllib.parse import quote | |
| safe_filename = quote(file_name) | |
| web_link = f"/documents/{safe_filename}" | |
| if page_num != "N/A": | |
| web_link += f"#page={page_num}" | |
| # DEBUG: Print link construction details | |
| logger.info(f"DEBUG: File: {file_name}, Page: {page_num}, Link: {web_link}") | |
| source_texts.append({ | |
| "text": node.text[:200] + "...", | |
| "score": node_score.score, | |
| "page": page_num, | |
| "file": file_name, | |
| "link": web_link | |
| }) | |
| return { | |
| "answer": str(response), | |
| "images": source_images, | |
| "texts": source_texts | |
| } | |
| # Main for simple testing | |
| def main(): | |
| try: | |
| rag = MultimodalRAGSystem() | |
| chat_hist = [] | |
| while True: | |
| q = input("Query (q to quit): ") | |
| if q.lower() == 'q': break | |
| result = rag.ask(q, chat_history=chat_hist) | |
| print(f"\nAnswer: {result['answer']}\n") | |
| # Update history | |
| chat_hist.append({"role": "user", "content": q}) | |
| chat_hist.append({"role": "assistant", "content": result['answer']}) | |
| # Keep history reasonable | |
| if len(chat_hist) > 6: | |
| chat_hist = chat_hist[-6:] | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| if __name__ == "__main__": | |
| main() |