|
|
""" |
|
|
Author : Janarddan Sarkar |
|
|
file_name : mistral_ocr_st.py |
|
|
date : 10-03-2025 |
|
|
description : |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import streamlit as st |
|
|
from mistralai import Mistral |
|
|
from dotenv import find_dotenv, load_dotenv |
|
|
from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk |
|
|
from mistralai.models import OCRResponse |
|
|
from enum import Enum |
|
|
from pydantic import BaseModel |
|
|
import pycountry |
|
|
|
|
|
|
|
|
load_dotenv(find_dotenv()) |
|
|
api_key = os.environ.get("MISTRAL_API_KEY") |
|
|
client = Mistral(api_key=api_key) |
|
|
|
|
|
|
|
|
languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')} |
|
|
|
|
|
|
|
|
class LanguageMeta(Enum.__class__): |
|
|
def __new__(metacls, cls, bases, classdict): |
|
|
for code, name in languages.items(): |
|
|
classdict[name.upper().replace(' ', '_')] = name |
|
|
return super().__new__(metacls, cls, bases, classdict) |
|
|
|
|
|
|
|
|
class Language(Enum, metaclass=LanguageMeta): |
|
|
pass |
|
|
|
|
|
|
|
|
class StructuredOCR(BaseModel): |
|
|
file_name: str |
|
|
topics: list[str] |
|
|
languages: list[Language] |
|
|
ocr_contents: dict |
|
|
|
|
|
def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: |
|
|
for img_name, base64_str in images_dict.items(): |
|
|
markdown_str = markdown_str.replace(f"", f"") |
|
|
return markdown_str |
|
|
|
|
|
def get_combined_markdown(ocr_response: OCRResponse) -> str: |
|
|
markdowns: list[str] = [] |
|
|
for page in ocr_response.pages: |
|
|
image_data = {img.id: img.image_base64 for img in page.images} |
|
|
markdowns.append(replace_images_in_markdown(page.markdown, image_data)) |
|
|
return "\n\n".join(markdowns) |
|
|
|
|
|
def process_pdf(pdf_bytes, file_name): |
|
|
"""Process a PDF using OCR.""" |
|
|
uploaded_file = client.files.upload( |
|
|
file={"file_name": file_name, "content": pdf_bytes}, |
|
|
purpose = "ocr", |
|
|
) |
|
|
signed_url = client.files.get_signed_url(file_id=uploaded_file.id, expiry=1) |
|
|
pdf_response = client.ocr.process( |
|
|
document=DocumentURLChunk(document_url=signed_url.url), |
|
|
model="mistral-ocr-latest", |
|
|
include_image_base64=True, |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(pdf_response, dict): |
|
|
pdf_response = OCRResponse(**pdf_response) |
|
|
|
|
|
return pdf_response |
|
|
|
|
|
|
|
|
def process_image(image_bytes, file_name): |
|
|
"""Process an image using OCR.""" |
|
|
encoded_image = base64.b64encode(image_bytes).decode() |
|
|
base64_data_url = f"data:image/jpeg;base64,{encoded_image}" |
|
|
image_response = client.ocr.process( |
|
|
document=ImageURLChunk(image_url=base64_data_url), model="mistral-ocr-latest" |
|
|
) |
|
|
image_ocr_markdown = image_response.pages[0].markdown |
|
|
|
|
|
chat_response = client.chat.parse( |
|
|
model="pixtral-12b-latest", |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
ImageURLChunk(image_url=base64_data_url), |
|
|
TextChunk( |
|
|
text=( |
|
|
"This is the image's OCR in markdown:\n" |
|
|
f"<BEGIN_IMAGE_OCR>\n{image_ocr_markdown}\n<END_IMAGE_OCR>.\n" |
|
|
"Convert this into a structured JSON response with the OCR contents in a dictionary." |
|
|
) |
|
|
), |
|
|
], |
|
|
}, |
|
|
], |
|
|
response_format=StructuredOCR, |
|
|
temperature=0, |
|
|
) |
|
|
return json.loads(chat_response.choices[0].message.parsed.model_dump_json()) |
|
|
|
|
|
|
|
|
|
|
|
st.title("OLMOCR") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a PDF or Image", type=["pdf", "png", "jpg", "jpeg"]) |
|
|
|
|
|
if uploaded_file: |
|
|
file_type = uploaded_file.type |
|
|
file_bytes = uploaded_file.read() |
|
|
file_name = uploaded_file.name |
|
|
|
|
|
if st.button("Submit"): |
|
|
st.write(f"**Processing file:** {file_name}") |
|
|
|
|
|
if "pdf" in file_type: |
|
|
pdf_response = process_pdf(file_bytes, file_name) |
|
|
st.markdown(get_combined_markdown(pdf_response)) |
|
|
else: |
|
|
result = process_image(file_bytes, file_name) |
|
|
st.json(result) |
|
|
|