fastapi_rapidocr / main.py
HycJack's picture
return html
f281de0
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
from rapidocr_onnxruntime import RapidOCR
from rapid_table import ModelType, RapidTable, RapidTableInput
import io
import numpy as np
import pandas as pd
import uuid
from pathlib import Path
model = RapidOCR()
app = FastAPI()
origins = [
"https://hycjack-fastapi-rapidocr.hf.space/",
"http://localhost",
"http://localhost:7860",
"http://127.0.0.1:7860"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.get("/")
def read_root():
return {"Hello": "World!"}
@app.post("/ocr")
async def ocr(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
np_array = np.array(image)
ocr_result, elapse = model(np_array)
dt_boxes, rec_res, scores = list(zip(*ocr_result))
out_df = pd.DataFrame(
[[box, rec, score] for box, rec, score in zip(dt_boxes, rec_res, scores)],
columns=("box", "rec", "score"),
)
return out_df.to_dict(orient='records')
TMP_DIR = Path(__file__).parent / "tmp_uploads"
TMP_DIR.mkdir(parents=True, exist_ok=True)
@app.post("/ocr_table")
async def ocr_table(file: UploadFile = File(...)):
# ------------------- ① 参数校验 -------------------
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="上传的文件没有文件名",
)
# 只接受常见的图片 MIME 类型,防止恶意上传非图片文件
allowed_mime = {"image/jpeg", "image/png", "image/bmp", "image/tiff"}
if file.content_type not in allowed_mime:
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=f"不支持的文件类型: {file.content_type}",
)
# ------------------- ② 保存文件 -------------------
# 生成唯一文件名(保留原始后缀,方便调试)
suffix = Path(file.filename).suffix.lower()
# 如果上传的文件没有后缀,默认使用 .png
if not suffix:
suffix = ".png"
unique_name = f"{uuid.uuid4().hex}{suffix}"
tmp_path = TMP_DIR / unique_name
try:
# 读取全部字节并写入磁盘(使用 async 读取,写入同步即可)
contents = await file.read()
with open(tmp_path, "wb") as f:
f.write(contents)
print(f"文件已保存至 {tmp_path}")
# 使用示例
input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH)
table_engine = RapidTable(input_args)
table_results = table_engine(tmp_path)
print(table_results.pred_htmls)
return {"pred_htmls": table_results.pred_htmls}
finally:
# ------------------- ⑥ 清理临时文件 -------------------
# 为了防止磁盘被塞满,尽量在请求结束后删除文件
try:
if tmp_path.exists():
tmp_path.unlink()
print(f"已删除临时文件 {tmp_path}")
except Exception as exc:
print(f"删除临时文件 {tmp_path} 失败: {exc}")
return None