File size: 6,307 Bytes
2266343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba7476
cb9efc0
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba7476
2266343
 
 
cb9efc0
8ba7476
 
2266343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

import os
import logging
from typing import Dict, List, Optional
from dotenv import load_dotenv

from llama_index.core import (
    StorageContext,
    load_index_from_storage,
    Settings
)
# Standalone imports for Multimodal RAG
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.clip import ClipEmbedding

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MultimodalRAGConfig:
    """Configuration for the Standalone Multimodal RAG Pipeline"""
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    # Hardcoded to requested paths
    INDEX_DIR = os.path.join(BASE_DIR, "multimodal_rag_index")
    IMAGES_DIR = os.path.join(BASE_DIR, "extracted_images")
    
    # Models
    TEXT_EMBED_MODEL = "text-embedding-3-small"
    IMAGE_EMBED_MODEL = "ViT-B/32"
    LLM_MODEL = "gpt-4o"
    
    TOP_K = 3  # Retrieve top 3 text and top 3 images

class MultimodalRAGSystem:
    """

    Standalone Multimodal RAG System (Read-Only)

    """
    def __init__(self):
        self.config = MultimodalRAGConfig()
        self.index = None
        self.query_engine = None
        
        self._initialize_system()
        
    def _initialize_system(self):
        logger.info("Initializing Multimodal RAG System...")
        
        if not os.path.exists(self.config.INDEX_DIR):
            logger.error(f"Index directory not found: {self.config.INDEX_DIR}")
            raise FileNotFoundError(f"Index directory not found: {self.config.INDEX_DIR}")
            
        if not os.getenv("OPENAI_API_KEY"):
             logger.error("OPENAI_API_KEY not found in environment variables.")
             raise ValueError("OPENAI_API_KEY not found.")

        # 1. Setup Models
        logger.info("Setting up models...")
        text_embed = OpenAIEmbedding(model=self.config.TEXT_EMBED_MODEL)
        image_embed = ClipEmbedding(model_name=self.config.IMAGE_EMBED_MODEL)
        
        # GPT-4o for Multimodal Generation
        openai_mm_llm = OpenAIMultiModal(
            model=self.config.LLM_MODEL,
            max_new_tokens=512
        )
        
        # 2. Load Index
        logger.info(f"Loading index from {self.config.INDEX_DIR}...")
        storage_context = StorageContext.from_defaults(persist_dir=self.config.INDEX_DIR)
        
        self.index = load_index_from_storage(
            storage_context,
            embed_model=text_embed,
            image_embed_model=image_embed
        )
        
        # 3. Create Query Engine
        self.query_engine = self.index.as_query_engine(
            llm=openai_mm_llm,
            similarity_top_k=self.config.TOP_K,
            image_similarity_top_k=self.config.TOP_K
        )
        
        logger.info(f"System Ready! Model: {self.config.LLM_MODEL}")

    def ask(self, query_str: str) -> Dict:
        """

        Ask a question and return answer + source images

        """
        if not self.query_engine:
             raise RuntimeError("Query engine not initialized")

        logger.info(f"Querying: {query_str}")
        
        response = self.query_engine.query(query_str)
        
        source_images = []
        source_texts = []
        
        for node_score in response.source_nodes:
            node = node_score.node
            if node.metadata.get("image_source"):
                # It's an image node
                # Try to get image path from node attribute or metadata
                img_path = getattr(node, "image_path", None) or node.metadata.get("image_path")
                
                # Normalize path if possible to be relative or filename
                if img_path:
                     img_filename = os.path.basename(img_path)
                     # We assume app.py serves 'extracted_images' as static
                     # So let's provide a relative web path or just the filename for app.py to handle
                     web_path = f"/extracted_images/{img_filename}"
                else:
                     web_path = None
                     img_filename = "unknown"

                source_images.append({
                    "path": web_path,
                    "filename": img_filename,
                    "score": node_score.score,
                    "page": node.metadata.get("page_number"),
                    "file": node.metadata.get("file_name")
                })
            else:
                # Text node
                file_name = node.metadata.get("file_name", "N/A")
                page_num = node.metadata.get("page_number", "N/A")
                
                web_link = None
                if file_name != "N/A":
                    # URL encode the filename to handle spaces and special chars safely
                    from urllib.parse import quote
                    safe_filename = quote(file_name)
                    web_link = f"/documents/{safe_filename}"
                    
                    if page_num != "N/A":
                        web_link += f"#page={page_num}"
                    
                    # DEBUG: Print link construction details
                    logger.info(f"DEBUG: File: {file_name}, Page: {page_num}, Link: {web_link}")
                
                source_texts.append({
                    "text": node.text[:200] + "...",
                    "score": node_score.score,
                    "page": page_num,
                    "file": file_name,
                    "link": web_link
                })
                
        return {
            "answer": str(response),
            "images": source_images,
            "texts": source_texts
        }

# Main for simple testing
def main():
    try:
        rag = MultimodalRAGSystem()
        while True:
            q = input("Query (q to quit): ")
            if q.lower() == 'q': break
            print(rag.ask(q))
    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    main()