File size: 8,229 Bytes
bb6a871
 
 
 
f006f4f
bb6a871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f006f4f
 
 
 
bb6a871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f006f4f
339a105
f006f4f
 
 
 
 
 
 
 
339a105
f006f4f
 
339a105
f006f4f
339a105
f006f4f
 
339a105
f006f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339a105
f006f4f
 
 
339a105
 
bb6a871
339a105
 
bb6a871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f006f4f
bb6a871
 
 
f006f4f
 
 
 
 
 
 
 
 
 
 
 
bb6a871
 
 
 
f006f4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import logging
from typing import Dict, List, Optional
from dotenv import load_dotenv
from llama_index.llms.openai import OpenAI

from llama_index.core import (
    StorageContext,
    load_index_from_storage,
    Settings
)
# Standalone imports for Multimodal RAG
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.clip import ClipEmbedding

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MultimodalRAGConfig:
    """Configuration for the Standalone Multimodal RAG Pipeline"""
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    # Hardcoded to requested paths
    INDEX_DIR = os.path.join(BASE_DIR, "multimodal_rag_index")
    IMAGES_DIR = os.path.join(BASE_DIR, "extracted_images")
    
    # Models
    TEXT_EMBED_MODEL = "text-embedding-3-small"
    IMAGE_EMBED_MODEL = "ViT-B/32"
    LLM_MODEL = "gpt-4o"
    
    TOP_K = 3  # Retrieve top 3 text and top 3 images

class MultimodalRAGSystem:
    """
    Standalone Multimodal RAG System (Read-Only)
    """
    def __init__(self):
        self.config = MultimodalRAGConfig()
        self.index = None
        self.query_engine = None
        self.rewrite_llm = OpenAI(
            model="gpt-4o-mini",  
            temperature=0.0
        )
        self._initialize_system()
        
    def _initialize_system(self):
        logger.info("Initializing Multimodal RAG System...")
        
        if not os.path.exists(self.config.INDEX_DIR):
            logger.error(f"Index directory not found: {self.config.INDEX_DIR}")
            raise FileNotFoundError(f"Index directory not found: {self.config.INDEX_DIR}")
            
        if not os.getenv("OPENAI_API_KEY"):
             logger.error("OPENAI_API_KEY not found in environment variables.")
             raise ValueError("OPENAI_API_KEY not found.")

        # 1. Setup Models
        logger.info("Setting up models...")
        text_embed = OpenAIEmbedding(model=self.config.TEXT_EMBED_MODEL)
        image_embed = ClipEmbedding(model_name=self.config.IMAGE_EMBED_MODEL)
        
        # GPT-4o for Multimodal Generation
        openai_mm_llm = OpenAIMultiModal(
            model=self.config.LLM_MODEL,
            max_new_tokens=512
        )
        
        # 2. Load Index
        logger.info(f"Loading index from {self.config.INDEX_DIR}...")
        storage_context = StorageContext.from_defaults(persist_dir=self.config.INDEX_DIR)
        
        self.index = load_index_from_storage(
            storage_context,
            embed_model=text_embed,
            image_embed_model=image_embed
        )
        
        # 3. Create Query Engine
        self.query_engine = self.index.as_query_engine(
            llm=openai_mm_llm,
            similarity_top_k=self.config.TOP_K,
            image_similarity_top_k=self.config.TOP_K
        )
        
        logger.info(f"System Ready! Model: {self.config.LLM_MODEL}")

    def ask(self, query_str: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict:
        """
        Query the RAG system with optional chat history for context.
        
        Args:
            query_str: The user's question
            chat_history: List of dicts with 'role' and 'content' keys
            
        Returns:
            Dict with 'answer', 'images', and 'texts' keys
        """
        if not self.query_engine:
            raise RuntimeError("Query engine not initialized")
    
        logger.info(f"Original question: {query_str}")
    
        # Rewrite follow-up into standalone question if history exists
        standalone_question = query_str
    
        if chat_history and len(chat_history) > 0:
            # Convert chat history to context string
            history_text = "\n".join(
                f"{turn['role'].capitalize()}: {turn['content']}"
                for turn in chat_history[-4:]  # last 2 turns (4 messages)
            )
            
            rewrite_prompt = (
                "Given the previous conversation and the follow-up question, "
                "rewrite the follow-up question as a standalone question that includes all necessary context.\n\n"
                f"Conversation history:\n{history_text}\n\n"
                f"Follow-up question:\n{query_str}\n\n"
                "Rewrite this as a standalone question that can be understood without the conversation history. "
                "Only output the rewritten question, nothing else.\n\n"
                "Standalone question:"
            )
    
            standalone_question = self.rewrite_llm.complete(
                rewrite_prompt
            ).text.strip()
    
            logger.info(f"Rewritten question: {standalone_question}")
        
        response = self.query_engine.query(standalone_question)

        
        source_images = []
        source_texts = []
        
        for node_score in response.source_nodes:
            node = node_score.node
            if node.metadata.get("image_source"):
                # It's an image node
                # Try to get image path from node attribute or metadata
                img_path = getattr(node, "image_path", None) or node.metadata.get("image_path")
                
                # Normalize path if possible to be relative or filename
                if img_path:
                     img_filename = os.path.basename(img_path)
                     # We assume app.py serves 'extracted_images' as static
                     # So let's provide a relative web path or just the filename for app.py to handle
                     web_path = f"/extracted_images/{img_filename}"
                else:
                     web_path = None
                     img_filename = "unknown"

                source_images.append({
                    "path": web_path,
                    "filename": img_filename,
                    "score": node_score.score,
                    "page": node.metadata.get("page_number"),
                    "file": node.metadata.get("file_name")
                })
            else:
                # Text node
                file_name = node.metadata.get("file_name", "N/A")
                page_num = node.metadata.get("page_number", "N/A")
                
                web_link = None
                if file_name != "N/A":
                    # URL encode the filename to handle spaces and special chars safely
                    from urllib.parse import quote
                    safe_filename = quote(file_name)
                    web_link = f"/documents/{safe_filename}"
                    
                    if page_num != "N/A":
                        web_link += f"#page={page_num}"
                    
                    # DEBUG: Print link construction details
                    logger.info(f"DEBUG: File: {file_name}, Page: {page_num}, Link: {web_link}")
                
                source_texts.append({
                    "text": node.text[:200] + "...",
                    "score": node_score.score,
                    "page": page_num,
                    "file": file_name,
                    "link": web_link
                })

        return {
            "answer": str(response),
            "images": source_images,
            "texts": source_texts
        }

# Main for simple testing
def main():
    try:
        rag = MultimodalRAGSystem()
        chat_hist = []
        while True:
            q = input("Query (q to quit): ")
            if q.lower() == 'q': break
            
            result = rag.ask(q, chat_history=chat_hist)
            print(f"\nAnswer: {result['answer']}\n")
            
            # Update history
            chat_hist.append({"role": "user", "content": q})
            chat_hist.append({"role": "assistant", "content": result['answer']})
            
            # Keep history reasonable
            if len(chat_hist) > 6:
                chat_hist = chat_hist[-6:]
                
    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    main()