mistral-ocr / app.py
pjf67546's picture
Update app.py
6f59c8f verified
import os
import base64
import time
import re
import urllib.request
from io import BytesIO
import gradio as gr
from mistralai import Mistral
from PIL import Image
import markdown
import tempfile
from xhtml2pdf import pisa
# Config
VALID_DOCUMENT_EXTENSIONS = {".pdf"}
VALID_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
# --- 字型設定 ---
# 使用 SimHei 確保中文顯示正常
FONT_URL = "https://github.com/StellarCN/scp_zh/raw/master/fonts/SimHei.ttf"
FONT_FILENAME = "SimHei.ttf"
def download_font():
"""下載 .ttf 字型檔"""
if not os.path.exists(FONT_FILENAME):
print(f"Downloading compatible font from {FONT_URL}...")
try:
urllib.request.urlretrieve(FONT_URL, FONT_FILENAME)
print("Font downloaded successfully.")
except Exception as e:
print(f"Error downloading font: {e}")
download_font()
def upload_pdf(content, filename, client):
uploaded_file = client.files.upload(
file={"file_name": filename, "content": content},
purpose="ocr",
)
signed_url = client.files.get_signed_url(file_id=uploaded_file.id)
return signed_url.url
def process_ocr(document_source, client):
return client.ocr.process(
model="mistral-ocr-latest",
document=document_source,
include_image_base64=True
)
# --- 圖片與文字處理 ---
def replace_images_with_placeholders(text):
img_map = {}
counter = 0
pattern = r'!\[.*?\]\(data:image\/.*?;base64,.*?\)'
def replacer(match):
nonlocal counter
placeholder = f"[[IMG_{counter}]]"
img_map[placeholder] = match.group(0)
counter += 1
return placeholder
text_with_placeholders = re.sub(pattern, replacer, text, flags=re.DOTALL)
return text_with_placeholders, img_map
def restore_images_from_placeholders(text, img_map):
for placeholder, original_tag in img_map.items():
text = text.replace(placeholder, original_tag)
text = text.replace(placeholder.replace("[[", "[").replace("]]", "]"), original_tag)
return text
def translate_chunk_safe(text, target_lang, client, model="mistral-large-latest"):
safe_text, img_map = replace_images_with_placeholders(text)
system_prompt = (
f"You are a professional document translator. Translate the following Markdown content into {target_lang}. "
"IMPORTANT RULES:\n"
"1. Preserve all Markdown formatting strictly (headers, lists, bolding).\n"
"2. DO NOT touch the placeholders like [[IMG_0]]. Keep them exactly where they are.\n"
"3. Output ONLY the translated content.\n"
"4. Do not output code blocks unless the original text had them.\n"
)
try:
response = client.chat.complete(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": safe_text}
]
)
translated_safe_text = response.choices[0].message.content
final_text = restore_images_from_placeholders(translated_safe_text, img_map)
return final_text
except Exception as e:
return f"[Translation Error: {str(e)}]"
# --- PDF 生成核心 (強力修正版) ---
def convert_html_to_pdf(html_content, output_path):
font_path = os.path.abspath(FONT_FILENAME)
# CSS 策略:
# 1. @page 定義固定 A4 大小與邊距。
# 2. table-layout: fixed 強制表格不撐開。
# 3. word-wrap: break-word 和 word-break: break-all 強制長字串換行。
# 4. pre (程式碼) 區塊強制換行。
css = f"""
<style>
@font-face {{
font-family: 'SimHei';
src: url('{font_path}');
}}
@page {{
size: A4;
margin: 1.5cm; /* 邊距稍微縮小,給內容更多空間 */
@frame footer_frame {{
-pdf-frame-content: footerContent;
bottom: 0.5cm;
margin-left: 1cm;
margin-right: 1cm;
height: 1cm;
}}
}}
body {{
font-family: 'SimHei', sans-serif;
font-size: 10pt;
line-height: 1.4;
color: #333;
}}
/* 核心:強制內容不超出容器 */
.container {{
width: 100%;
max-width: 100%;
overflow: hidden;
}}
/* 對所有可能包含文字的標籤強制換行 */
p, div, td, th, li, span, a {{
word-wrap: break-word;
word-break: break-all; /* 這是 xhtml2pdf 處理長網址的關鍵 */
white-space: normal;
}}
/* 圖片限制 */
img {{
max-width: 100%;
height: auto;
margin: 10px 0;
}}
/* 表格限制:這是最容易超出邊界的地方 */
table {{
width: 100%;
max-width: 100%;
table-layout: fixed; /* 強制固定佈局,忽略內容寬度 */
border-collapse: collapse;
margin-bottom: 15px;
border: 0.5px solid #ccc;
}}
td, th {{
border: 0.5px solid #888;
padding: 4px;
vertical-align: top;
font-size: 9pt; /* 表格文字稍微縮小以容納更多內容 */
overflow: hidden; /* 防止內容溢出單元格 */
word-wrap: break-word;
word-break: break-all;
}}
th {{ background-color: #f0f0f0; font-weight: bold; }}
/* 程式碼區塊限制 */
pre {{
font-family: 'SimHei', monospace; /* 使用同樣字型避免亂碼 */
background-color: #f5f5f5;
padding: 8px;
border: 1px solid #ddd;
border-radius: 4px;
white-space: pre-wrap; /* 關鍵:讓程式碼換行 */
word-wrap: break-word;
word-break: break-all;
font-size: 9pt;
width: 100%;
}}
h1, h2, h3 {{ color: #202020; margin-top: 15px; margin-bottom: 8px; page-break-after: avoid; }}
h1 {{ font-size: 16pt; border-bottom: 1px solid #eee; }}
h2 {{ font-size: 14pt; }}
h3 {{ font-size: 12pt; }}
.page-break {{ page-break-after: always; }}
</style>
"""
# 組合 HTML,加上 container wrapper
full_html = f"""
<html>
<head>
<meta charset='utf-8'>
{css}
</head>
<body>
<div class="container">
{html_content}
</div>
<div id="footerContent" style="text-align:right; font-size:8pt; color: #999;">
Translated by Mistral
</div>
</body>
</html>
"""
# xhtml2pdf 需要二進位模式寫入
with open(output_path, "wb") as pdf_file:
pisa_status = pisa.CreatePDF(src=full_html, dest=pdf_file, encoding='utf-8')
return not pisa_status.err
# --- Gradio 邏輯 (無變更,保持 UI 一致) ---
def do_ocr(input_type, url, file, api_key, progress=gr.Progress()):
gr.Info("Starting process... please wait.")
api_key = api_key.strip() if api_key and api_key.strip() else os.environ.get("MISTRAL")
if not api_key: raise gr.Error("Please provide a valid Mistral API key")
client = Mistral(api_key=api_key)
document_source = None
progress(0.1, desc="Preparing File...")
if input_type == "URL":
if not url: raise gr.Error("Invalid URL")
if any(url.lower().endswith(ext) for ext in VALID_IMAGE_EXTENSIONS):
document_source = {"type": "image_url", "image_url": url.strip()}
else:
document_source = {"type": "document_url", "document_url": url.strip()}
elif input_type == "Upload file":
if not file: raise gr.Error("No file uploaded")
file_path = file.name if hasattr(file, 'name') else file
file_name = os.path.basename(file_path).lower()
file_ext = os.path.splitext(file_name)[1]
if file_ext in VALID_DOCUMENT_EXTENSIONS:
with open(file_path, "rb") as f: content = f.read()
gr.Info("Uploading PDF...")
signed_url = upload_pdf(content, file_name, client)
document_source = {"type": "document_url", "document_url": signed_url}
elif file_ext in VALID_IMAGE_EXTENSIONS:
img = Image.open(file_path)
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
document_source = {"type": "image_url", "image_url": f"data:image/png;base64,{img_str}"}
else:
raise gr.Error("Unsupported file type")
progress(0.4, desc="Mistral OCR Running...")
try:
ocr_response = process_ocr(document_source, client)
except Exception as e:
raise gr.Error(f"OCR Failed: {e}")
progress(0.7, desc="Processing Images...")
pages_data = []
full_md = ""
for page in ocr_response.pages:
page_md = page.markdown
for img in page.images:
if img.image_base64:
b64 = img.image_base64.split(",")[1] if "," in img.image_base64 else img.image_base64
data_url = f"data:image/png;base64,{b64}"
page_md = page_md.replace(f"![{img.id}]({img.id})", f"![Image]({data_url})")
pages_data.append(page_md)
full_md += page_md + "\n\n"
progress(1.0, desc="OCR Done")
gr.Info("OCR Complete.")
return full_md, pages_data
def run_translation(pages_data, language, rpm, api_key, progress=gr.Progress()):
if not pages_data: raise gr.Error("No content. Run OCR first.")
api_key = api_key.strip() if api_key else os.environ.get("MISTRAL")
client = Mistral(api_key=api_key)
translated_html_parts = []
translated_md_parts = []
delay = 60.0 / max(1, float(rpm))
total = len(pages_data)
for i, page_md in enumerate(pages_data):
progress((i / total), desc=f"Translating Page {i+1}/{total}...")
trans_md = translate_chunk_safe(page_md, language, client)
translated_md_parts.append(trans_md)
# 轉換 markdown -> html
html_part = markdown.markdown(trans_md, extensions=['tables', 'fenced_code', 'nl2br'])
translated_html_parts.append(f"{html_part}<div class='page-break'></div>")
if i < total - 1:
time.sleep(delay)
progress(0.9, desc="Creating PDF...")
full_translated_md = "\n\n".join(translated_md_parts)
full_html_body = "\n".join(translated_html_parts)
output_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf").name
success = convert_html_to_pdf(full_html_body, output_pdf)
if not success:
gr.Warning("PDF layout warning.")
output_html = tempfile.NamedTemporaryFile(delete=False, suffix=".html").name
with open(output_html, "w", encoding="utf-8") as f:
f.write(f"<html><head><meta charset='utf-8'></head><body>{full_html_body}</body></html>")
gr.Info("Done! Download files below.")
return full_translated_md, output_pdf, output_html
# UI
with gr.Blocks(title="Mistral OCR to PDF", theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align:center'>Mistral OCR & Translate to PDF</h1>")
ocr_state = gr.State()
with gr.Row():
with gr.Column(scale=1):
api_key = gr.Textbox(label="Mistral API Key", type="password", value=os.environ.get("MISTRAL", ""))
input_type = gr.Radio(["URL", "Upload file"], label="Source", value="URL")
with gr.Group():
url_in = gr.Textbox(label="URL", placeholder="https://arxiv.org/pdf/...", visible=True)
file_in = gr.File(label="File", visible=False)
ocr_btn = gr.Button("1. Start OCR", variant="primary")
gr.Markdown("---")
lang_in = gr.Dropdown(["Traditional Chinese", "Simplified Chinese", "English"], value="Traditional Chinese", label="Target Language")
rpm_in = gr.Slider(1, 60, 5, 1, label="Speed (RPM)", info="Keep low for free accounts")
trans_btn = gr.Button("2. Translate & Download", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Preview"):
md_out = gr.Markdown(label="Preview")
with gr.TabItem("Raw OCR"):
raw_out = gr.Markdown(label="Original OCR")
gr.Markdown("### Downloads")
with gr.Row():
pdf_out = gr.File(label="Translated PDF (Fixed Width)")
html_out = gr.File(label="Backup HTML")
def toggle_in(c):
return [gr.update(visible=c=="URL"), gr.update(visible=c=="Upload file")]
input_type.change(toggle_in, input_type, [url_in, file_in])
ocr_btn.click(
do_ocr,
inputs=[input_type, url_in, file_in, api_key],
outputs=[raw_out, ocr_state],
show_progress="full"
)
trans_btn.click(
run_translation,
inputs=[ocr_state, lang_in, rpm_in, api_key],
outputs=[md_out, pdf_out, html_out],
show_progress="full"
)
if __name__ == "__main__":
demo.queue().launch()