# app.py (final version - supports PDF/JPG/PNG + export to repo root) import os import io import zipfile import traceback from datetime import datetime from pathlib import Path from typing import List, Optional, Tuple import gradio as gr import pandas as pd from PIL import Image import pytesseract import re from rapidfuzz import process, fuzz from sqlalchemy import create_engine, Column, Integer, String, DateTime, Float, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from pdf2image import convert_from_bytes # ---------- config ---------- BASE_DIR = Path("/app") STORAGE_DIR = BASE_DIR / "storage" DATA_DIR = BASE_DIR / "data" SUPPLIER_CSV = BASE_DIR / "suppliers.csv" DB_PATH = DATA_DIR / "staging.db" STORAGE_DIR.mkdir(parents=True, exist_ok=True) DATA_DIR.mkdir(parents=True, exist_ok=True) # ---------- DB ---------- engine = create_engine(f"sqlite:///{DB_PATH}", connect_args={"check_same_thread": False}) SessionLocal = sessionmaker(bind=engine) Base = declarative_base() class Supplier(Base): __tablename__ = "supplier" id = Column(Integer, primary_key=True) name = Column(String, unique=True, index=True) tax_no = Column(String, unique=True, index=True) note = Column(String, nullable=True) class Document(Base): __tablename__ = "document" id = Column(Integer, primary_key=True) filename = Column(String) filepath = Column(String) uploaded_at = Column(DateTime, default=datetime.utcnow) uploader = Column(String, default="unknown") ocr_text = Column(Text) supplier_id = Column(Integer, nullable=True) supplier_name = Column(String, nullable=True) supplier_taxno = Column(String, nullable=True) invoice_no = Column(String, nullable=True) total_amount = Column(Float, nullable=True) match_score = Column(Float, nullable=True) status = Column(String, default="new") raw_extracted = Column(Text, nullable=True) error = Column(Text, nullable=True) Base.metadata.create_all(engine) # ---------- Load Suppliers ---------- def load_suppliers_from_csv(): if not SUPPLIER_CSV.exists(): return try: df = pd.read_csv(SUPPLIER_CSV, dtype=str).fillna("") except Exception: return session = SessionLocal() for _, r in df.iterrows(): name = str(r.get("name","")).strip() tax = str(r.get("tax_no","")).strip() if not name: continue exists = session.query(Supplier).filter((Supplier.name==name)|(Supplier.tax_no==tax)).first() if not exists: s = Supplier(name=name, tax_no=tax, note=str(r.get("note",""))) session.add(s) session.commit() session.close() load_suppliers_from_csv() # ---------- OCR & Extraction ---------- TAXNO_RE = re.compile(r"([0-9A-Z]{15,20})") INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)") AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)") def do_ocr(file_bytes: bytes) -> str: """ 自动识别图片或 PDF,返回识别文字。 """ # 1) 图片 try: img = Image.open(io.BytesIO(file_bytes)) img = img.convert("RGB") text = pytesseract.image_to_string(img, lang='chi_sim+eng') if text.strip(): return text except Exception: pass # 2) PDF try: images = convert_from_bytes(file_bytes, dpi=300) texts = [] for im in images: t = pytesseract.image_to_string(im, lang='chi_sim+eng') if t.strip(): texts.append(t) return "\n\n".join(texts) except Exception as e: print("OCR error:", e) return "" def extract_fields(ocr_text: str) -> dict: text = ocr_text.replace("\n"," ") tax_candidates = TAXNO_RE.findall(text) taxno = tax_candidates[0].strip() if tax_candidates else None m = INVOICE_RE.search(text) inv = m.group(2).strip() if m else None amounts = AMOUNT_RE.findall(text.replace(",", ",")) total = None if amounts: cand = amounts[-1].replace(",", "") try: total = float(cand) except: total = None heuristics = ["纳税人名称", "开票单位", "单位名称", "收票单位", "单位:", "单位:"] name = None for h in heuristics: if h in ocr_text: idx = ocr_text.find(h) seg = ocr_text[idx: idx+150] parts = seg.splitlines() if parts: name = parts[0].replace(h,"").strip() break if not name: lines = [ln.strip() for ln in ocr_text.splitlines() if ln.strip()] if lines: name = lines[0][:120] return {"taxno": taxno, "invoice_no": inv, "total": total, "name": name, "raw": ocr_text} def match_supplier(session, extracted: dict, threshold:int=80): tax = extracted.get("taxno") if tax: sup = session.query(Supplier).filter(Supplier.tax_no==tax.strip()).first() if sup: return {"supplier_id": sup.id, "supplier_name": sup.name, "supplier_taxno": sup.tax_no, "score": 100.0} name = (extracted.get("name") or "").strip() if not name: return None suppliers = session.query(Supplier).all() if not suppliers: return None choices = {s.name: s for s in suppliers} names = list(choices.keys()) best = process.extractOne(name, names, scorer=fuzz.WRatio) if best: match_name, score, _ = best if score >= threshold: s = choices[match_name] return {"supplier_id": s.id, "supplier_name": s.name, "supplier_taxno": s.tax_no, "score": float(score)} return None # ---------- Save & Record ---------- def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"): ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f") safe_name = f"{ts}_{filename.replace(' ', '_')}" path = STORAGE_DIR / safe_name with open(path, "wb") as f: f.write(file_bytes) ocr_text = "" try: ocr_text = do_ocr(file_bytes) except Exception as e: print("OCR exception:", e, traceback.format_exc()) extracted = extract_fields(ocr_text) session = SessionLocal() match = match_supplier(session, extracted) status = "matched" if match else "unknown" doc = Document( filename=filename, filepath=str(path), uploaded_at=datetime.utcnow(), uploader=uploader, ocr_text=ocr_text, supplier_id=match["supplier_id"] if match else None, supplier_name=match["supplier_name"] if match else None, supplier_taxno=match["supplier_taxno"] if match else extracted.get("taxno"), invoice_no=extracted.get("invoice_no"), total_amount=extracted.get("total"), match_score=match["score"] if match else None, status=status, raw_extracted=str(extracted) ) session.add(doc) session.commit() session.refresh(doc) session.close() return { "id": doc.id, "filename": doc.filename, "supplier": doc.supplier_name, "score": doc.match_score, "status": doc.status, "invoice_no": doc.invoice_no, "total": doc.total_amount } # ---------- Export ---------- def export_zip(ids: List[int]): session = SessionLocal() docs = session.query(Document).filter(Document.id.in_(ids)).all() if not docs: session.close() return None ts = datetime.utcnow().strftime("%Y%m%d%H%M%S") zip_path = STORAGE_DIR / f"export_{ts}.zip" csv_buffer = io.StringIO() rows = [] for d in docs: rows.append({ "id": d.id, "filename": d.filename, "supplier_name": d.supplier_name, "supplier_taxno": d.supplier_taxno, "invoice_no": d.invoice_no, "total_amount": d.total_amount, "uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"), "status": d.status }) df = pd.DataFrame(rows) df.to_csv(csv_buffer, index=False, encoding="utf-8-sig") with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("metadata.csv", csv_buffer.getvalue()) for d in docs: if os.path.exists(d.filepath): zf.write(d.filepath, arcname=os.path.basename(d.filepath)) session.close() # === 新增:复制 ZIP 到 Space 根目录,让用户可直接下载 === import shutil try: shutil.copy(zip_path, BASE_DIR / f"export_{ts}.zip") print(f"✅ ZIP 已复制到根目录: export_{ts}.zip") except Exception as e: print("复制 ZIP 失败:", e) return str(zip_path) # ---------- Gradio UI ---------- def upload_files(files, uploader): results = [] if not files: return pd.DataFrame([], columns=["id","filename","supplier","score","status","invoice_no","total"]) for f in files: try: content = f.read() if hasattr(f, "read") else open(f.name, "rb").read() res = save_file_and_record(content, getattr(f, "name", "file"), uploader or "unknown") results.append(res) except Exception as e: results.append({ "id": -1, "filename": getattr(f, "name", "file"), "supplier": None, "score": None, "status": "error", "invoice_no": None, "total": None }) print("File process error:", e, traceback.format_exc()) return pd.DataFrame(results) def refresh_list(): session = SessionLocal() q = session.query(Document).order_by(Document.uploaded_at.desc()).limit(200).all() rows = [{ "id": d.id, "filename": d.filename, "uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"), "supplier_name": d.supplier_name, "invoice_no": d.invoice_no, "total_amount": d.total_amount, "score": d.match_score, "status": d.status } for d in q] session.close() return pd.DataFrame(rows) def ui_confirm(doc_id: str, supplier_name: str): if not doc_id.isdigit(): return "Invalid doc id" doc_id_i = int(doc_id) session = SessionLocal() s = session.query(Supplier).filter(Supplier.name==supplier_name).first() d = session.query(Document).get(doc_id_i) if not d: session.close() return "Doc not found" if not s: session.close() return "Supplier not found" d.supplier_id = s.id d.supplier_name = s.name d.supplier_taxno = s.tax_no d.match_score = 100.0 d.status = "matched" session.add(d) session.commit() session.close() return "ok" def ui_export(txt_ids: str): ids = [int(x.strip()) for x in txt_ids.split(",") if x.strip().isdigit()] if not ids: return "" return export_zip(ids) with gr.Blocks() as demo: gr.Markdown("## 发票收票分拣(Staging)\n上传发票 → 自动识别供应商 → 人工确认/导出") with gr.Row(): with gr.Column(scale=3): uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk") file_inputs = gr.File(label="拖拽或选择发票(图片或 PDF)", file_count="multiple") upload_btn = gr.Button("上传并识别") upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"], interactive=False) with gr.Column(scale=2): refresh_btn = gr.Button("刷新列表") list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"], interactive=False) doc_id_in = gr.Textbox(label="文档 ID(从列表复制)") supplier_dd = gr.Dropdown(choices=[s.name for s in SessionLocal().query(Supplier).all()], label="选择供应商") confirm_btn = gr.Button("确认关联") confirm_out = gr.Textbox(label="结果") export_ids = gr.Textbox(label="导出 ID(如 1,2,3)") export_btn = gr.Button("导出为 ZIP") export_out = gr.Textbox(label="导出结果") upload_btn.click(upload_files, inputs=[file_inputs, uploader], outputs=upload_out) refresh_btn.click(refresh_list, inputs=[], outputs=list_out) confirm_btn.click(ui_confirm, inputs=[doc_id_in, supplier_dd], outputs=confirm_out) export_btn.click(ui_export, inputs=[export_ids], outputs=export_out) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)