from fastapi import FastAPI, File, UploadFile, HTTPException, status from PIL import Image from fastapi.middleware.cors import CORSMiddleware from rapidocr_onnxruntime import RapidOCR from rapid_table import ModelType, RapidTable, RapidTableInput import io import numpy as np import pandas as pd import uuid from pathlib import Path model = RapidOCR() app = FastAPI() origins = [ "https://hycjack-fastapi-rapidocr.hf.space/", "http://localhost", "http://localhost:7860", "http://127.0.0.1:7860" ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) @app.get("/") def read_root(): return {"Hello": "World!"} @app.post("/ocr") async def ocr(file: UploadFile = File(...)): contents = await file.read() image = Image.open(io.BytesIO(contents)) np_array = np.array(image) ocr_result, elapse = model(np_array) dt_boxes, rec_res, scores = list(zip(*ocr_result)) out_df = pd.DataFrame( [[box, rec, score] for box, rec, score in zip(dt_boxes, rec_res, scores)], columns=("box", "rec", "score"), ) return out_df.to_dict(orient='records') TMP_DIR = Path(__file__).parent / "tmp_uploads" TMP_DIR.mkdir(parents=True, exist_ok=True) @app.post("/ocr_table") async def ocr_table(file: UploadFile = File(...)): # ------------------- ① 参数校验 ------------------- if not file.filename: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="上传的文件没有文件名", ) # 只接受常见的图片 MIME 类型,防止恶意上传非图片文件 allowed_mime = {"image/jpeg", "image/png", "image/bmp", "image/tiff"} if file.content_type not in allowed_mime: raise HTTPException( status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=f"不支持的文件类型: {file.content_type}", ) # ------------------- ② 保存文件 ------------------- # 生成唯一文件名(保留原始后缀,方便调试) suffix = Path(file.filename).suffix.lower() # 如果上传的文件没有后缀,默认使用 .png if not suffix: suffix = ".png" unique_name = f"{uuid.uuid4().hex}{suffix}" tmp_path = TMP_DIR / unique_name try: # 读取全部字节并写入磁盘(使用 async 读取,写入同步即可) contents = await file.read() with open(tmp_path, "wb") as f: f.write(contents) print(f"文件已保存至 {tmp_path}") # 使用示例 input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH) table_engine = RapidTable(input_args) table_results = table_engine(tmp_path) print(table_results.pred_htmls) return {"pred_htmls": table_results.pred_htmls} finally: # ------------------- ⑥ 清理临时文件 ------------------- # 为了防止磁盘被塞满,尽量在请求结束后删除文件 try: if tmp_path.exists(): tmp_path.unlink() print(f"已删除临时文件 {tmp_path}") except Exception as exc: print(f"删除临时文件 {tmp_path} 失败: {exc}") return None