Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, status | |
| from PIL import Image | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from rapidocr_onnxruntime import RapidOCR | |
| from rapid_table import ModelType, RapidTable, RapidTableInput | |
| import io | |
| import numpy as np | |
| import pandas as pd | |
| import uuid | |
| from pathlib import Path | |
| model = RapidOCR() | |
| app = FastAPI() | |
| origins = [ | |
| "https://hycjack-fastapi-rapidocr.hf.space/", | |
| "http://localhost", | |
| "http://localhost:7860", | |
| "http://127.0.0.1:7860" | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| def read_root(): | |
| return {"Hello": "World!"} | |
| async def ocr(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| np_array = np.array(image) | |
| ocr_result, elapse = model(np_array) | |
| dt_boxes, rec_res, scores = list(zip(*ocr_result)) | |
| out_df = pd.DataFrame( | |
| [[box, rec, score] for box, rec, score in zip(dt_boxes, rec_res, scores)], | |
| columns=("box", "rec", "score"), | |
| ) | |
| return out_df.to_dict(orient='records') | |
| TMP_DIR = Path(__file__).parent / "tmp_uploads" | |
| TMP_DIR.mkdir(parents=True, exist_ok=True) | |
| async def ocr_table(file: UploadFile = File(...)): | |
| # ------------------- ① 参数校验 ------------------- | |
| if not file.filename: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="上传的文件没有文件名", | |
| ) | |
| # 只接受常见的图片 MIME 类型,防止恶意上传非图片文件 | |
| allowed_mime = {"image/jpeg", "image/png", "image/bmp", "image/tiff"} | |
| if file.content_type not in allowed_mime: | |
| raise HTTPException( | |
| status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, | |
| detail=f"不支持的文件类型: {file.content_type}", | |
| ) | |
| # ------------------- ② 保存文件 ------------------- | |
| # 生成唯一文件名(保留原始后缀,方便调试) | |
| suffix = Path(file.filename).suffix.lower() | |
| # 如果上传的文件没有后缀,默认使用 .png | |
| if not suffix: | |
| suffix = ".png" | |
| unique_name = f"{uuid.uuid4().hex}{suffix}" | |
| tmp_path = TMP_DIR / unique_name | |
| try: | |
| # 读取全部字节并写入磁盘(使用 async 读取,写入同步即可) | |
| contents = await file.read() | |
| with open(tmp_path, "wb") as f: | |
| f.write(contents) | |
| print(f"文件已保存至 {tmp_path}") | |
| # 使用示例 | |
| input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH) | |
| table_engine = RapidTable(input_args) | |
| table_results = table_engine(tmp_path) | |
| print(table_results.pred_htmls) | |
| return {"pred_htmls": table_results.pred_htmls} | |
| finally: | |
| # ------------------- ⑥ 清理临时文件 ------------------- | |
| # 为了防止磁盘被塞满,尽量在请求结束后删除文件 | |
| try: | |
| if tmp_path.exists(): | |
| tmp_path.unlink() | |
| print(f"已删除临时文件 {tmp_path}") | |
| except Exception as exc: | |
| print(f"删除临时文件 {tmp_path} 失败: {exc}") | |
| return None | |