ocr_extraction / app.py
kawaiipeace's picture
update app.py
3a171a3
import os
import tempfile
from fastapi import FastAPI, UploadFile, File, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from typhoon_ocr import ocr_document
from PyPDF2 import PdfReader
from docx import Document
import gradio as gr
from dotenv import load_dotenv
load_dotenv()
API_KEY = os.getenv("API_KEY")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def process_file_with_typhoon(file_path, task_type: str, page_num: int):
ext = os.path.splitext(file_path)[-1].lower()
md_result = ""
doc = Document()
if ext == ".pdf":
reader = PdfReader(file_path)
total_pages = len(reader.pages)
if page_num == 0:
pages_to_process = range(1, total_pages + 1)
elif page_num > 0:
if page_num > total_pages:
raise Exception(f"เลขหน้าที่ใส่มา ({page_num}) เกินจำนวนหน้าจริง ({total_pages}) จ้า")
pages_to_process = [page_num]
else:
raise Exception("เลขหน้าต้องเป็น 0 หรือบวกเท่านั้น ไอ้สัส")
for p in pages_to_process:
md = ocr_document(file_path, task_type=task_type, page_num=p)
md_result += f"\n## หน้า {p}\n{md.strip()}\n"
doc.add_heading(f"หน้า {p}", level=1)
doc.add_paragraph(md.strip())
else:
if page_num != 0 and page_num != 1:
raise Exception("รูปภาพต้องใส่เลขหน้า 0 หรือ 1 เท่านั้น")
md = ocr_document(file_path, task_type=task_type)
md_result = md.strip()
doc.add_paragraph(md_result)
tmp_dir = tempfile.gettempdir()
md_path = os.path.join(tmp_dir, "ocr_extraction_peace.md")
docx_path = os.path.join(tmp_dir, "ocr_extraction_peace.docx")
with open(md_path, "w", encoding="utf-8") as f_md:
f_md.write(md_result)
doc.save(docx_path)
return md_result, md_path, docx_path
@app.post("/api/ocr_document")
async def ocr_endpoint(
file: UploadFile = File(...),
task_type: str = "default",
page_num: int = 0,
x_api_key: str | None = Header(None)
):
if API_KEY and x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="API key ผิดพ่อง")
if page_num < 0:
raise HTTPException(status_code=400, detail="เลขหน้าต้องเป็น 0 หรือบวกเท่านั้น")
suffix = os.path.splitext(file.filename)[-1].lower()
content = await file.read()
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
md, md_path, docx_path = process_file_with_typhoon(tmp_path, task_type, page_num)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return {
"markdown": md,
"md_file": md_path,
"docx_file": docx_path,
"task_type": task_type,
}
def gradio_handler(file, task_type, page_num):
if file is None:
return "❌ อัปโหลดไฟล์ก่อน....", None, None
if page_num < 0:
return "เลขหน้าต้องเป็น 0 หรือบวกเท่านั้น", None, None
try:
md, md_path, docx_path = process_file_with_typhoon(file.name, task_type, page_num)
except Exception as e:
return f"Error: {str(e)}", None, None
return md, md_path, docx_path
with gr.Blocks() as demo:
gr.Markdown("## 📄 OCR Extraction System by TyphoonOCR")
with gr.Row():
file_input = gr.File(label="📤 อัปโหลด PDF หรือ รูปภาพ", file_types=[".pdf", ".jpg", ".jpeg", ".png"])
page_input = gr.Number(value=0, label="📄 เลขหน้า (0 = แปลงทุกหน้า)")
task_type = gr.Radio(["default", "structure"], label="📌 OCR Mode", value="default")
ocr_output = gr.Textbox(label="📋 ผล Markdown (OCR)", lines=20)
with gr.Row():
btn = gr.Button("🚀 ประมวลผล OCR บัดเดียวนี้")
md_out = gr.File(label="📥 ดาวน์โหลด .md ไฟล์")
docx_out = gr.File(label="📥 ดาวน์โหลด .docx ไฟล์")
btn.click(fn=gradio_handler, inputs=[file_input, task_type, page_input], outputs=[ocr_output, md_out, docx_out])
# demo.launch()
# ถ้าจะ mount เข้า FastAPI ใช้คำสั่งด้านล่างนี้
# แล้วรันคำสั่ง uvicorn app:app --host 0.0.0.0 --port 7860 ใน Terminal
# app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
demo.launch()