File size: 5,411 Bytes
0400df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import requests
import os
import re

from typing import List
from utils import encode_image
from PIL import Image
from google import genai
import torch
import subprocess
import psutil
import torch
from transformers import AutoModel, AutoTokenizer
from google import genai


class Rag:
    
    def _clean_raw_token_response(self, response_text):
        """
        Clean raw token responses that contain undecoded token IDs
        This handles cases where models return raw tokens instead of decoded text
        """
        if not response_text:
            return response_text
            
        # Check if response contains raw token patterns
        token_patterns = [
            r'<unused\d+>',  # unused tokens
            r'<bos>',        # beginning of sequence
            r'<eos>',        # end of sequence
            r'<unk>',        # unknown tokens
            r'<mask>',       # mask tokens
            r'<pad>',        # padding tokens
            r'\[multimodal\]', # multimodal tokens
        ]
        
        # If response contains raw tokens, try to clean them
        has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns)
        
        if has_raw_tokens:
            print("⚠️  Detected raw token response, attempting to clean...")
            
            # Remove common raw token patterns
            cleaned_text = response_text
            
            # Remove unused tokens
            cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text)
            
            # Remove special tokens
            cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text)
            
            # Remove multimodal tokens
            cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text)
            
            # Clean up extra whitespace
            cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
            
            # If we still have mostly tokens, return an error message
            if len(cleaned_text.strip()) < 10:
                return "❌ **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs"
            
            return cleaned_text
        
        return response_text
    
    def get_answer_from_gemini(self, query: str, image_paths: List[str]) -> str:
        print(f"Querying Gemini 2.5 Pro for query={query}, image_paths={image_paths}")
        try:
            # Use environment variable GEMINI_API_KEY
            api_key = os.environ.get('GEMINI_API_KEY')
            if not api_key:
                return "Error: GEMINI_API_KEY is not set."

            genai.configure(api_key=api_key)
            model = genai.GenerativeModel('gemini-2.5-pro')

            # Load images
            images = []
            for p in image_paths:
                try:
                    images.append(Image.open(p))
                except Exception:
                    pass

            chat_session = model.start_chat()
            response = chat_session.send_message([*images, query])
            return response.text
        except Exception as e:
            print(f"Gemini error: {e}")
            return f"Error: {str(e)}"
    
    #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
    
    def get_answer_from_openai(self, query, imagesPaths):
        #import environ variables from .env
        import dotenv

        # Load the .env file
        dotenv_file = dotenv.find_dotenv()
        dotenv.load_dotenv(dotenv_file)
        
        # This function formerly used Ollama. Replace with Gemini 2.5 Pro.
        print(f"Querying Gemini (replacement for Ollama) for query={query}, imagesPaths={imagesPaths}")
        try:
            enhanced_query = f"Use all {len(imagesPaths)} pages to answer comprehensively.\n\nQuery: {query}"
            return self.get_answer_from_gemini(enhanced_query, imagesPaths)
        except Exception as e:
            print(f"Gemini replacement error: {e}")
            return None
        


    def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
        image_payload = []

        for imagePath in imagesPaths:
            base64_image = encode_image(imagePath)
            image_payload.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                }
            })

        payload = {
            "model": "Llama3.2-vision", #change model here as needed
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": query
                        },
                        *image_payload
                    ]
                }
            ],
            "max_tokens": 1024 #reduce token size to reduce processing time
        }

        return payload
    


# if __name__ == "__main__":
#     rag = Rag()
    
#     query = "Based on attached images, how many new cases were reported during second wave peak"
#     imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
    
#     rag.get_answer_from_gemini(query, imagesPaths)