Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import io
|
|
| 4 |
import zipfile
|
| 5 |
from datetime import datetime
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import List, Optional
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
import pandas as pd
|
|
@@ -61,7 +61,10 @@ Base.metadata.create_all(engine)
|
|
| 61 |
def load_suppliers_from_csv():
|
| 62 |
if not SUPPLIER_CSV.exists():
|
| 63 |
return
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
| 65 |
session = SessionLocal()
|
| 66 |
for _, r in df.iterrows():
|
| 67 |
name = str(r.get("name","")).strip()
|
|
@@ -82,8 +85,15 @@ TAXNO_RE = re.compile(r"([0-9A-Z]{15,20})")
|
|
| 82 |
INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)")
|
| 83 |
AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)")
|
| 84 |
|
| 85 |
-
def do_ocr
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
text = pytesseract.image_to_string(img, lang='chi_sim+eng')
|
| 88 |
return text
|
| 89 |
|
|
@@ -118,16 +128,18 @@ def extract_fields(ocr_text: str) -> dict:
|
|
| 118 |
if not name:
|
| 119 |
lines = [ln.strip() for ln in ocr_text.splitlines() if ln.strip()]
|
| 120 |
if lines:
|
| 121 |
-
name = lines[0][:
|
| 122 |
return {"taxno": taxno, "invoice_no": inv, "total": total, "name": name, "raw": ocr_text}
|
| 123 |
|
| 124 |
def match_supplier(session, extracted: dict, threshold:int=80):
|
|
|
|
| 125 |
tax = extracted.get("taxno")
|
| 126 |
if tax:
|
| 127 |
sup = session.query(Supplier).filter(Supplier.tax_no==tax.strip()).first()
|
| 128 |
if sup:
|
| 129 |
return {"supplier_id": sup.id, "supplier_name": sup.name, "supplier_taxno": sup.tax_no, "score": 100.0}
|
| 130 |
|
|
|
|
| 131 |
name = (extracted.get("name") or "").strip()
|
| 132 |
if not name:
|
| 133 |
return None
|
|
@@ -147,12 +159,12 @@ def match_supplier(session, extracted: dict, threshold:int=80):
|
|
| 147 |
# ---------- Storage helpers ----------
|
| 148 |
def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"):
|
| 149 |
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
|
| 150 |
-
safe_name = f"{ts}_{filename}"
|
| 151 |
path = STORAGE_DIR / safe_name
|
| 152 |
with open(path, "wb") as f:
|
| 153 |
f.write(file_bytes)
|
| 154 |
|
| 155 |
-
ocr_text = do_ocr
|
| 156 |
extracted = extract_fields(ocr_text)
|
| 157 |
session = SessionLocal()
|
| 158 |
match = match_supplier(session, extracted)
|
|
@@ -195,10 +207,11 @@ def list_documents(limit=200):
|
|
| 195 |
session.close()
|
| 196 |
return rows
|
| 197 |
|
| 198 |
-
def get_suppliers():
|
| 199 |
session = SessionLocal()
|
| 200 |
s = session.query(Supplier).order_by(Supplier.name).all()
|
| 201 |
session.close()
|
|
|
|
| 202 |
return [(str(x.id), x.name) for x in s]
|
| 203 |
|
| 204 |
def confirm_document(doc_id: int, supplier_id: Optional[int]):
|
|
@@ -253,10 +266,19 @@ def export_zip(ids: List[int]):
|
|
| 253 |
|
| 254 |
# ---------- Gradio UI ----------
|
| 255 |
def upload_files(files, uploader):
|
|
|
|
| 256 |
results = []
|
|
|
|
|
|
|
| 257 |
for f in files:
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
results.append(res)
|
| 261 |
return pd.DataFrame(results)
|
| 262 |
|
|
@@ -266,31 +288,51 @@ def refresh_list():
|
|
| 266 |
return pd.DataFrame([], columns=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"])
|
| 267 |
return pd.DataFrame(rows)
|
| 268 |
|
| 269 |
-
def ui_confirm(doc_id,
|
| 270 |
-
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
def ui_export(
|
| 274 |
-
if not
|
| 275 |
-
return
|
| 276 |
-
ids = [int(x) for x in
|
|
|
|
|
|
|
| 277 |
zp = export_zip(ids)
|
| 278 |
-
return zp
|
| 279 |
|
|
|
|
| 280 |
with gr.Blocks() as demo:
|
| 281 |
gr.Markdown("## 发票收票分拣(Staging) \n上传发票 → 自动识别供应商 → 人工确认/导出")
|
| 282 |
with gr.Row():
|
| 283 |
with gr.Column(scale=3):
|
| 284 |
uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk")
|
| 285 |
-
file_inputs = gr.File(label="拖拽或选择发票(图片
|
| 286 |
upload_btn = gr.Button("上传并识别")
|
| 287 |
-
upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"])
|
| 288 |
with gr.Column(scale=2):
|
| 289 |
refresh_btn = gr.Button("刷新列表")
|
| 290 |
-
list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"])
|
| 291 |
gr.Markdown("**确认匹配(人工)**")
|
| 292 |
doc_id_in = gr.Textbox(label="文档 ID(从列表取)")
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
confirm_btn = gr.Button("确认关联")
|
| 295 |
confirm_out = gr.Textbox()
|
| 296 |
gr.Markdown("**导出选中(输入逗号分隔的 ID 列表)**")
|
|
@@ -300,10 +342,11 @@ with gr.Blocks() as demo:
|
|
| 300 |
|
| 301 |
upload_btn.click(lambda files,uploader: upload_files(files,uploader), inputs=[file_inputs,uploader], outputs=upload_out)
|
| 302 |
refresh_btn.click(lambda: refresh_list(), inputs=[], outputs=list_out)
|
| 303 |
-
confirm_btn.click(lambda d,s: ui_confirm(d,
|
| 304 |
-
export_btn.click(lambda txt: ui_export(
|
| 305 |
|
| 306 |
-
gr.Markdown("**提示**:如果供应商不在下拉中,请在 repo 放入或更新 `suppliers.csv` 并 Rebuild
|
| 307 |
|
| 308 |
if __name__ == "__main__":
|
|
|
|
| 309 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
|
|
|
| 4 |
import zipfile
|
| 5 |
from datetime import datetime
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
import pandas as pd
|
|
|
|
| 61 |
def load_suppliers_from_csv():
|
| 62 |
if not SUPPLIER_CSV.exists():
|
| 63 |
return
|
| 64 |
+
try:
|
| 65 |
+
df = pd.read_csv(SUPPLIER_CSV, dtype=str).fillna("")
|
| 66 |
+
except Exception:
|
| 67 |
+
return
|
| 68 |
session = SessionLocal()
|
| 69 |
for _, r in df.iterrows():
|
| 70 |
name = str(r.get("name","")).strip()
|
|
|
|
| 85 |
INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)")
|
| 86 |
AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)")
|
| 87 |
|
| 88 |
+
def do_ocr(image_bytes: bytes) -> str:
|
| 89 |
+
"""
|
| 90 |
+
用 pytesseract 对二进制图片数据做 OCR,返回识别出的纯文本。
|
| 91 |
+
"""
|
| 92 |
+
img = Image.open(io.BytesIO(image_bytes))
|
| 93 |
+
try:
|
| 94 |
+
img = img.convert("RGB")
|
| 95 |
+
except Exception:
|
| 96 |
+
pass
|
| 97 |
text = pytesseract.image_to_string(img, lang='chi_sim+eng')
|
| 98 |
return text
|
| 99 |
|
|
|
|
| 128 |
if not name:
|
| 129 |
lines = [ln.strip() for ln in ocr_text.splitlines() if ln.strip()]
|
| 130 |
if lines:
|
| 131 |
+
name = lines[0][:120]
|
| 132 |
return {"taxno": taxno, "invoice_no": inv, "total": total, "name": name, "raw": ocr_text}
|
| 133 |
|
| 134 |
def match_supplier(session, extracted: dict, threshold:int=80):
|
| 135 |
+
# 1) tax no exact match
|
| 136 |
tax = extracted.get("taxno")
|
| 137 |
if tax:
|
| 138 |
sup = session.query(Supplier).filter(Supplier.tax_no==tax.strip()).first()
|
| 139 |
if sup:
|
| 140 |
return {"supplier_id": sup.id, "supplier_name": sup.name, "supplier_taxno": sup.tax_no, "score": 100.0}
|
| 141 |
|
| 142 |
+
# 2) name fuzzy match
|
| 143 |
name = (extracted.get("name") or "").strip()
|
| 144 |
if not name:
|
| 145 |
return None
|
|
|
|
| 159 |
# ---------- Storage helpers ----------
|
| 160 |
def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"):
|
| 161 |
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
|
| 162 |
+
safe_name = f"{ts}_{filename.replace(' ', '_')}"
|
| 163 |
path = STORAGE_DIR / safe_name
|
| 164 |
with open(path, "wb") as f:
|
| 165 |
f.write(file_bytes)
|
| 166 |
|
| 167 |
+
ocr_text = do_ocr(file_bytes)
|
| 168 |
extracted = extract_fields(ocr_text)
|
| 169 |
session = SessionLocal()
|
| 170 |
match = match_supplier(session, extracted)
|
|
|
|
| 207 |
session.close()
|
| 208 |
return rows
|
| 209 |
|
| 210 |
+
def get_suppliers() -> List[Tuple[str,str]]:
|
| 211 |
session = SessionLocal()
|
| 212 |
s = session.query(Supplier).order_by(Supplier.name).all()
|
| 213 |
session.close()
|
| 214 |
+
# return list of (id, name)
|
| 215 |
return [(str(x.id), x.name) for x in s]
|
| 216 |
|
| 217 |
def confirm_document(doc_id: int, supplier_id: Optional[int]):
|
|
|
|
| 266 |
|
| 267 |
# ---------- Gradio UI ----------
|
| 268 |
def upload_files(files, uploader):
|
| 269 |
+
# files: list of gradio file-like objects
|
| 270 |
results = []
|
| 271 |
+
if not files:
|
| 272 |
+
return pd.DataFrame([], columns=["id","filename","supplier","score","status","invoice_no","total"])
|
| 273 |
for f in files:
|
| 274 |
+
# f.read() returns bytes
|
| 275 |
+
try:
|
| 276 |
+
content = f.read()
|
| 277 |
+
except Exception:
|
| 278 |
+
# f may be a local path in some envs
|
| 279 |
+
with open(f.name, "rb") as fh:
|
| 280 |
+
content = fh.read()
|
| 281 |
+
res = save_file_and_record(content, getattr(f, "name", "uploaded_file"), uploader or "unknown")
|
| 282 |
results.append(res)
|
| 283 |
return pd.DataFrame(results)
|
| 284 |
|
|
|
|
| 288 |
return pd.DataFrame([], columns=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"])
|
| 289 |
return pd.DataFrame(rows)
|
| 290 |
|
| 291 |
+
def ui_confirm(doc_id: str, supplier_name: str):
|
| 292 |
+
# doc_id is string from textbox; supplier_name is name selected in dropdown
|
| 293 |
+
try:
|
| 294 |
+
doc_id_i = int(doc_id)
|
| 295 |
+
except:
|
| 296 |
+
return "Invalid doc id"
|
| 297 |
+
if not supplier_name:
|
| 298 |
+
# mark unknown
|
| 299 |
+
ok = confirm_document(doc_id_i, None)
|
| 300 |
+
return "ok" if ok else "fail"
|
| 301 |
+
# find supplier id by name
|
| 302 |
+
session = SessionLocal()
|
| 303 |
+
s = session.query(Supplier).filter(Supplier.name==supplier_name).first()
|
| 304 |
+
session.close()
|
| 305 |
+
if not s:
|
| 306 |
+
return "supplier not found"
|
| 307 |
+
ok = confirm_document(doc_id_i, s.id)
|
| 308 |
+
return "ok" if ok else "fail"
|
| 309 |
|
| 310 |
+
def ui_export(txt_ids: str):
|
| 311 |
+
if not txt_ids:
|
| 312 |
+
return ""
|
| 313 |
+
ids = [int(x.strip()) for x in txt_ids.split(",") if x.strip().isdigit()]
|
| 314 |
+
if not ids:
|
| 315 |
+
return ""
|
| 316 |
zp = export_zip(ids)
|
| 317 |
+
return zp or ""
|
| 318 |
|
| 319 |
+
# build UI
|
| 320 |
with gr.Blocks() as demo:
|
| 321 |
gr.Markdown("## 发票收票分拣(Staging) \n上传发票 → 自动识别供应商 → 人工确认/导出")
|
| 322 |
with gr.Row():
|
| 323 |
with gr.Column(scale=3):
|
| 324 |
uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk")
|
| 325 |
+
file_inputs = gr.File(label="拖拽或选择发票(图片 优先)", file_count="multiple")
|
| 326 |
upload_btn = gr.Button("上传并识别")
|
| 327 |
+
upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"], interactive=False)
|
| 328 |
with gr.Column(scale=2):
|
| 329 |
refresh_btn = gr.Button("刷新列表")
|
| 330 |
+
list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"], interactive=False)
|
| 331 |
gr.Markdown("**确认匹配(人工)**")
|
| 332 |
doc_id_in = gr.Textbox(label="文档 ID(从列表取)")
|
| 333 |
+
# supplier dropdown choices filled at start; to refresh suppliers update repo or restart
|
| 334 |
+
supplier_choices = [name for (_id, name) in get_suppliers()]
|
| 335 |
+
supplier_dd = gr.Dropdown(choices=supplier_choices, label="选择供应商(或先通过 suppliers.csv 添加)")
|
| 336 |
confirm_btn = gr.Button("确认关联")
|
| 337 |
confirm_out = gr.Textbox()
|
| 338 |
gr.Markdown("**导出选中(输入逗号分隔的 ID 列表)**")
|
|
|
|
| 342 |
|
| 343 |
upload_btn.click(lambda files,uploader: upload_files(files,uploader), inputs=[file_inputs,uploader], outputs=upload_out)
|
| 344 |
refresh_btn.click(lambda: refresh_list(), inputs=[], outputs=list_out)
|
| 345 |
+
confirm_btn.click(lambda d,s: ui_confirm(d, s), inputs=[doc_id_in, supplier_dd], outputs=confirm_out)
|
| 346 |
+
export_btn.click(lambda txt: ui_export(txt), inputs=[export_ids], outputs=export_out)
|
| 347 |
|
| 348 |
+
gr.Markdown("**提示**:如果供应商不在下拉中,请在 repo 放入或更新 `suppliers.csv` 并 Rebuild 空间,或在数据库新增。")
|
| 349 |
|
| 350 |
if __name__ == "__main__":
|
| 351 |
+
# Gradio default port for Spaces is 7860
|
| 352 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|