File size: 3,322 Bytes
e21a5f5
0125000
 
1cdc160
778d159
0125000
 
b9d651c
08606b0
244bf38
 
 
 
 
 
eaff201
5bf1b3e
 
 
 
 
 
 
 
 
 
 
eaff201
 
 
 
5bf1b3e
 
08606b0
 
1cdc160
b9d651c
 
1cdc160
 
 
 
b9d651c
 
 
 
 
 
 
7b3952c
593a1d9
0553381
7b3952c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244bf38
7b3952c
 
 
 
 
 
0553381
f281de0
7b3952c
 
 
0553381
 
 
244bf38
0553381
 
7b3952c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
from rapidocr_onnxruntime import RapidOCR
from rapid_table import ModelType, RapidTable, RapidTableInput
import io
import numpy as np
import pandas as pd


import uuid

from pathlib import Path


model = RapidOCR()
app = FastAPI()

origins = [
    "https://hycjack-fastapi-rapidocr.hf.space/",
    "http://localhost",
    "http://localhost:7860",
    "http://127.0.0.1:7860"
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"] 
)

@app.get("/")
def read_root():
    return {"Hello": "World!"}
    
@app.post("/ocr")    
async def ocr(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    np_array = np.array(image)
    ocr_result, elapse = model(np_array)
    dt_boxes, rec_res, scores = list(zip(*ocr_result))
    out_df = pd.DataFrame(
        [[box, rec, score] for box, rec, score in zip(dt_boxes, rec_res, scores)],
        columns=("box", "rec", "score"),
    )
    return out_df.to_dict(orient='records')

TMP_DIR = Path(__file__).parent / "tmp_uploads"
TMP_DIR.mkdir(parents=True, exist_ok=True)
@app.post("/ocr_table")    
async def ocr_table(file: UploadFile = File(...)):
    # ------------------- ① 参数校验 -------------------
    if not file.filename:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="上传的文件没有文件名",
        )

    # 只接受常见的图片 MIME 类型,防止恶意上传非图片文件
    allowed_mime = {"image/jpeg", "image/png", "image/bmp", "image/tiff"}
    if file.content_type not in allowed_mime:
        raise HTTPException(
            status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
            detail=f"不支持的文件类型: {file.content_type}",
        )

    # ------------------- ② 保存文件 -------------------
    # 生成唯一文件名(保留原始后缀,方便调试)
    suffix = Path(file.filename).suffix.lower()
    # 如果上传的文件没有后缀,默认使用 .png
    if not suffix:
        suffix = ".png"
    unique_name = f"{uuid.uuid4().hex}{suffix}"
    tmp_path = TMP_DIR / unique_name

    try:
        # 读取全部字节并写入磁盘(使用 async 读取,写入同步即可)
        contents = await file.read()
        with open(tmp_path, "wb") as f:
            f.write(contents)

        print(f"文件已保存至 {tmp_path}")
        
        # 使用示例
        input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH)
        table_engine = RapidTable(input_args)
        table_results = table_engine(tmp_path)

        print(table_results.pred_htmls)
        return {"pred_htmls": table_results.pred_htmls}
    finally:
        # ------------------- ⑥ 清理临时文件 -------------------
        # 为了防止磁盘被塞满,尽量在请求结束后删除文件
        try:
            if tmp_path.exists():
                tmp_path.unlink()
                print(f"已删除临时文件 {tmp_path}")
        except Exception as exc:
            print(f"删除临时文件 {tmp_path} 失败: {exc}")
    return None