Spaces:
Paused
Paused
| # app.py (final version - supports PDF/JPG/PNG + export to repo root) | |
| import os | |
| import io | |
| import zipfile | |
| import traceback | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| import pandas as pd | |
| from PIL import Image | |
| import pytesseract | |
| import re | |
| from rapidfuzz import process, fuzz | |
| from sqlalchemy import create_engine, Column, Integer, String, DateTime, Float, Text | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker | |
| from pdf2image import convert_from_bytes | |
| # ---------- config ---------- | |
| BASE_DIR = Path("/app") | |
| STORAGE_DIR = BASE_DIR / "storage" | |
| DATA_DIR = BASE_DIR / "data" | |
| SUPPLIER_CSV = BASE_DIR / "suppliers.csv" | |
| DB_PATH = DATA_DIR / "staging.db" | |
| STORAGE_DIR.mkdir(parents=True, exist_ok=True) | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| # ---------- DB ---------- | |
| engine = create_engine(f"sqlite:///{DB_PATH}", connect_args={"check_same_thread": False}) | |
| SessionLocal = sessionmaker(bind=engine) | |
| Base = declarative_base() | |
| class Supplier(Base): | |
| __tablename__ = "supplier" | |
| id = Column(Integer, primary_key=True) | |
| name = Column(String, unique=True, index=True) | |
| tax_no = Column(String, unique=True, index=True) | |
| note = Column(String, nullable=True) | |
| class Document(Base): | |
| __tablename__ = "document" | |
| id = Column(Integer, primary_key=True) | |
| filename = Column(String) | |
| filepath = Column(String) | |
| uploaded_at = Column(DateTime, default=datetime.utcnow) | |
| uploader = Column(String, default="unknown") | |
| ocr_text = Column(Text) | |
| supplier_id = Column(Integer, nullable=True) | |
| supplier_name = Column(String, nullable=True) | |
| supplier_taxno = Column(String, nullable=True) | |
| invoice_no = Column(String, nullable=True) | |
| total_amount = Column(Float, nullable=True) | |
| match_score = Column(Float, nullable=True) | |
| status = Column(String, default="new") | |
| raw_extracted = Column(Text, nullable=True) | |
| error = Column(Text, nullable=True) | |
| Base.metadata.create_all(engine) | |
| # ---------- Load Suppliers ---------- | |
| def load_suppliers_from_csv(): | |
| if not SUPPLIER_CSV.exists(): | |
| return | |
| try: | |
| df = pd.read_csv(SUPPLIER_CSV, dtype=str).fillna("") | |
| except Exception: | |
| return | |
| session = SessionLocal() | |
| for _, r in df.iterrows(): | |
| name = str(r.get("name","")).strip() | |
| tax = str(r.get("tax_no","")).strip() | |
| if not name: | |
| continue | |
| exists = session.query(Supplier).filter((Supplier.name==name)|(Supplier.tax_no==tax)).first() | |
| if not exists: | |
| s = Supplier(name=name, tax_no=tax, note=str(r.get("note",""))) | |
| session.add(s) | |
| session.commit() | |
| session.close() | |
| load_suppliers_from_csv() | |
| # ---------- OCR & Extraction ---------- | |
| TAXNO_RE = re.compile(r"([0-9A-Z]{15,20})") | |
| INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)") | |
| AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)") | |
| def do_ocr(file_bytes: bytes) -> str: | |
| """ | |
| 自动识别图片或 PDF,返回识别文字。 | |
| """ | |
| # 1) 图片 | |
| try: | |
| img = Image.open(io.BytesIO(file_bytes)) | |
| img = img.convert("RGB") | |
| text = pytesseract.image_to_string(img, lang='chi_sim+eng') | |
| if text.strip(): | |
| return text | |
| except Exception: | |
| pass | |
| # 2) PDF | |
| try: | |
| images = convert_from_bytes(file_bytes, dpi=300) | |
| texts = [] | |
| for im in images: | |
| t = pytesseract.image_to_string(im, lang='chi_sim+eng') | |
| if t.strip(): | |
| texts.append(t) | |
| return "\n\n".join(texts) | |
| except Exception as e: | |
| print("OCR error:", e) | |
| return "" | |
| def extract_fields(ocr_text: str) -> dict: | |
| text = ocr_text.replace("\n"," ") | |
| tax_candidates = TAXNO_RE.findall(text) | |
| taxno = tax_candidates[0].strip() if tax_candidates else None | |
| m = INVOICE_RE.search(text) | |
| inv = m.group(2).strip() if m else None | |
| amounts = AMOUNT_RE.findall(text.replace(",", ",")) | |
| total = None | |
| if amounts: | |
| cand = amounts[-1].replace(",", "") | |
| try: | |
| total = float(cand) | |
| except: | |
| total = None | |
| heuristics = ["纳税人名称", "开票单位", "单位名称", "收票单位", "单位:", "单位:"] | |
| name = None | |
| for h in heuristics: | |
| if h in ocr_text: | |
| idx = ocr_text.find(h) | |
| seg = ocr_text[idx: idx+150] | |
| parts = seg.splitlines() | |
| if parts: | |
| name = parts[0].replace(h,"").strip() | |
| break | |
| if not name: | |
| lines = [ln.strip() for ln in ocr_text.splitlines() if ln.strip()] | |
| if lines: | |
| name = lines[0][:120] | |
| return {"taxno": taxno, "invoice_no": inv, "total": total, "name": name, "raw": ocr_text} | |
| def match_supplier(session, extracted: dict, threshold:int=80): | |
| tax = extracted.get("taxno") | |
| if tax: | |
| sup = session.query(Supplier).filter(Supplier.tax_no==tax.strip()).first() | |
| if sup: | |
| return {"supplier_id": sup.id, "supplier_name": sup.name, "supplier_taxno": sup.tax_no, "score": 100.0} | |
| name = (extracted.get("name") or "").strip() | |
| if not name: | |
| return None | |
| suppliers = session.query(Supplier).all() | |
| if not suppliers: | |
| return None | |
| choices = {s.name: s for s in suppliers} | |
| names = list(choices.keys()) | |
| best = process.extractOne(name, names, scorer=fuzz.WRatio) | |
| if best: | |
| match_name, score, _ = best | |
| if score >= threshold: | |
| s = choices[match_name] | |
| return {"supplier_id": s.id, "supplier_name": s.name, "supplier_taxno": s.tax_no, "score": float(score)} | |
| return None | |
| # ---------- Save & Record ---------- | |
| def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"): | |
| ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f") | |
| safe_name = f"{ts}_{filename.replace(' ', '_')}" | |
| path = STORAGE_DIR / safe_name | |
| with open(path, "wb") as f: | |
| f.write(file_bytes) | |
| ocr_text = "" | |
| try: | |
| ocr_text = do_ocr(file_bytes) | |
| except Exception as e: | |
| print("OCR exception:", e, traceback.format_exc()) | |
| extracted = extract_fields(ocr_text) | |
| session = SessionLocal() | |
| match = match_supplier(session, extracted) | |
| status = "matched" if match else "unknown" | |
| doc = Document( | |
| filename=filename, | |
| filepath=str(path), | |
| uploaded_at=datetime.utcnow(), | |
| uploader=uploader, | |
| ocr_text=ocr_text, | |
| supplier_id=match["supplier_id"] if match else None, | |
| supplier_name=match["supplier_name"] if match else None, | |
| supplier_taxno=match["supplier_taxno"] if match else extracted.get("taxno"), | |
| invoice_no=extracted.get("invoice_no"), | |
| total_amount=extracted.get("total"), | |
| match_score=match["score"] if match else None, | |
| status=status, | |
| raw_extracted=str(extracted) | |
| ) | |
| session.add(doc) | |
| session.commit() | |
| session.refresh(doc) | |
| session.close() | |
| return { | |
| "id": doc.id, "filename": doc.filename, "supplier": doc.supplier_name, | |
| "score": doc.match_score, "status": doc.status, | |
| "invoice_no": doc.invoice_no, "total": doc.total_amount | |
| } | |
| # ---------- Export ---------- | |
| def export_zip(ids: List[int]): | |
| session = SessionLocal() | |
| docs = session.query(Document).filter(Document.id.in_(ids)).all() | |
| if not docs: | |
| session.close() | |
| return None | |
| ts = datetime.utcnow().strftime("%Y%m%d%H%M%S") | |
| zip_path = STORAGE_DIR / f"export_{ts}.zip" | |
| csv_buffer = io.StringIO() | |
| rows = [] | |
| for d in docs: | |
| rows.append({ | |
| "id": d.id, "filename": d.filename, "supplier_name": d.supplier_name, | |
| "supplier_taxno": d.supplier_taxno, "invoice_no": d.invoice_no, | |
| "total_amount": d.total_amount, "uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"), | |
| "status": d.status | |
| }) | |
| df = pd.DataFrame(rows) | |
| df.to_csv(csv_buffer, index=False, encoding="utf-8-sig") | |
| with zipfile.ZipFile(zip_path, "w") as zf: | |
| zf.writestr("metadata.csv", csv_buffer.getvalue()) | |
| for d in docs: | |
| if os.path.exists(d.filepath): | |
| zf.write(d.filepath, arcname=os.path.basename(d.filepath)) | |
| session.close() | |
| # === 新增:复制 ZIP 到 Space 根目录,让用户可直接下载 === | |
| import shutil | |
| try: | |
| shutil.copy(zip_path, BASE_DIR / f"export_{ts}.zip") | |
| print(f"✅ ZIP 已复制到根目录: export_{ts}.zip") | |
| except Exception as e: | |
| print("复制 ZIP 失败:", e) | |
| return str(zip_path) | |
| # ---------- Gradio UI ---------- | |
| def upload_files(files, uploader): | |
| results = [] | |
| if not files: | |
| return pd.DataFrame([], columns=["id","filename","supplier","score","status","invoice_no","total"]) | |
| for f in files: | |
| try: | |
| content = f.read() if hasattr(f, "read") else open(f.name, "rb").read() | |
| res = save_file_and_record(content, getattr(f, "name", "file"), uploader or "unknown") | |
| results.append(res) | |
| except Exception as e: | |
| results.append({ | |
| "id": -1, "filename": getattr(f, "name", "file"), | |
| "supplier": None, "score": None, "status": "error", | |
| "invoice_no": None, "total": None | |
| }) | |
| print("File process error:", e, traceback.format_exc()) | |
| return pd.DataFrame(results) | |
| def refresh_list(): | |
| session = SessionLocal() | |
| q = session.query(Document).order_by(Document.uploaded_at.desc()).limit(200).all() | |
| rows = [{ | |
| "id": d.id, "filename": d.filename, | |
| "uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"), | |
| "supplier_name": d.supplier_name, "invoice_no": d.invoice_no, | |
| "total_amount": d.total_amount, "score": d.match_score, "status": d.status | |
| } for d in q] | |
| session.close() | |
| return pd.DataFrame(rows) | |
| def ui_confirm(doc_id: str, supplier_name: str): | |
| if not doc_id.isdigit(): | |
| return "Invalid doc id" | |
| doc_id_i = int(doc_id) | |
| session = SessionLocal() | |
| s = session.query(Supplier).filter(Supplier.name==supplier_name).first() | |
| d = session.query(Document).get(doc_id_i) | |
| if not d: | |
| session.close() | |
| return "Doc not found" | |
| if not s: | |
| session.close() | |
| return "Supplier not found" | |
| d.supplier_id = s.id | |
| d.supplier_name = s.name | |
| d.supplier_taxno = s.tax_no | |
| d.match_score = 100.0 | |
| d.status = "matched" | |
| session.add(d) | |
| session.commit() | |
| session.close() | |
| return "ok" | |
| def ui_export(txt_ids: str): | |
| ids = [int(x.strip()) for x in txt_ids.split(",") if x.strip().isdigit()] | |
| if not ids: | |
| return "" | |
| return export_zip(ids) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 发票收票分拣(Staging)\n上传发票 → 自动识别供应商 → 人工确认/导出") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk") | |
| file_inputs = gr.File(label="拖拽或选择发票(图片或 PDF)", file_count="multiple") | |
| upload_btn = gr.Button("上传并识别") | |
| upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"], interactive=False) | |
| with gr.Column(scale=2): | |
| refresh_btn = gr.Button("刷新列表") | |
| list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"], interactive=False) | |
| doc_id_in = gr.Textbox(label="文档 ID(从列表复制)") | |
| supplier_dd = gr.Dropdown(choices=[s.name for s in SessionLocal().query(Supplier).all()], label="选择供应商") | |
| confirm_btn = gr.Button("确认关联") | |
| confirm_out = gr.Textbox(label="结果") | |
| export_ids = gr.Textbox(label="导出 ID(如 1,2,3)") | |
| export_btn = gr.Button("导出为 ZIP") | |
| export_out = gr.Textbox(label="导出结果") | |
| upload_btn.click(upload_files, inputs=[file_inputs, uploader], outputs=upload_out) | |
| refresh_btn.click(refresh_list, inputs=[], outputs=list_out) | |
| confirm_btn.click(ui_confirm, inputs=[doc_id_in, supplier_dd], outputs=confirm_out) | |
| export_btn.click(ui_export, inputs=[export_ids], outputs=export_out) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |