Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import base64 | |
| import tempfile | |
| import os | |
| from mistralai import Mistral | |
| from PIL import Image | |
| import io | |
| from mistralai import DocumentURLChunk, ImageURLChunk | |
| from mistralai.models import OCRResponse | |
| #from dotenv import find_dotenv, load_dotenv | |
| from openai import OpenAI | |
| import os | |
| from dotenv import load_dotenv | |
| # OCR Processing Functions | |
| def upload_pdf(client, content, filename): | |
| """Uploads a PDF to Mistral's API and retrieves a signed URL for processing.""" | |
| if client is None: | |
| raise ValueError("Mistral client is not initialized") | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = os.path.join(temp_dir, filename) | |
| with open(temp_path, "wb") as tmp: | |
| tmp.write(content) | |
| try: | |
| with open(temp_path, "rb") as file_obj: | |
| file_upload = client.files.upload( | |
| file={"file_name": filename, "content": file_obj}, | |
| purpose="ocr" | |
| ) | |
| signed_url = client.files.get_signed_url(file_id=file_upload.id) | |
| return signed_url.url | |
| except Exception as e: | |
| raise ValueError(f"Error uploading PDF: {str(e)}") | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: | |
| """Replace image placeholders with base64 encoded images in markdown.""" | |
| for img_name, base64_str in images_dict.items(): | |
| markdown_str = markdown_str.replace(f"", f"") | |
| return markdown_str | |
| def get_combined_markdown(ocr_response: OCRResponse) -> str: | |
| """Combine markdown from all pages with their respective images.""" | |
| markdowns: list[str] = [] | |
| for page in ocr_response.pages: | |
| image_data = {} | |
| for img in page.images: | |
| image_data[img.id] = img.image_base64 | |
| markdowns.append(replace_images_in_markdown(page.markdown, image_data)) | |
| return "\n\n".join(markdowns) | |
| def process_ocr(client, document_source): | |
| """Process document with OCR API based on source type""" | |
| if client is None: | |
| raise ValueError("Mistral client is not initialized") | |
| if document_source["type"] == "document_url": | |
| return client.ocr.process( | |
| document=DocumentURLChunk(document_url=document_source["document_url"]), | |
| model="mistral-ocr-latest", | |
| include_image_base64=True | |
| ) | |
| elif document_source["type"] == "image_url": | |
| return client.ocr.process( | |
| document=ImageURLChunk(image_url=document_source["image_url"]), | |
| model="mistral-ocr-latest", | |
| include_image_base64=True | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported document source type: {document_source['type']}") | |
| load_dotenv() | |
| def generate_response(context, query): | |
| """Generate a response using OpenRouter API""" | |
| try: | |
| # Initialize OpenRouter client | |
| openrouter_api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not openrouter_api_key: | |
| return "Error: OpenRouter API key not found in environment variables." | |
| openrouter_client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=openrouter_api_key, | |
| default_headers={ | |
| "HTTP-Referer": "EnhancedRag", | |
| "X-Title": "DocumentChatApp", | |
| "User-Agent": "YourApp/1.0" | |
| } | |
| ) | |
| # Check for empty context | |
| if not context or len(context) < 10: | |
| return "Error: No document content available to answer your question." | |
| # Create a prompt with the document content and query | |
| prompt = f"""I have a document with the following content: | |
| {context} | |
| Based on this document, please answer the following question: | |
| {query} | |
| If you can find information related to the query in the document, please answer based on that information. | |
| If the document doesn't specifically mention the exact information asked, please try to infer from related content or clearly state that the specific information isn't available in the document. | |
| """ | |
| # Generate response using OpenRouter | |
| response = openrouter_client.chat.completions.create( | |
| model="meta-llama/llama-3.3-70b-instruct:free", | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful document analysis assistant."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.7, | |
| max_tokens=2048 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"Error generating response: {str(e)}") | |
| import traceback | |
| print(traceback.format_exc()) | |
| return f"Error generating response: {str(e)}" | |
| def initialize_mistral_client(api_key): | |
| """ | |
| Initialize and return a Mistral client | |
| Args: | |
| api_key (str): Mistral API key | |
| Returns: | |
| Mistral client object | |
| """ | |
| try: | |
| from mistralai import Mistral | |
| # Validate API key | |
| if not api_key: | |
| raise ValueError("API key cannot be empty") | |
| # Create and return Mistral client | |
| return Mistral(api_key=api_key) | |
| except ImportError: | |
| raise ImportError("Mistral AI library is not installed. Please install it using 'pip install mistralai'") | |
| except Exception as e: | |
| raise ValueError(f"Error initializing Mistral client: {str(e)}") | |
| def display_pdf(file_path): | |
| """ | |
| Display PDF in Streamlit app | |
| Args: | |
| file_path (str): Path to the PDF file | |
| """ | |
| try: | |
| # Open the PDF file in binary read mode | |
| with open(file_path, "rb") as file: | |
| # Read the file | |
| base64_pdf = base64.b64encode(file.read()).decode('utf-8') | |
| # Embedding PDF in HTML | |
| pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf"></iframe>' | |
| # Render PDF | |
| st.markdown(pdf_display, unsafe_allow_html=True) | |
| except FileNotFoundError: | |
| st.error(f"PDF file not found at {file_path}") | |
| except PermissionError: | |
| st.error(f"Permission denied accessing the PDF file at {file_path}") | |
| except Exception as e: | |
| st.error(f"Error displaying PDF: {str(e)}") | |
| def main(): | |
| # Load environment variables | |
| load_dotenv() | |
| # Get API keys from environment variables | |
| mistral_api_key = os.getenv("MISTRAL_API_KEY") | |
| openrouter_api_key = os.getenv("OPENROUTER_API_KEY") | |
| st.set_page_config(page_title="Document OCR & Chat", layout="wide") | |
| # Remove API key input sections from sidebar | |
| st.sidebar.header("Document Processing") | |
| # Initialize Mistral client | |
| mistral_client = None | |
| if mistral_api_key: | |
| try: | |
| mistral_client = initialize_mistral_client(mistral_api_key) | |
| st.sidebar.success("✅ Mistral API connected successfully") | |
| except Exception as e: | |
| st.sidebar.error(f"Failed to initialize Mistral client: {str(e)}") | |
| # Check OpenRouter API key | |
| if not openrouter_api_key: | |
| st.sidebar.warning("⚠️ OpenRouter API key is missing. Please check your .env file.") | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "document_content" not in st.session_state: | |
| st.session_state.document_content = "" | |
| if "document_loaded" not in st.session_state: | |
| st.session_state.document_loaded = False | |
| # Document upload section | |
| st.subheader("Document Upload") | |
| # Only show document upload if Mistral client is initialized | |
| if mistral_client: | |
| input_method = st.radio("Select Input Type:", ["PDF Upload", "Image Upload", "URL"]) | |
| document_source = None | |
| if input_method == "URL": | |
| url = st.text_input("Document URL:") | |
| if url and st.button("Load Document from URL"): | |
| document_source = { | |
| "type": "document_url", | |
| "document_url": url | |
| } | |
| elif input_method == "PDF Upload": | |
| uploaded_file = st.file_uploader("Choose PDF file", type=["pdf"]) | |
| if uploaded_file and st.button("Process PDF"): | |
| content = uploaded_file.read() | |
| # Save the uploaded PDF temporarily for display purposes | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| tmp.write(content) | |
| pdf_path = tmp.name | |
| try: | |
| # Prepare document source for OCR processing | |
| document_source = { | |
| "type": "document_url", | |
| "document_url": upload_pdf(mistral_client, content, uploaded_file.name) | |
| } | |
| # Display the uploaded PDF | |
| st.header("Uploaded PDF") | |
| display_pdf(pdf_path) | |
| except Exception as e: | |
| st.error(f"Error processing PDF: {str(e)}") | |
| # Clean up the temporary file | |
| if os.path.exists(pdf_path): | |
| os.unlink(pdf_path) | |
| elif input_method == "Image Upload": | |
| uploaded_image = st.file_uploader("Choose Image file", type=["png", "jpg", "jpeg"]) | |
| if uploaded_image and st.button("Process Image"): | |
| try: | |
| # Display the uploaded image | |
| image = Image.open(uploaded_image) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| # Convert image to base64 | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Prepare document source for OCR processing | |
| document_source = { | |
| "type": "image_url", | |
| "image_url": f"data:image/png;base64,{img_str}" | |
| } | |
| except Exception as e: | |
| st.error(f"Error processing image: {str(e)}") | |
| # Process document if source is provided | |
| if document_source: | |
| with st.spinner("Processing document..."): | |
| try: | |
| ocr_response = process_ocr(mistral_client, document_source) | |
| if ocr_response and ocr_response.pages: | |
| # Extract all text without page markers for clean content | |
| raw_content = [] | |
| display_content = [] | |
| for i, page in enumerate(ocr_response.pages): | |
| page_content = page.markdown.strip() | |
| if page_content: # Only add non-empty pages | |
| raw_content.append(page_content) | |
| display_content.append(f"Page {i + 1}:\n{page_content}") | |
| # Join all content into one clean string for the model | |
| final_content = "\n\n".join(raw_content) | |
| display_formatted = "\n\n----------\n\n".join(display_content) | |
| # Store both versions | |
| st.session_state.document_content = final_content | |
| st.session_state.display_content = display_formatted | |
| st.session_state.document_loaded = True | |
| st.session_state.ocr_response = ocr_response | |
| # Markdown Download Section | |
| st.subheader("Download Markdown") | |
| # Full Document Download | |
| full_markdown = "\n\n----------\n\n".join(display_content) | |
| st.download_button( | |
| label="Download Full Document Markdown", | |
| data=full_markdown, | |
| file_name="document_ocr_output.md", | |
| mime="text/markdown" | |
| ) | |
| # Page-wise Download Dropdown | |
| page_options = [f"Page {i + 1}" for i in range(len(ocr_response.pages)) if | |
| ocr_response.pages[i].markdown.strip()] | |
| selected_page = st.selectbox("Select a page to download", page_options) | |
| if selected_page: | |
| page_index = page_options.index(selected_page) | |
| page_markdown = ocr_response.pages[page_index].markdown.strip() | |
| st.download_button( | |
| label=f"Download {selected_page} Markdown", | |
| data=page_markdown, | |
| file_name=f"{selected_page.lower().replace(' ', '_')}_ocr_output.md", | |
| mime="text/markdown" | |
| ) | |
| # Success message | |
| st.success( | |
| f"Document processed successfully! Extracted {len(final_content)} characters from {len(raw_content)} pages." | |
| ) | |
| else: | |
| st.warning("No content extracted from document.") | |
| except Exception as e: | |
| st.error(f"Processing error: {str(e)}") | |
| # Main area: Display chat interface | |
| st.title("Document OCR & Chat") | |
| # Document preview area | |
| if "document_loaded" in st.session_state and st.session_state.document_loaded: | |
| with st.expander("Document Content", expanded=False): | |
| # Show the display version with page numbers | |
| if "display_content" in st.session_state: | |
| st.markdown(st.session_state.display_content) | |
| else: | |
| st.markdown(st.session_state.document_content) | |
| # Chat interface | |
| st.subheader("Chat with your document") | |
| # Display chat messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Input for user query | |
| if prompt := st.chat_input("Ask a question about your document..."): | |
| # Check if Google API key is available | |
| if not openrouter_api_key : | |
| st.error("Openrouter API key is required for generating responses.") | |
| else: | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Show thinking spinner | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| # Get document content from session state | |
| document_content = st.session_state.document_content | |
| # Generate response directly | |
| response = generate_response(document_content, prompt) | |
| # Display response | |
| st.markdown(response) | |
| # Add assistant message to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| else: | |
| # Show a welcome message if no document is loaded | |
| st.info("👈 Please upload a document using the sidebar to start chatting.") | |
| if __name__ == "__main__": | |
| main() |