import streamlit as st import base64 import tempfile import os from mistralai import Mistral from PIL import Image import io from mistralai import DocumentURLChunk, ImageURLChunk from mistralai.models import OCRResponse #from dotenv import find_dotenv, load_dotenv from openai import OpenAI import os from dotenv import load_dotenv # OCR Processing Functions def upload_pdf(client, content, filename): """Uploads a PDF to Mistral's API and retrieves a signed URL for processing.""" if client is None: raise ValueError("Mistral client is not initialized") with tempfile.TemporaryDirectory() as temp_dir: temp_path = os.path.join(temp_dir, filename) with open(temp_path, "wb") as tmp: tmp.write(content) try: with open(temp_path, "rb") as file_obj: file_upload = client.files.upload( file={"file_name": filename, "content": file_obj}, purpose="ocr" ) signed_url = client.files.get_signed_url(file_id=file_upload.id) return signed_url.url except Exception as e: raise ValueError(f"Error uploading PDF: {str(e)}") finally: if os.path.exists(temp_path): os.remove(temp_path) def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: """Replace image placeholders with base64 encoded images in markdown.""" for img_name, base64_str in images_dict.items(): markdown_str = markdown_str.replace(f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})") return markdown_str def get_combined_markdown(ocr_response: OCRResponse) -> str: """Combine markdown from all pages with their respective images.""" markdowns: list[str] = [] for page in ocr_response.pages: image_data = {} for img in page.images: image_data[img.id] = img.image_base64 markdowns.append(replace_images_in_markdown(page.markdown, image_data)) return "\n\n".join(markdowns) def process_ocr(client, document_source): """Process document with OCR API based on source type""" if client is None: raise ValueError("Mistral client is not initialized") if document_source["type"] == "document_url": return client.ocr.process( document=DocumentURLChunk(document_url=document_source["document_url"]), model="mistral-ocr-latest", include_image_base64=True ) elif document_source["type"] == "image_url": return client.ocr.process( document=ImageURLChunk(image_url=document_source["image_url"]), model="mistral-ocr-latest", include_image_base64=True ) else: raise ValueError(f"Unsupported document source type: {document_source['type']}") load_dotenv() def generate_response(context, query): """Generate a response using OpenRouter API""" try: # Initialize OpenRouter client openrouter_api_key = os.getenv("OPENROUTER_API_KEY") if not openrouter_api_key: return "Error: OpenRouter API key not found in environment variables." openrouter_client = OpenAI( base_url="https://openrouter.ai/api/v1", api_key=openrouter_api_key, default_headers={ "HTTP-Referer": "EnhancedRag", "X-Title": "DocumentChatApp", "User-Agent": "YourApp/1.0" } ) # Check for empty context if not context or len(context) < 10: return "Error: No document content available to answer your question." # Create a prompt with the document content and query prompt = f"""I have a document with the following content: {context} Based on this document, please answer the following question: {query} If you can find information related to the query in the document, please answer based on that information. If the document doesn't specifically mention the exact information asked, please try to infer from related content or clearly state that the specific information isn't available in the document. """ # Generate response using OpenRouter response = openrouter_client.chat.completions.create( model="meta-llama/llama-3.3-70b-instruct:free", messages=[ {"role": "system", "content": "You are a helpful document analysis assistant."}, {"role": "user", "content": prompt} ], temperature=0.7, max_tokens=2048 ) return response.choices[0].message.content except Exception as e: print(f"Error generating response: {str(e)}") import traceback print(traceback.format_exc()) return f"Error generating response: {str(e)}" def initialize_mistral_client(api_key): """ Initialize and return a Mistral client Args: api_key (str): Mistral API key Returns: Mistral client object """ try: from mistralai import Mistral # Validate API key if not api_key: raise ValueError("API key cannot be empty") # Create and return Mistral client return Mistral(api_key=api_key) except ImportError: raise ImportError("Mistral AI library is not installed. Please install it using 'pip install mistralai'") except Exception as e: raise ValueError(f"Error initializing Mistral client: {str(e)}") def display_pdf(file_path): """ Display PDF in Streamlit app Args: file_path (str): Path to the PDF file """ try: # Open the PDF file in binary read mode with open(file_path, "rb") as file: # Read the file base64_pdf = base64.b64encode(file.read()).decode('utf-8') # Embedding PDF in HTML pdf_display = f'' # Render PDF st.markdown(pdf_display, unsafe_allow_html=True) except FileNotFoundError: st.error(f"PDF file not found at {file_path}") except PermissionError: st.error(f"Permission denied accessing the PDF file at {file_path}") except Exception as e: st.error(f"Error displaying PDF: {str(e)}") def main(): # Load environment variables load_dotenv() # Get API keys from environment variables mistral_api_key = os.getenv("MISTRAL_API_KEY") openrouter_api_key = os.getenv("OPENROUTER_API_KEY") st.set_page_config(page_title="Document OCR & Chat", layout="wide") # Remove API key input sections from sidebar st.sidebar.header("Document Processing") # Initialize Mistral client mistral_client = None if mistral_api_key: try: mistral_client = initialize_mistral_client(mistral_api_key) st.sidebar.success("✅ Mistral API connected successfully") except Exception as e: st.sidebar.error(f"Failed to initialize Mistral client: {str(e)}") # Check OpenRouter API key if not openrouter_api_key: st.sidebar.warning("⚠️ OpenRouter API key is missing. Please check your .env file.") # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "document_content" not in st.session_state: st.session_state.document_content = "" if "document_loaded" not in st.session_state: st.session_state.document_loaded = False # Document upload section st.subheader("Document Upload") # Only show document upload if Mistral client is initialized if mistral_client: input_method = st.radio("Select Input Type:", ["PDF Upload", "Image Upload", "URL"]) document_source = None if input_method == "URL": url = st.text_input("Document URL:") if url and st.button("Load Document from URL"): document_source = { "type": "document_url", "document_url": url } elif input_method == "PDF Upload": uploaded_file = st.file_uploader("Choose PDF file", type=["pdf"]) if uploaded_file and st.button("Process PDF"): content = uploaded_file.read() # Save the uploaded PDF temporarily for display purposes with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: tmp.write(content) pdf_path = tmp.name try: # Prepare document source for OCR processing document_source = { "type": "document_url", "document_url": upload_pdf(mistral_client, content, uploaded_file.name) } # Display the uploaded PDF st.header("Uploaded PDF") display_pdf(pdf_path) except Exception as e: st.error(f"Error processing PDF: {str(e)}") # Clean up the temporary file if os.path.exists(pdf_path): os.unlink(pdf_path) elif input_method == "Image Upload": uploaded_image = st.file_uploader("Choose Image file", type=["png", "jpg", "jpeg"]) if uploaded_image and st.button("Process Image"): try: # Display the uploaded image image = Image.open(uploaded_image) st.image(image, caption="Uploaded Image", use_column_width=True) # Convert image to base64 buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() # Prepare document source for OCR processing document_source = { "type": "image_url", "image_url": f"data:image/png;base64,{img_str}" } except Exception as e: st.error(f"Error processing image: {str(e)}") # Process document if source is provided if document_source: with st.spinner("Processing document..."): try: ocr_response = process_ocr(mistral_client, document_source) if ocr_response and ocr_response.pages: # Extract all text without page markers for clean content raw_content = [] display_content = [] for i, page in enumerate(ocr_response.pages): page_content = page.markdown.strip() if page_content: # Only add non-empty pages raw_content.append(page_content) display_content.append(f"Page {i + 1}:\n{page_content}") # Join all content into one clean string for the model final_content = "\n\n".join(raw_content) display_formatted = "\n\n----------\n\n".join(display_content) # Store both versions st.session_state.document_content = final_content st.session_state.display_content = display_formatted st.session_state.document_loaded = True st.session_state.ocr_response = ocr_response # Markdown Download Section st.subheader("Download Markdown") # Full Document Download full_markdown = "\n\n----------\n\n".join(display_content) st.download_button( label="Download Full Document Markdown", data=full_markdown, file_name="document_ocr_output.md", mime="text/markdown" ) # Page-wise Download Dropdown page_options = [f"Page {i + 1}" for i in range(len(ocr_response.pages)) if ocr_response.pages[i].markdown.strip()] selected_page = st.selectbox("Select a page to download", page_options) if selected_page: page_index = page_options.index(selected_page) page_markdown = ocr_response.pages[page_index].markdown.strip() st.download_button( label=f"Download {selected_page} Markdown", data=page_markdown, file_name=f"{selected_page.lower().replace(' ', '_')}_ocr_output.md", mime="text/markdown" ) # Success message st.success( f"Document processed successfully! Extracted {len(final_content)} characters from {len(raw_content)} pages." ) else: st.warning("No content extracted from document.") except Exception as e: st.error(f"Processing error: {str(e)}") # Main area: Display chat interface st.title("Document OCR & Chat") # Document preview area if "document_loaded" in st.session_state and st.session_state.document_loaded: with st.expander("Document Content", expanded=False): # Show the display version with page numbers if "display_content" in st.session_state: st.markdown(st.session_state.display_content) else: st.markdown(st.session_state.document_content) # Chat interface st.subheader("Chat with your document") # Display chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Input for user query if prompt := st.chat_input("Ask a question about your document..."): # Check if Google API key is available if not openrouter_api_key : st.error("Openrouter API key is required for generating responses.") else: # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) # Display user message with st.chat_message("user"): st.markdown(prompt) # Show thinking spinner with st.chat_message("assistant"): with st.spinner("Thinking..."): # Get document content from session state document_content = st.session_state.document_content # Generate response directly response = generate_response(document_content, prompt) # Display response st.markdown(response) # Add assistant message to chat history st.session_state.messages.append({"role": "assistant", "content": response}) else: # Show a welcome message if no document is loaded st.info("👈 Please upload a document using the sidebar to start chatting.") if __name__ == "__main__": main()