Gabriel00A commited on
Commit
71d3ff0
·
verified ·
1 Parent(s): 17face9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -158
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py (final robust version - supports PNG/JPG/PDF and returns errors nicely)
2
  import os
3
  import io
4
  import zipfile
@@ -16,8 +16,6 @@ from rapidfuzz import process, fuzz
16
  from sqlalchemy import create_engine, Column, Integer, String, DateTime, Float, Text
17
  from sqlalchemy.ext.declarative import declarative_base
18
  from sqlalchemy.orm import sessionmaker
19
-
20
- # pdf2image for PDF -> image
21
  from pdf2image import convert_from_bytes
22
 
23
  # ---------- config ----------
@@ -56,13 +54,13 @@ class Document(Base):
56
  invoice_no = Column(String, nullable=True)
57
  total_amount = Column(Float, nullable=True)
58
  match_score = Column(Float, nullable=True)
59
- status = Column(String, default="new") # new / matched / unknown / exported
60
  raw_extracted = Column(Text, nullable=True)
61
  error = Column(Text, nullable=True)
62
 
63
  Base.metadata.create_all(engine)
64
 
65
- # load suppliers CSV if present
66
  def load_suppliers_from_csv():
67
  if not SUPPLIER_CSV.exists():
68
  return
@@ -85,44 +83,36 @@ def load_suppliers_from_csv():
85
 
86
  load_suppliers_from_csv()
87
 
88
- # ---------- OCR & extract ----------
89
  TAXNO_RE = re.compile(r"([0-9A-Z]{15,20})")
90
  INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)")
91
  AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)")
92
 
93
  def do_ocr(file_bytes: bytes) -> str:
94
  """
95
- 支持图片 PDF
96
- - 先尝试把 bytes 当图片打开(PIL)
97
- - 如果不是图片或识别无结果,尝试 pdf2image 将第一页转为图片再 OCR
98
- 返回拼接的识别文本(若多页可合并)。
99
  """
100
- # 1) try image
101
  try:
102
  img = Image.open(io.BytesIO(file_bytes))
103
  img = img.convert("RGB")
104
  text = pytesseract.image_to_string(img, lang='chi_sim+eng')
105
- if text and text.strip():
106
  return text
107
  except Exception:
108
- # 不是图片或 PIL 打开失败 -> fall through to PDF attempt
109
  pass
110
 
111
- # 2) try PDF -> images
112
  try:
113
- images = convert_from_bytes(file_bytes, dpi=300) # convert all pages (if many, okay)
114
  texts = []
115
  for im in images:
116
- try:
117
- t = pytesseract.image_to_string(im, lang='chi_sim+eng')
118
- except Exception:
119
- t = ""
120
- if t and t.strip():
121
  texts.append(t)
122
  return "\n\n".join(texts)
123
  except Exception as e:
124
- # 无法处理为 PDF
125
- print("do_ocr error:", e)
126
  return ""
127
 
128
  def extract_fields(ocr_text: str) -> dict:
@@ -181,7 +171,7 @@ def match_supplier(session, extracted: dict, threshold:int=80):
181
  return {"supplier_id": s.id, "supplier_name": s.name, "supplier_taxno": s.tax_no, "score": float(score)}
182
  return None
183
 
184
- # ---------- Storage helpers ----------
185
  def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"):
186
  ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
187
  safe_name = f"{ts}_{filename.replace(' ', '_')}"
@@ -193,15 +183,12 @@ def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unkn
193
  try:
194
  ocr_text = do_ocr(file_bytes)
195
  except Exception as e:
196
- ocr_text = ""
197
  print("OCR exception:", e, traceback.format_exc())
198
 
199
  extracted = extract_fields(ocr_text)
200
  session = SessionLocal()
201
  match = match_supplier(session, extracted)
202
- # default status and error None
203
  status = "matched" if match else "unknown"
204
- error_msg = None
205
 
206
  doc = Document(
207
  filename=filename,
@@ -216,8 +203,7 @@ def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unkn
216
  total_amount=extracted.get("total"),
217
  match_score=match["score"] if match else None,
218
  status=status,
219
- raw_extracted=str(extracted),
220
- error=error_msg
221
  )
222
  session.add(doc)
223
  session.commit()
@@ -225,53 +211,11 @@ def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unkn
225
  session.close()
226
  return {
227
  "id": doc.id, "filename": doc.filename, "supplier": doc.supplier_name,
228
- "score": doc.match_score, "status": doc.status, "invoice_no": doc.invoice_no, "total": doc.total_amount
 
229
  }
230
 
231
- def list_documents(limit=200):
232
- session = SessionLocal()
233
- q = session.query(Document).order_by(Document.uploaded_at.desc()).limit(limit).all()
234
- rows = []
235
- for d in q:
236
- rows.append({
237
- "id": d.id, "filename": d.filename,
238
- "uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"),
239
- "supplier_id": d.supplier_id, "supplier_name": d.supplier_name,
240
- "supplier_taxno": d.supplier_taxno, "invoice_no": d.invoice_no,
241
- "total_amount": d.total_amount, "score": d.match_score, "status": d.status
242
- })
243
- session.close()
244
- return rows
245
-
246
- def get_suppliers() -> List[Tuple[str,str]]:
247
- session = SessionLocal()
248
- s = session.query(Supplier).order_by(Supplier.name).all()
249
- session.close()
250
- return [(str(x.id), x.name) for x in s]
251
-
252
- def confirm_document(doc_id: int, supplier_id: Optional[int]):
253
- session = SessionLocal()
254
- d = session.query(Document).get(doc_id)
255
- if not d:
256
- session.close()
257
- return False
258
- if supplier_id:
259
- s = session.query(Supplier).get(supplier_id)
260
- if not s:
261
- session.close()
262
- return False
263
- d.supplier_id = s.id
264
- d.supplier_name = s.name
265
- d.supplier_taxno = s.tax_no
266
- d.match_score = 100.0
267
- d.status = "matched"
268
- else:
269
- d.status = "unknown"
270
- session.add(d)
271
- session.commit()
272
- session.close()
273
- return True
274
-
275
  def export_zip(ids: List[int]):
276
  session = SessionLocal()
277
  docs = session.query(Document).filter(Document.id.in_(ids)).all()
@@ -297,131 +241,100 @@ def export_zip(ids: List[int]):
297
  if os.path.exists(d.filepath):
298
  zf.write(d.filepath, arcname=os.path.basename(d.filepath))
299
  session.close()
300
- return str(zip_path)
301
 
302
- # ---------- Gradio UI ----------
303
- def _read_uploaded_file(f):
304
- """
305
- Robust reader: f may be a gradio file-like object with read(), or a local file path object.
306
- 返回 bytes 和 文件名
307
- """
308
  try:
309
- content = f.read()
310
- name = getattr(f, "name", None) or getattr(f, "filename", None) or "uploaded_file"
311
- # when gradio gives a SpooledTemporaryFile, name may be a path
312
- return content, os.path.basename(name)
313
- except Exception:
314
- # fallback: if f has .name path, read from disk
315
- try:
316
- path = f.name
317
- with open(path, "rb") as fh:
318
- return fh.read(), os.path.basename(path)
319
- except Exception as e:
320
- raise RuntimeError(f"无法读取上传文件: {e}")
321
 
 
 
 
322
  def upload_files(files, uploader):
323
  results = []
324
  if not files:
325
  return pd.DataFrame([], columns=["id","filename","supplier","score","status","invoice_no","total"])
326
  for f in files:
327
  try:
328
- content, filename = _read_uploaded_file(f)
329
- res = save_file_and_record(content, filename, uploader or "unknown")
330
  results.append(res)
331
  except Exception as e:
332
- # record error into DB so we can inspect later
333
- try:
334
- ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
335
- safe_name = f"err_{ts}_{getattr(f, 'name', 'unknown')}"
336
- p = STORAGE_DIR / safe_name
337
- # try write raw if possible
338
- try:
339
- with open(p, "wb") as fh:
340
- if hasattr(f, "read"):
341
- fh.write(f.read())
342
- except Exception:
343
- pass
344
- except Exception:
345
- pass
346
- # append a visible error row
347
  results.append({
348
- "id": -1,
349
- "filename": getattr(f, "name", "uploaded_file"),
350
- "supplier": None,
351
- "score": None,
352
- "status": "error",
353
- "invoice_no": None,
354
- "total": None,
355
- "error": str(e) + "\n" + traceback.format_exc()
356
  })
357
- print("file processing error:", e, traceback.format_exc())
358
- df = pd.DataFrame(results)
359
- # ensure consistent columns for front-end display
360
- cols = ["id","filename","supplier","score","status","invoice_no","total"]
361
- for c in cols:
362
- if c not in df.columns:
363
- df[c] = None
364
- return df[cols]
365
 
366
  def refresh_list():
367
- rows = list_documents()
368
- if not rows:
369
- return pd.DataFrame([], columns=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"])
 
 
 
 
 
 
370
  return pd.DataFrame(rows)
371
 
372
  def ui_confirm(doc_id: str, supplier_name: str):
373
- try:
374
- doc_id_i = int(doc_id)
375
- except:
376
  return "Invalid doc id"
377
- if not supplier_name:
378
- ok = confirm_document(doc_id_i, None)
379
- return "ok" if ok else "fail"
380
  session = SessionLocal()
381
  s = session.query(Supplier).filter(Supplier.name==supplier_name).first()
382
- session.close()
 
 
 
383
  if not s:
384
- return "supplier not found"
385
- ok = confirm_document(doc_id_i, s.id)
386
- return "ok" if ok else "fail"
 
 
 
 
 
 
 
 
387
 
388
  def ui_export(txt_ids: str):
389
- if not txt_ids:
390
- return ""
391
  ids = [int(x.strip()) for x in txt_ids.split(",") if x.strip().isdigit()]
392
  if not ids:
393
  return ""
394
- zp = export_zip(ids)
395
- return zp or ""
396
 
397
  with gr.Blocks() as demo:
398
- gr.Markdown("## 发票收票分拣(Staging) \n上传发票 → 自动识别供应商 → 人工确认/导出")
399
  with gr.Row():
400
  with gr.Column(scale=3):
401
  uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk")
402
- file_inputs = gr.File(label="拖拽或选择发票(图片 优先)", file_count="multiple")
403
  upload_btn = gr.Button("上传并识别")
404
  upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"], interactive=False)
405
  with gr.Column(scale=2):
406
  refresh_btn = gr.Button("刷新列表")
407
  list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"], interactive=False)
408
- gr.Markdown("**确认匹配人工**")
409
- doc_id_in = gr.Textbox(label="文档 ID(从列表取)")
410
- supplier_choices = [name for (_id, name) in get_suppliers()]
411
- supplier_dd = gr.Dropdown(choices=supplier_choices, label="选择供应商(或先通过 suppliers.csv 添加)")
412
  confirm_btn = gr.Button("确认关联")
413
- confirm_out = gr.Textbox()
414
- gr.Markdown("**导出选中(输入逗号分隔的 ID 列表**")
415
- export_ids = gr.Textbox(label="ID 列表,例如:1,2,3")
416
  export_btn = gr.Button("导出为 ZIP")
417
- export_out = gr.Textbox()
418
-
419
- upload_btn.click(lambda files,uploader: upload_files(files,uploader), inputs=[file_inputs,uploader], outputs=upload_out)
420
- refresh_btn.click(lambda: refresh_list(), inputs=[], outputs=list_out)
421
- confirm_btn.click(lambda d,s: ui_confirm(d, s), inputs=[doc_id_in, supplier_dd], outputs=confirm_out)
422
- export_btn.click(lambda txt: ui_export(txt), inputs=[export_ids], outputs=export_out)
423
 
424
- gr.Markdown("**提示**:如果供应商不在下拉中,请在 repo 放入或更新 `suppliers.csv` 并 Rebuild 空间,或在数据库新增。")
 
 
 
425
 
426
  if __name__ == "__main__":
427
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
+ # app.py (final version - supports PDF/JPG/PNG + export to repo root)
2
  import os
3
  import io
4
  import zipfile
 
16
  from sqlalchemy import create_engine, Column, Integer, String, DateTime, Float, Text
17
  from sqlalchemy.ext.declarative import declarative_base
18
  from sqlalchemy.orm import sessionmaker
 
 
19
  from pdf2image import convert_from_bytes
20
 
21
  # ---------- config ----------
 
54
  invoice_no = Column(String, nullable=True)
55
  total_amount = Column(Float, nullable=True)
56
  match_score = Column(Float, nullable=True)
57
+ status = Column(String, default="new")
58
  raw_extracted = Column(Text, nullable=True)
59
  error = Column(Text, nullable=True)
60
 
61
  Base.metadata.create_all(engine)
62
 
63
+ # ---------- Load Suppliers ----------
64
  def load_suppliers_from_csv():
65
  if not SUPPLIER_CSV.exists():
66
  return
 
83
 
84
  load_suppliers_from_csv()
85
 
86
+ # ---------- OCR & Extraction ----------
87
  TAXNO_RE = re.compile(r"([0-9A-Z]{15,20})")
88
  INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)")
89
  AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)")
90
 
91
  def do_ocr(file_bytes: bytes) -> str:
92
  """
93
+ 自动识别图片 PDF,返回识别文字。
 
 
 
94
  """
95
+ # 1) 图片
96
  try:
97
  img = Image.open(io.BytesIO(file_bytes))
98
  img = img.convert("RGB")
99
  text = pytesseract.image_to_string(img, lang='chi_sim+eng')
100
+ if text.strip():
101
  return text
102
  except Exception:
 
103
  pass
104
 
105
+ # 2) PDF
106
  try:
107
+ images = convert_from_bytes(file_bytes, dpi=300)
108
  texts = []
109
  for im in images:
110
+ t = pytesseract.image_to_string(im, lang='chi_sim+eng')
111
+ if t.strip():
 
 
 
112
  texts.append(t)
113
  return "\n\n".join(texts)
114
  except Exception as e:
115
+ print("OCR error:", e)
 
116
  return ""
117
 
118
  def extract_fields(ocr_text: str) -> dict:
 
171
  return {"supplier_id": s.id, "supplier_name": s.name, "supplier_taxno": s.tax_no, "score": float(score)}
172
  return None
173
 
174
+ # ---------- Save & Record ----------
175
  def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"):
176
  ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
177
  safe_name = f"{ts}_{filename.replace(' ', '_')}"
 
183
  try:
184
  ocr_text = do_ocr(file_bytes)
185
  except Exception as e:
 
186
  print("OCR exception:", e, traceback.format_exc())
187
 
188
  extracted = extract_fields(ocr_text)
189
  session = SessionLocal()
190
  match = match_supplier(session, extracted)
 
191
  status = "matched" if match else "unknown"
 
192
 
193
  doc = Document(
194
  filename=filename,
 
203
  total_amount=extracted.get("total"),
204
  match_score=match["score"] if match else None,
205
  status=status,
206
+ raw_extracted=str(extracted)
 
207
  )
208
  session.add(doc)
209
  session.commit()
 
211
  session.close()
212
  return {
213
  "id": doc.id, "filename": doc.filename, "supplier": doc.supplier_name,
214
+ "score": doc.match_score, "status": doc.status,
215
+ "invoice_no": doc.invoice_no, "total": doc.total_amount
216
  }
217
 
218
+ # ---------- Export ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def export_zip(ids: List[int]):
220
  session = SessionLocal()
221
  docs = session.query(Document).filter(Document.id.in_(ids)).all()
 
241
  if os.path.exists(d.filepath):
242
  zf.write(d.filepath, arcname=os.path.basename(d.filepath))
243
  session.close()
 
244
 
245
+ # === 新增:复制 ZIP 到 Space 根目录,让用户可直接下载 ===
246
+ import shutil
 
 
 
 
247
  try:
248
+ shutil.copy(zip_path, BASE_DIR / f"export_{ts}.zip")
249
+ print(f" ZIP 已复制到根目录: export_{ts}.zip")
250
+ except Exception as e:
251
+ print("复制 ZIP 失败:", e)
 
 
 
 
 
 
 
 
252
 
253
+ return str(zip_path)
254
+
255
+ # ---------- Gradio UI ----------
256
  def upload_files(files, uploader):
257
  results = []
258
  if not files:
259
  return pd.DataFrame([], columns=["id","filename","supplier","score","status","invoice_no","total"])
260
  for f in files:
261
  try:
262
+ content = f.read() if hasattr(f, "read") else open(f.name, "rb").read()
263
+ res = save_file_and_record(content, getattr(f, "name", "file"), uploader or "unknown")
264
  results.append(res)
265
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  results.append({
267
+ "id": -1, "filename": getattr(f, "name", "file"),
268
+ "supplier": None, "score": None, "status": "error",
269
+ "invoice_no": None, "total": None
 
 
 
 
 
270
  })
271
+ print("File process error:", e, traceback.format_exc())
272
+ return pd.DataFrame(results)
 
 
 
 
 
 
273
 
274
  def refresh_list():
275
+ session = SessionLocal()
276
+ q = session.query(Document).order_by(Document.uploaded_at.desc()).limit(200).all()
277
+ rows = [{
278
+ "id": d.id, "filename": d.filename,
279
+ "uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"),
280
+ "supplier_name": d.supplier_name, "invoice_no": d.invoice_no,
281
+ "total_amount": d.total_amount, "score": d.match_score, "status": d.status
282
+ } for d in q]
283
+ session.close()
284
  return pd.DataFrame(rows)
285
 
286
  def ui_confirm(doc_id: str, supplier_name: str):
287
+ if not doc_id.isdigit():
 
 
288
  return "Invalid doc id"
289
+ doc_id_i = int(doc_id)
 
 
290
  session = SessionLocal()
291
  s = session.query(Supplier).filter(Supplier.name==supplier_name).first()
292
+ d = session.query(Document).get(doc_id_i)
293
+ if not d:
294
+ session.close()
295
+ return "Doc not found"
296
  if not s:
297
+ session.close()
298
+ return "Supplier not found"
299
+ d.supplier_id = s.id
300
+ d.supplier_name = s.name
301
+ d.supplier_taxno = s.tax_no
302
+ d.match_score = 100.0
303
+ d.status = "matched"
304
+ session.add(d)
305
+ session.commit()
306
+ session.close()
307
+ return "ok"
308
 
309
  def ui_export(txt_ids: str):
 
 
310
  ids = [int(x.strip()) for x in txt_ids.split(",") if x.strip().isdigit()]
311
  if not ids:
312
  return ""
313
+ return export_zip(ids)
 
314
 
315
  with gr.Blocks() as demo:
316
+ gr.Markdown("## 发票收票分拣(Staging)\n上传发票 → 自动识别供应商 → 人工确认/导出")
317
  with gr.Row():
318
  with gr.Column(scale=3):
319
  uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk")
320
+ file_inputs = gr.File(label="拖拽或选择发票(图片 PDF)", file_count="multiple")
321
  upload_btn = gr.Button("上传并识别")
322
  upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"], interactive=False)
323
  with gr.Column(scale=2):
324
  refresh_btn = gr.Button("刷新列表")
325
  list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"], interactive=False)
326
+ doc_id_in = gr.Textbox(label="文档 ID从列表复制)")
327
+ supplier_dd = gr.Dropdown(choices=[s.name for s in SessionLocal().query(Supplier).all()], label="选择供应商")
 
 
328
  confirm_btn = gr.Button("确认关联")
329
+ confirm_out = gr.Textbox(label="结果")
330
+ export_ids = gr.Textbox(label="导出 ID(如 1,2,3)")
 
331
  export_btn = gr.Button("导出为 ZIP")
332
+ export_out = gr.Textbox(label="导出结果")
 
 
 
 
 
333
 
334
+ upload_btn.click(upload_files, inputs=[file_inputs, uploader], outputs=upload_out)
335
+ refresh_btn.click(refresh_list, inputs=[], outputs=list_out)
336
+ confirm_btn.click(ui_confirm, inputs=[doc_id_in, supplier_dd], outputs=confirm_out)
337
+ export_btn.click(ui_export, inputs=[export_ids], outputs=export_out)
338
 
339
  if __name__ == "__main__":
340
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)