Gabriel00A commited on
Commit
01da16a
·
verified ·
1 Parent(s): 5545434

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -25
app.py CHANGED
@@ -4,7 +4,7 @@ import io
4
  import zipfile
5
  from datetime import datetime
6
  from pathlib import Path
7
- from typing import List, Optional
8
 
9
  import gradio as gr
10
  import pandas as pd
@@ -61,7 +61,10 @@ Base.metadata.create_all(engine)
61
  def load_suppliers_from_csv():
62
  if not SUPPLIER_CSV.exists():
63
  return
64
- df = pd.read_csv(SUPPLIER_CSV, dtype=str).fillna("")
 
 
 
65
  session = SessionLocal()
66
  for _, r in df.iterrows():
67
  name = str(r.get("name","")).strip()
@@ -82,8 +85,15 @@ TAXNO_RE = re.compile(r"([0-9A-Z]{15,20})")
82
  INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)")
83
  AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)")
84
 
85
- def do_ocr PIL_image_bytes(file_bytes: bytes) -> str:
86
- img = Image.open(io.BytesIO(file_bytes))
 
 
 
 
 
 
 
87
  text = pytesseract.image_to_string(img, lang='chi_sim+eng')
88
  return text
89
 
@@ -118,16 +128,18 @@ def extract_fields(ocr_text: str) -> dict:
118
  if not name:
119
  lines = [ln.strip() for ln in ocr_text.splitlines() if ln.strip()]
120
  if lines:
121
- name = lines[0][:80]
122
  return {"taxno": taxno, "invoice_no": inv, "total": total, "name": name, "raw": ocr_text}
123
 
124
  def match_supplier(session, extracted: dict, threshold:int=80):
 
125
  tax = extracted.get("taxno")
126
  if tax:
127
  sup = session.query(Supplier).filter(Supplier.tax_no==tax.strip()).first()
128
  if sup:
129
  return {"supplier_id": sup.id, "supplier_name": sup.name, "supplier_taxno": sup.tax_no, "score": 100.0}
130
 
 
131
  name = (extracted.get("name") or "").strip()
132
  if not name:
133
  return None
@@ -147,12 +159,12 @@ def match_supplier(session, extracted: dict, threshold:int=80):
147
  # ---------- Storage helpers ----------
148
  def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"):
149
  ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
150
- safe_name = f"{ts}_{filename}"
151
  path = STORAGE_DIR / safe_name
152
  with open(path, "wb") as f:
153
  f.write(file_bytes)
154
 
155
- ocr_text = do_ocr PIL_image_bytes(file_bytes)
156
  extracted = extract_fields(ocr_text)
157
  session = SessionLocal()
158
  match = match_supplier(session, extracted)
@@ -195,10 +207,11 @@ def list_documents(limit=200):
195
  session.close()
196
  return rows
197
 
198
- def get_suppliers():
199
  session = SessionLocal()
200
  s = session.query(Supplier).order_by(Supplier.name).all()
201
  session.close()
 
202
  return [(str(x.id), x.name) for x in s]
203
 
204
  def confirm_document(doc_id: int, supplier_id: Optional[int]):
@@ -253,10 +266,19 @@ def export_zip(ids: List[int]):
253
 
254
  # ---------- Gradio UI ----------
255
  def upload_files(files, uploader):
 
256
  results = []
 
 
257
  for f in files:
258
- content = f.read()
259
- res = save_file_and_record(content, f.name, uploader or "unknown")
 
 
 
 
 
 
260
  results.append(res)
261
  return pd.DataFrame(results)
262
 
@@ -266,31 +288,51 @@ def refresh_list():
266
  return pd.DataFrame([], columns=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"])
267
  return pd.DataFrame(rows)
268
 
269
- def ui_confirm(doc_id, supplier_id):
270
- ok = confirm_document(int(doc_id), int(supplier_id) if supplier_id else None)
271
- return {"ok": ok}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- def ui_export(selected_ids):
274
- if not selected_ids:
275
- return None
276
- ids = [int(x) for x in selected_ids]
 
 
277
  zp = export_zip(ids)
278
- return zp
279
 
 
280
  with gr.Blocks() as demo:
281
  gr.Markdown("## 发票收票分拣(Staging) \n上传发票 → 自动识别供应商 → 人工确认/导出")
282
  with gr.Row():
283
  with gr.Column(scale=3):
284
  uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk")
285
- file_inputs = gr.File(label="拖拽或选择发票(图片 / PDF 可行,但图片更稳定)", file_count="multiple")
286
  upload_btn = gr.Button("上传并识别")
287
- upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"])
288
  with gr.Column(scale=2):
289
  refresh_btn = gr.Button("刷新列表")
290
- list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"])
291
  gr.Markdown("**确认匹配(人工)**")
292
  doc_id_in = gr.Textbox(label="文档 ID(从列表取)")
293
- supplier_dd = gr.Dropdown(choices=[x[1] for x in get_suppliers()], label="选择供应商(或先新增 suppliers.csv)")
 
 
294
  confirm_btn = gr.Button("确认关联")
295
  confirm_out = gr.Textbox()
296
  gr.Markdown("**导出选中(输入逗号分隔的 ID 列表)**")
@@ -300,10 +342,11 @@ with gr.Blocks() as demo:
300
 
301
  upload_btn.click(lambda files,uploader: upload_files(files,uploader), inputs=[file_inputs,uploader], outputs=upload_out)
302
  refresh_btn.click(lambda: refresh_list(), inputs=[], outputs=list_out)
303
- confirm_btn.click(lambda d,s: ui_confirm(d, None if not s else next((sid for sid,name in get_suppliers() if name==s), None)), inputs=[doc_id_in, supplier_dd], outputs=confirm_out)
304
- export_btn.click(lambda txt: ui_export([x.strip() for x in txt.split(",") if x.strip()]), inputs=[export_ids], outputs=export_out)
305
 
306
- gr.Markdown("**提示**:如果供应商不在下拉中,请在 repo 放入或更新 `suppliers.csv` 并 Rebuild 空间,或者修改数据库。")
307
 
308
  if __name__ == "__main__":
 
309
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
4
  import zipfile
5
  from datetime import datetime
6
  from pathlib import Path
7
+ from typing import List, Optional, Tuple
8
 
9
  import gradio as gr
10
  import pandas as pd
 
61
  def load_suppliers_from_csv():
62
  if not SUPPLIER_CSV.exists():
63
  return
64
+ try:
65
+ df = pd.read_csv(SUPPLIER_CSV, dtype=str).fillna("")
66
+ except Exception:
67
+ return
68
  session = SessionLocal()
69
  for _, r in df.iterrows():
70
  name = str(r.get("name","")).strip()
 
85
  INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)")
86
  AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)")
87
 
88
+ def do_ocr(image_bytes: bytes) -> str:
89
+ """
90
+ 用 pytesseract 对二进制图片数据做 OCR,返回识别出的纯文本。
91
+ """
92
+ img = Image.open(io.BytesIO(image_bytes))
93
+ try:
94
+ img = img.convert("RGB")
95
+ except Exception:
96
+ pass
97
  text = pytesseract.image_to_string(img, lang='chi_sim+eng')
98
  return text
99
 
 
128
  if not name:
129
  lines = [ln.strip() for ln in ocr_text.splitlines() if ln.strip()]
130
  if lines:
131
+ name = lines[0][:120]
132
  return {"taxno": taxno, "invoice_no": inv, "total": total, "name": name, "raw": ocr_text}
133
 
134
  def match_supplier(session, extracted: dict, threshold:int=80):
135
+ # 1) tax no exact match
136
  tax = extracted.get("taxno")
137
  if tax:
138
  sup = session.query(Supplier).filter(Supplier.tax_no==tax.strip()).first()
139
  if sup:
140
  return {"supplier_id": sup.id, "supplier_name": sup.name, "supplier_taxno": sup.tax_no, "score": 100.0}
141
 
142
+ # 2) name fuzzy match
143
  name = (extracted.get("name") or "").strip()
144
  if not name:
145
  return None
 
159
  # ---------- Storage helpers ----------
160
  def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"):
161
  ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
162
+ safe_name = f"{ts}_{filename.replace(' ', '_')}"
163
  path = STORAGE_DIR / safe_name
164
  with open(path, "wb") as f:
165
  f.write(file_bytes)
166
 
167
+ ocr_text = do_ocr(file_bytes)
168
  extracted = extract_fields(ocr_text)
169
  session = SessionLocal()
170
  match = match_supplier(session, extracted)
 
207
  session.close()
208
  return rows
209
 
210
+ def get_suppliers() -> List[Tuple[str,str]]:
211
  session = SessionLocal()
212
  s = session.query(Supplier).order_by(Supplier.name).all()
213
  session.close()
214
+ # return list of (id, name)
215
  return [(str(x.id), x.name) for x in s]
216
 
217
  def confirm_document(doc_id: int, supplier_id: Optional[int]):
 
266
 
267
  # ---------- Gradio UI ----------
268
  def upload_files(files, uploader):
269
+ # files: list of gradio file-like objects
270
  results = []
271
+ if not files:
272
+ return pd.DataFrame([], columns=["id","filename","supplier","score","status","invoice_no","total"])
273
  for f in files:
274
+ # f.read() returns bytes
275
+ try:
276
+ content = f.read()
277
+ except Exception:
278
+ # f may be a local path in some envs
279
+ with open(f.name, "rb") as fh:
280
+ content = fh.read()
281
+ res = save_file_and_record(content, getattr(f, "name", "uploaded_file"), uploader or "unknown")
282
  results.append(res)
283
  return pd.DataFrame(results)
284
 
 
288
  return pd.DataFrame([], columns=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"])
289
  return pd.DataFrame(rows)
290
 
291
+ def ui_confirm(doc_id: str, supplier_name: str):
292
+ # doc_id is string from textbox; supplier_name is name selected in dropdown
293
+ try:
294
+ doc_id_i = int(doc_id)
295
+ except:
296
+ return "Invalid doc id"
297
+ if not supplier_name:
298
+ # mark unknown
299
+ ok = confirm_document(doc_id_i, None)
300
+ return "ok" if ok else "fail"
301
+ # find supplier id by name
302
+ session = SessionLocal()
303
+ s = session.query(Supplier).filter(Supplier.name==supplier_name).first()
304
+ session.close()
305
+ if not s:
306
+ return "supplier not found"
307
+ ok = confirm_document(doc_id_i, s.id)
308
+ return "ok" if ok else "fail"
309
 
310
+ def ui_export(txt_ids: str):
311
+ if not txt_ids:
312
+ return ""
313
+ ids = [int(x.strip()) for x in txt_ids.split(",") if x.strip().isdigit()]
314
+ if not ids:
315
+ return ""
316
  zp = export_zip(ids)
317
+ return zp or ""
318
 
319
+ # build UI
320
  with gr.Blocks() as demo:
321
  gr.Markdown("## 发票收票分拣(Staging) \n上传发票 → 自动识别供应商 → 人工确认/导出")
322
  with gr.Row():
323
  with gr.Column(scale=3):
324
  uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk")
325
+ file_inputs = gr.File(label="拖拽或选择发票(图片 优先)", file_count="multiple")
326
  upload_btn = gr.Button("上传并识别")
327
+ upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"], interactive=False)
328
  with gr.Column(scale=2):
329
  refresh_btn = gr.Button("刷新列表")
330
+ list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"], interactive=False)
331
  gr.Markdown("**确认匹配(人工)**")
332
  doc_id_in = gr.Textbox(label="文档 ID(从列表取)")
333
+ # supplier dropdown choices filled at start; to refresh suppliers update repo or restart
334
+ supplier_choices = [name for (_id, name) in get_suppliers()]
335
+ supplier_dd = gr.Dropdown(choices=supplier_choices, label="选择供应商(或先通过 suppliers.csv 添加)")
336
  confirm_btn = gr.Button("确认关联")
337
  confirm_out = gr.Textbox()
338
  gr.Markdown("**导出选中(输入逗号分隔的 ID 列表)**")
 
342
 
343
  upload_btn.click(lambda files,uploader: upload_files(files,uploader), inputs=[file_inputs,uploader], outputs=upload_out)
344
  refresh_btn.click(lambda: refresh_list(), inputs=[], outputs=list_out)
345
+ confirm_btn.click(lambda d,s: ui_confirm(d, s), inputs=[doc_id_in, supplier_dd], outputs=confirm_out)
346
+ export_btn.click(lambda txt: ui_export(txt), inputs=[export_ids], outputs=export_out)
347
 
348
+ gr.Markdown("**提示**:如果供应商不在下拉中,请在 repo 放入或更新 `suppliers.csv` 并 Rebuild 空间,或在数据库新增。")
349
 
350
  if __name__ == "__main__":
351
+ # Gradio default port for Spaces is 7860
352
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)