Spaces:
Runtime error
Runtime error
| import requests | |
| import os | |
| import re | |
| from typing import List | |
| from utils import encode_image | |
| from PIL import Image | |
| from ollama import chat | |
| import torch | |
| import subprocess | |
| import psutil | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| from google import genai | |
| class Rag: | |
| def _clean_raw_token_response(self, response_text): | |
| """ | |
| Clean raw token responses that contain undecoded token IDs | |
| This handles cases where models return raw tokens instead of decoded text | |
| """ | |
| if not response_text: | |
| return response_text | |
| # Check if response contains raw token patterns | |
| token_patterns = [ | |
| r'<unused\d+>', # unused tokens | |
| r'<bos>', # beginning of sequence | |
| r'<eos>', # end of sequence | |
| r'<unk>', # unknown tokens | |
| r'<mask>', # mask tokens | |
| r'<pad>', # padding tokens | |
| r'\[multimodal\]', # multimodal tokens | |
| ] | |
| # If response contains raw tokens, try to clean them | |
| has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns) | |
| if has_raw_tokens: | |
| print("β οΈ Detected raw token response, attempting to clean...") | |
| # Remove common raw token patterns | |
| cleaned_text = response_text | |
| # Remove unused tokens | |
| cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text) | |
| # Remove special tokens | |
| cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text) | |
| # Remove multimodal tokens | |
| cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text) | |
| # Clean up extra whitespace | |
| cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() | |
| # If we still have mostly tokens, return an error message | |
| if len(cleaned_text.strip()) < 10: | |
| return "β **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs" | |
| return cleaned_text | |
| return response_text | |
| def get_answer_from_gemini(self, query, imagePaths): | |
| print(f"Querying Gemini for query={query}, imagePaths={imagePaths}") | |
| try: | |
| client = genai.Client(api_key="AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs") | |
| images = [Image.open(path) for path in imagePaths] | |
| response = client.models.generate_content( | |
| model="gemini-2.5-pro", | |
| contents=[images, query], | |
| ) | |
| print(response.text) | |
| answer = response.text | |
| return answer | |
| except Exception as e: | |
| print(f"An error occurred while querying Gemini: {e}") | |
| return f"Error: {str(e)}" | |
| #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work" | |
| def get_answer_from_openai(self, query, imagesPaths): | |
| #import environ variables from .env | |
| import dotenv | |
| # Load the .env file | |
| dotenv_file = dotenv.find_dotenv() | |
| dotenv.load_dotenv(dotenv_file) | |
| #ollama method below | |
| torch.cuda.empty_cache() #release cuda so that ollama can use gpu! | |
| os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1" | |
| if os.environ['ollama'] == "minicpm-v": | |
| os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version | |
| elif os.environ['ollama'] == "gemma3": | |
| os.environ['ollama'] = "gemma3:12b" #set to upscaled version 12b when needed | |
| # Add specific environment variables for Gemma3 to prevent raw token issues | |
| os.environ['OLLAMA_KEEP_ALIVE'] = "5m" | |
| os.environ['OLLAMA_ORIGINS'] = "*" | |
| # Close model thread (colpali) | |
| print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}") | |
| try: | |
| # Enhanced prompt for more detailed responses with explicit page usage | |
| enhanced_query = f""" | |
| Please provide a comprehensive and detailed answer to the following query. | |
| Use ALL available information from the provided document images to give a thorough response. | |
| Query: {query} | |
| CRITICAL INSTRUCTIONS: | |
| - You have been provided with {len(imagesPaths)} document page(s) | |
| - You MUST reference information from ALL {len(imagesPaths)} page(s) in your response | |
| - Do not skip any pages - each page contains relevant information | |
| - If you mention one page, you must also mention the others | |
| - Ensure your response reflects the complete information from all pages | |
| Instructions for detailed response: | |
| 1. Provide extensive background information and context | |
| 2. Include specific details, examples, and data points from ALL documents | |
| 3. Explain concepts thoroughly with step-by-step breakdowns | |
| 4. Provide comprehensive analysis rather than simple answers when requested | |
| 5. Explicitly reference each page and what information it contributes | |
| 6. Cross-reference information between pages when relevant | |
| 7. Ensure no page is left unmentioned in your analysis | |
| SPECIAL INSTRUCTIONS FOR TABULAR DATA: | |
| - If the query requests a table, list, or structured data, organize your response in a clear, structured format | |
| - Use numbered lists, bullet points, or clear categories when appropriate | |
| - Include specific data points or comparisons when available | |
| - Structure information in a way that can be easily converted to a table format | |
| IMPORTANT: Respond with natural, human-readable text only. Do not include any special tokens, codes, or technical identifiers in your response. | |
| Make sure to acknowledge and use information from all {len(imagesPaths)} provided pages. | |
| """ | |
| # Try with current model first | |
| current_model = os.environ['ollama'] | |
| # Set different options based on the model | |
| if "gemma3" in current_model.lower(): | |
| # Specific options for Gemma3 to prevent raw token issues | |
| model_options = { | |
| "num_predict": 1024, # Shorter responses for Gemma3 | |
| "stop": ["<eos>", "<|endoftext|>", "</s>", "<|im_end|>"], # More stop tokens | |
| "top_k": 20, # Lower top_k for more focused generation | |
| "top_p": 0.8, # Lower top_p for more deterministic output | |
| "repeat_penalty": 1.2, # Higher repeat penalty | |
| "seed": 42, # Consistent results | |
| "temperature": 0.7, # Lower temperature for more focused responses | |
| } | |
| else: | |
| # Default options for other models | |
| model_options = { | |
| "num_predict": 2048, # Limit response length | |
| "stop": ["<eos>", "<|endoftext|>", "</s>"], # Stop at end tokens | |
| "top_k": 40, # Reduce randomness | |
| "top_p": 0.9, # Nucleus sampling | |
| "repeat_penalty": 1.1, # Prevent repetition | |
| "seed": 42, # Consistent results | |
| } | |
| response = chat( | |
| model=current_model, | |
| messages=[ | |
| { | |
| 'role': 'user', | |
| 'content': enhanced_query, | |
| 'images': imagesPaths, | |
| "temperature":float(os.environ['temperature']), #test if temp makes a diff | |
| } | |
| ], | |
| options=model_options | |
| ) | |
| answer = response.message.content | |
| # Clean the response to handle raw token issues | |
| cleaned_answer = self._clean_raw_token_response(answer) | |
| # If the cleaned answer is still problematic, try fallback models | |
| if cleaned_answer and "β **Model Response Error**" in cleaned_answer: | |
| print(f"β οΈ Primary model {current_model} failed, trying fallback models...") | |
| # List of fallback models to try | |
| fallback_models = [ | |
| "llama3.2-vision:latest", | |
| "llava:latest", | |
| "bakllava:latest", | |
| "llama3.2:latest" | |
| ] | |
| for fallback_model in fallback_models: | |
| try: | |
| print(f"π Trying fallback model: {fallback_model}") | |
| response = chat( | |
| model=fallback_model, | |
| messages=[ | |
| { | |
| 'role': 'user', | |
| 'content': enhanced_query, | |
| 'images': imagesPaths, | |
| "temperature":float(os.environ['temperature']), | |
| } | |
| ], | |
| options={ | |
| "num_predict": 2048, | |
| "stop": ["<eos>", "<|endoftext|>", "</s>"], | |
| "top_k": 40, | |
| "top_p": 0.9, | |
| "repeat_penalty": 1.1, | |
| "seed": 42, | |
| } | |
| ) | |
| fallback_answer = response.message.content | |
| cleaned_fallback = self._clean_raw_token_response(fallback_answer) | |
| if cleaned_fallback and "β **Model Response Error**" not in cleaned_fallback: | |
| print(f"β Fallback model {fallback_model} succeeded") | |
| return cleaned_fallback | |
| except Exception as fallback_error: | |
| print(f"β Fallback model {fallback_model} failed: {fallback_error}") | |
| continue | |
| # If all fallbacks fail, return the original error | |
| return cleaned_answer | |
| print(f"Original response: {answer}") | |
| print(f"Cleaned response: {cleaned_answer}") | |
| return cleaned_answer | |
| except Exception as e: | |
| print(f"An error occurred while querying OpenAI: {e}") | |
| return None | |
| def __get_openai_api_payload(self, query:str, imagesPaths:List[str]): | |
| image_payload = [] | |
| for imagePath in imagesPaths: | |
| base64_image = encode_image(imagePath) | |
| image_payload.append({ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}" | |
| } | |
| }) | |
| payload = { | |
| "model": "Llama3.2-vision", #change model here as needed | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": query | |
| }, | |
| *image_payload | |
| ] | |
| } | |
| ], | |
| "max_tokens": 1024 #reduce token size to reduce processing time | |
| } | |
| return payload | |
| # if __name__ == "__main__": | |
| # rag = Rag() | |
| # query = "Based on attached images, how many new cases were reported during second wave peak" | |
| # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"] | |
| # rag.get_answer_from_gemini(query, imagesPaths) |