dhanvanth183 commited on
Commit
4b81143
·
verified ·
1 Parent(s): f1e15fb

Upload mistralocr_gemma.py

Browse files
Files changed (1) hide show
  1. mistralocr_gemma.py +408 -0
mistralocr_gemma.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import base64
3
+ import tempfile
4
+ import os
5
+ from mistralai import Mistral
6
+ from PIL import Image
7
+ import io
8
+ from mistralai import DocumentURLChunk, ImageURLChunk
9
+ from mistralai.models import OCRResponse
10
+ #from dotenv import find_dotenv, load_dotenv
11
+
12
+ from openai import OpenAI
13
+ import os
14
+ from dotenv import load_dotenv
15
+
16
+ # OCR Processing Functions
17
+ def upload_pdf(client, content, filename):
18
+ """Uploads a PDF to Mistral's API and retrieves a signed URL for processing."""
19
+ if client is None:
20
+ raise ValueError("Mistral client is not initialized")
21
+
22
+ with tempfile.TemporaryDirectory() as temp_dir:
23
+ temp_path = os.path.join(temp_dir, filename)
24
+
25
+ with open(temp_path, "wb") as tmp:
26
+ tmp.write(content)
27
+
28
+ try:
29
+ with open(temp_path, "rb") as file_obj:
30
+ file_upload = client.files.upload(
31
+ file={"file_name": filename, "content": file_obj},
32
+ purpose="ocr"
33
+ )
34
+
35
+ signed_url = client.files.get_signed_url(file_id=file_upload.id)
36
+ return signed_url.url
37
+ except Exception as e:
38
+ raise ValueError(f"Error uploading PDF: {str(e)}")
39
+ finally:
40
+ if os.path.exists(temp_path):
41
+ os.remove(temp_path)
42
+
43
+ def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str:
44
+ """Replace image placeholders with base64 encoded images in markdown."""
45
+ for img_name, base64_str in images_dict.items():
46
+ markdown_str = markdown_str.replace(f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})")
47
+ return markdown_str
48
+
49
+ def get_combined_markdown(ocr_response: OCRResponse) -> str:
50
+ """Combine markdown from all pages with their respective images."""
51
+ markdowns: list[str] = []
52
+ for page in ocr_response.pages:
53
+ image_data = {}
54
+ for img in page.images:
55
+ image_data[img.id] = img.image_base64
56
+ markdowns.append(replace_images_in_markdown(page.markdown, image_data))
57
+
58
+ return "\n\n".join(markdowns)
59
+
60
+
61
+ def process_ocr(client, document_source):
62
+ """Process document with OCR API based on source type"""
63
+ if client is None:
64
+ raise ValueError("Mistral client is not initialized")
65
+
66
+ if document_source["type"] == "document_url":
67
+ return client.ocr.process(
68
+ document=DocumentURLChunk(document_url=document_source["document_url"]),
69
+ model="mistral-ocr-latest",
70
+ include_image_base64=True
71
+ )
72
+ elif document_source["type"] == "image_url":
73
+ return client.ocr.process(
74
+ document=ImageURLChunk(image_url=document_source["image_url"]),
75
+ model="mistral-ocr-latest",
76
+ include_image_base64=True
77
+ )
78
+ else:
79
+ raise ValueError(f"Unsupported document source type: {document_source['type']}")
80
+
81
+
82
+ load_dotenv()
83
+
84
+ def generate_response(context, query):
85
+ """Generate a response using OpenRouter API"""
86
+ try:
87
+ # Initialize OpenRouter client
88
+ openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
89
+ if not openrouter_api_key:
90
+ return "Error: OpenRouter API key not found in environment variables."
91
+
92
+ openrouter_client = OpenAI(
93
+ base_url="https://openrouter.ai/api/v1",
94
+ api_key=openrouter_api_key,
95
+ default_headers={
96
+ "HTTP-Referer": "https://YourAppName.com",
97
+ "X-Title": "DocumentChatApp",
98
+ "User-Agent": "YourApp/1.0"
99
+ }
100
+ )
101
+
102
+ # Check for empty context
103
+ if not context or len(context) < 10:
104
+ return "Error: No document content available to answer your question."
105
+
106
+ # Create a prompt with the document content and query
107
+ prompt = f"""I have a document with the following content:
108
+
109
+ {context}
110
+
111
+ Based on this document, please answer the following question:
112
+ {query}
113
+
114
+ If you can find information related to the query in the document, please answer based on that information.
115
+ If the document doesn't specifically mention the exact information asked, please try to infer from related content or clearly state that the specific information isn't available in the document.
116
+ """
117
+
118
+ # Generate response using OpenRouter
119
+ response = openrouter_client.chat.completions.create(
120
+ model="google/gemma-3-27b-it:free",
121
+ messages=[
122
+ {"role": "system", "content": "You are a helpful document analysis assistant."},
123
+ {"role": "user", "content": prompt}
124
+ ],
125
+ temperature=0.7,
126
+ max_tokens=2048
127
+ )
128
+
129
+ return response.choices[0].message.content
130
+
131
+ except Exception as e:
132
+ print(f"Error generating response: {str(e)}")
133
+ import traceback
134
+ print(traceback.format_exc())
135
+ return f"Error generating response: {str(e)}"
136
+
137
+
138
+ def initialize_mistral_client(api_key):
139
+ """
140
+ Initialize and return a Mistral client
141
+
142
+ Args:
143
+ api_key (str): Mistral API key
144
+
145
+ Returns:
146
+ Mistral client object
147
+ """
148
+ try:
149
+ from mistralai import Mistral
150
+
151
+ # Validate API key
152
+ if not api_key:
153
+ raise ValueError("API key cannot be empty")
154
+
155
+ # Create and return Mistral client
156
+ return Mistral(api_key=api_key)
157
+
158
+ except ImportError:
159
+ raise ImportError("Mistral AI library is not installed. Please install it using 'pip install mistralai'")
160
+ except Exception as e:
161
+ raise ValueError(f"Error initializing Mistral client: {str(e)}")
162
+
163
+
164
+ def display_pdf(file_path):
165
+ """
166
+ Display PDF in Streamlit app
167
+
168
+ Args:
169
+ file_path (str): Path to the PDF file
170
+ """
171
+ try:
172
+ # Open the PDF file in binary read mode
173
+ with open(file_path, "rb") as file:
174
+ # Read the file
175
+ base64_pdf = base64.b64encode(file.read()).decode('utf-8')
176
+
177
+ # Embedding PDF in HTML
178
+ pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf"></iframe>'
179
+
180
+ # Render PDF
181
+ st.markdown(pdf_display, unsafe_allow_html=True)
182
+
183
+ except FileNotFoundError:
184
+ st.error(f"PDF file not found at {file_path}")
185
+ except PermissionError:
186
+ st.error(f"Permission denied accessing the PDF file at {file_path}")
187
+ except Exception as e:
188
+ st.error(f"Error displaying PDF: {str(e)}")
189
+
190
+ def main():
191
+ # Load environment variables
192
+ load_dotenv()
193
+
194
+ # Get API keys from environment variables
195
+ mistral_api_key = os.getenv("MISTRAL_API_KEY")
196
+ openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
197
+
198
+ st.set_page_config(page_title="Document OCR & Chat", layout="wide")
199
+
200
+ # Remove API key input sections from sidebar
201
+ st.sidebar.header("Document Processing")
202
+
203
+ # Initialize Mistral client
204
+ mistral_client = None
205
+ if mistral_api_key:
206
+ try:
207
+ mistral_client = initialize_mistral_client(mistral_api_key)
208
+ st.sidebar.success("✅ Mistral API connected successfully")
209
+ except Exception as e:
210
+ st.sidebar.error(f"Failed to initialize Mistral client: {str(e)}")
211
+
212
+ # Check OpenRouter API key
213
+ if not openrouter_api_key:
214
+ st.sidebar.warning("⚠️ OpenRouter API key is missing. Please check your .env file.")
215
+
216
+ # Initialize session state
217
+ if "messages" not in st.session_state:
218
+ st.session_state.messages = []
219
+
220
+ if "document_content" not in st.session_state:
221
+ st.session_state.document_content = ""
222
+
223
+ if "document_loaded" not in st.session_state:
224
+ st.session_state.document_loaded = False
225
+
226
+ # Document upload section
227
+ st.subheader("Document Upload")
228
+
229
+ # Only show document upload if Mistral client is initialized
230
+ if mistral_client:
231
+ input_method = st.radio("Select Input Type:", ["PDF Upload", "Image Upload", "URL"])
232
+
233
+ document_source = None
234
+
235
+ if input_method == "URL":
236
+ url = st.text_input("Document URL:")
237
+ if url and st.button("Load Document from URL"):
238
+ document_source = {
239
+ "type": "document_url",
240
+ "document_url": url
241
+ }
242
+
243
+ elif input_method == "PDF Upload":
244
+ uploaded_file = st.file_uploader("Choose PDF file", type=["pdf"])
245
+ if uploaded_file and st.button("Process PDF"):
246
+ content = uploaded_file.read()
247
+
248
+ # Save the uploaded PDF temporarily for display purposes
249
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
250
+ tmp.write(content)
251
+ pdf_path = tmp.name
252
+
253
+ try:
254
+ # Prepare document source for OCR processing
255
+ document_source = {
256
+ "type": "document_url",
257
+ "document_url": upload_pdf(mistral_client, content, uploaded_file.name)
258
+ }
259
+
260
+ # Display the uploaded PDF
261
+ st.header("Uploaded PDF")
262
+ display_pdf(pdf_path)
263
+ except Exception as e:
264
+ st.error(f"Error processing PDF: {str(e)}")
265
+ # Clean up the temporary file
266
+ if os.path.exists(pdf_path):
267
+ os.unlink(pdf_path)
268
+
269
+ elif input_method == "Image Upload":
270
+ uploaded_image = st.file_uploader("Choose Image file", type=["png", "jpg", "jpeg"])
271
+ if uploaded_image and st.button("Process Image"):
272
+ try:
273
+ # Display the uploaded image
274
+ image = Image.open(uploaded_image)
275
+ st.image(image, caption="Uploaded Image", use_column_width=True)
276
+
277
+ # Convert image to base64
278
+ buffered = io.BytesIO()
279
+ image.save(buffered, format="PNG")
280
+ img_str = base64.b64encode(buffered.getvalue()).decode()
281
+
282
+ # Prepare document source for OCR processing
283
+ document_source = {
284
+ "type": "image_url",
285
+ "image_url": f"data:image/png;base64,{img_str}"
286
+ }
287
+ except Exception as e:
288
+ st.error(f"Error processing image: {str(e)}")
289
+
290
+ # Process document if source is provided
291
+ if document_source:
292
+ with st.spinner("Processing document..."):
293
+ try:
294
+ ocr_response = process_ocr(mistral_client, document_source)
295
+
296
+ if ocr_response and ocr_response.pages:
297
+ # Extract all text without page markers for clean content
298
+ raw_content = []
299
+ display_content = []
300
+
301
+ for i, page in enumerate(ocr_response.pages):
302
+ page_content = page.markdown.strip()
303
+ if page_content: # Only add non-empty pages
304
+ raw_content.append(page_content)
305
+ display_content.append(f"Page {i + 1}:\n{page_content}")
306
+
307
+ # Join all content into one clean string for the model
308
+ final_content = "\n\n".join(raw_content)
309
+ display_formatted = "\n\n----------\n\n".join(display_content)
310
+
311
+ # Store both versions
312
+ st.session_state.document_content = final_content
313
+ st.session_state.display_content = display_formatted
314
+ st.session_state.document_loaded = True
315
+ st.session_state.ocr_response = ocr_response
316
+
317
+ # Markdown Download Section
318
+ st.subheader("Download Markdown")
319
+
320
+ # Full Document Download
321
+ full_markdown = "\n\n----------\n\n".join(display_content)
322
+ st.download_button(
323
+ label="Download Full Document Markdown",
324
+ data=full_markdown,
325
+ file_name="document_ocr_output.md",
326
+ mime="text/markdown"
327
+ )
328
+
329
+ # Page-wise Download Dropdown
330
+ page_options = [f"Page {i + 1}" for i in range(len(ocr_response.pages)) if
331
+ ocr_response.pages[i].markdown.strip()]
332
+ selected_page = st.selectbox("Select a page to download", page_options)
333
+
334
+ if selected_page:
335
+ page_index = page_options.index(selected_page)
336
+ page_markdown = ocr_response.pages[page_index].markdown.strip()
337
+
338
+ st.download_button(
339
+ label=f"Download {selected_page} Markdown",
340
+ data=page_markdown,
341
+ file_name=f"{selected_page.lower().replace(' ', '_')}_ocr_output.md",
342
+ mime="text/markdown"
343
+ )
344
+
345
+ # Success message
346
+ st.success(
347
+ f"Document processed successfully! Extracted {len(final_content)} characters from {len(raw_content)} pages."
348
+ )
349
+ else:
350
+ st.warning("No content extracted from document.")
351
+
352
+ except Exception as e:
353
+ st.error(f"Processing error: {str(e)}")
354
+
355
+ # Main area: Display chat interface
356
+ st.title("Document OCR & Chat")
357
+
358
+ # Document preview area
359
+ if "document_loaded" in st.session_state and st.session_state.document_loaded:
360
+ with st.expander("Document Content", expanded=False):
361
+ # Show the display version with page numbers
362
+ if "display_content" in st.session_state:
363
+ st.markdown(st.session_state.display_content)
364
+ else:
365
+ st.markdown(st.session_state.document_content)
366
+
367
+ # Chat interface
368
+ st.subheader("Chat with your document")
369
+
370
+ # Display chat messages
371
+ for message in st.session_state.messages:
372
+ with st.chat_message(message["role"]):
373
+ st.markdown(message["content"])
374
+
375
+ # Input for user query
376
+ if prompt := st.chat_input("Ask a question about your document..."):
377
+ # Check if Google API key is available
378
+ if not openrouter_api_key :
379
+ st.error("Openrouter API key is required for generating responses.")
380
+ else:
381
+ # Add user message to chat history
382
+ st.session_state.messages.append({"role": "user", "content": prompt})
383
+
384
+ # Display user message
385
+ with st.chat_message("user"):
386
+ st.markdown(prompt)
387
+
388
+ # Show thinking spinner
389
+ with st.chat_message("assistant"):
390
+ with st.spinner("Thinking..."):
391
+ # Get document content from session state
392
+ document_content = st.session_state.document_content
393
+
394
+ # Generate response directly
395
+ response = generate_response(document_content, prompt)
396
+
397
+ # Display response
398
+ st.markdown(response)
399
+
400
+ # Add assistant message to chat history
401
+ st.session_state.messages.append({"role": "assistant", "content": response})
402
+ else:
403
+ # Show a welcome message if no document is loaded
404
+ st.info("👈 Please upload a document using the sidebar to start chatting.")
405
+
406
+
407
+ if __name__ == "__main__":
408
+ main()