invoice-staging / app.py
Gabriel00A's picture
Update app.py
71d3ff0 verified
# app.py (final version - supports PDF/JPG/PNG + export to repo root)
import os
import io
import zipfile
import traceback
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Tuple
import gradio as gr
import pandas as pd
from PIL import Image
import pytesseract
import re
from rapidfuzz import process, fuzz
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Float, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from pdf2image import convert_from_bytes
# ---------- config ----------
BASE_DIR = Path("/app")
STORAGE_DIR = BASE_DIR / "storage"
DATA_DIR = BASE_DIR / "data"
SUPPLIER_CSV = BASE_DIR / "suppliers.csv"
DB_PATH = DATA_DIR / "staging.db"
STORAGE_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR.mkdir(parents=True, exist_ok=True)
# ---------- DB ----------
engine = create_engine(f"sqlite:///{DB_PATH}", connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(bind=engine)
Base = declarative_base()
class Supplier(Base):
__tablename__ = "supplier"
id = Column(Integer, primary_key=True)
name = Column(String, unique=True, index=True)
tax_no = Column(String, unique=True, index=True)
note = Column(String, nullable=True)
class Document(Base):
__tablename__ = "document"
id = Column(Integer, primary_key=True)
filename = Column(String)
filepath = Column(String)
uploaded_at = Column(DateTime, default=datetime.utcnow)
uploader = Column(String, default="unknown")
ocr_text = Column(Text)
supplier_id = Column(Integer, nullable=True)
supplier_name = Column(String, nullable=True)
supplier_taxno = Column(String, nullable=True)
invoice_no = Column(String, nullable=True)
total_amount = Column(Float, nullable=True)
match_score = Column(Float, nullable=True)
status = Column(String, default="new")
raw_extracted = Column(Text, nullable=True)
error = Column(Text, nullable=True)
Base.metadata.create_all(engine)
# ---------- Load Suppliers ----------
def load_suppliers_from_csv():
if not SUPPLIER_CSV.exists():
return
try:
df = pd.read_csv(SUPPLIER_CSV, dtype=str).fillna("")
except Exception:
return
session = SessionLocal()
for _, r in df.iterrows():
name = str(r.get("name","")).strip()
tax = str(r.get("tax_no","")).strip()
if not name:
continue
exists = session.query(Supplier).filter((Supplier.name==name)|(Supplier.tax_no==tax)).first()
if not exists:
s = Supplier(name=name, tax_no=tax, note=str(r.get("note","")))
session.add(s)
session.commit()
session.close()
load_suppliers_from_csv()
# ---------- OCR & Extraction ----------
TAXNO_RE = re.compile(r"([0-9A-Z]{15,20})")
INVOICE_RE = re.compile(r"(发票代码|发票号码|发票号|Invoice No|Invoice No\.)[::\s]*([A-Za-z0-9\-]+)")
AMOUNT_RE = re.compile(r"([0-9]{1,3}(?:[,,][0-9]{3})*(?:\.[0-9]{1,2})?)")
def do_ocr(file_bytes: bytes) -> str:
"""
自动识别图片或 PDF,返回识别文字。
"""
# 1) 图片
try:
img = Image.open(io.BytesIO(file_bytes))
img = img.convert("RGB")
text = pytesseract.image_to_string(img, lang='chi_sim+eng')
if text.strip():
return text
except Exception:
pass
# 2) PDF
try:
images = convert_from_bytes(file_bytes, dpi=300)
texts = []
for im in images:
t = pytesseract.image_to_string(im, lang='chi_sim+eng')
if t.strip():
texts.append(t)
return "\n\n".join(texts)
except Exception as e:
print("OCR error:", e)
return ""
def extract_fields(ocr_text: str) -> dict:
text = ocr_text.replace("\n"," ")
tax_candidates = TAXNO_RE.findall(text)
taxno = tax_candidates[0].strip() if tax_candidates else None
m = INVOICE_RE.search(text)
inv = m.group(2).strip() if m else None
amounts = AMOUNT_RE.findall(text.replace(",", ","))
total = None
if amounts:
cand = amounts[-1].replace(",", "")
try:
total = float(cand)
except:
total = None
heuristics = ["纳税人名称", "开票单位", "单位名称", "收票单位", "单位:", "单位:"]
name = None
for h in heuristics:
if h in ocr_text:
idx = ocr_text.find(h)
seg = ocr_text[idx: idx+150]
parts = seg.splitlines()
if parts:
name = parts[0].replace(h,"").strip()
break
if not name:
lines = [ln.strip() for ln in ocr_text.splitlines() if ln.strip()]
if lines:
name = lines[0][:120]
return {"taxno": taxno, "invoice_no": inv, "total": total, "name": name, "raw": ocr_text}
def match_supplier(session, extracted: dict, threshold:int=80):
tax = extracted.get("taxno")
if tax:
sup = session.query(Supplier).filter(Supplier.tax_no==tax.strip()).first()
if sup:
return {"supplier_id": sup.id, "supplier_name": sup.name, "supplier_taxno": sup.tax_no, "score": 100.0}
name = (extracted.get("name") or "").strip()
if not name:
return None
suppliers = session.query(Supplier).all()
if not suppliers:
return None
choices = {s.name: s for s in suppliers}
names = list(choices.keys())
best = process.extractOne(name, names, scorer=fuzz.WRatio)
if best:
match_name, score, _ = best
if score >= threshold:
s = choices[match_name]
return {"supplier_id": s.id, "supplier_name": s.name, "supplier_taxno": s.tax_no, "score": float(score)}
return None
# ---------- Save & Record ----------
def save_file_and_record(file_bytes: bytes, filename: str, uploader: str = "unknown"):
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
safe_name = f"{ts}_{filename.replace(' ', '_')}"
path = STORAGE_DIR / safe_name
with open(path, "wb") as f:
f.write(file_bytes)
ocr_text = ""
try:
ocr_text = do_ocr(file_bytes)
except Exception as e:
print("OCR exception:", e, traceback.format_exc())
extracted = extract_fields(ocr_text)
session = SessionLocal()
match = match_supplier(session, extracted)
status = "matched" if match else "unknown"
doc = Document(
filename=filename,
filepath=str(path),
uploaded_at=datetime.utcnow(),
uploader=uploader,
ocr_text=ocr_text,
supplier_id=match["supplier_id"] if match else None,
supplier_name=match["supplier_name"] if match else None,
supplier_taxno=match["supplier_taxno"] if match else extracted.get("taxno"),
invoice_no=extracted.get("invoice_no"),
total_amount=extracted.get("total"),
match_score=match["score"] if match else None,
status=status,
raw_extracted=str(extracted)
)
session.add(doc)
session.commit()
session.refresh(doc)
session.close()
return {
"id": doc.id, "filename": doc.filename, "supplier": doc.supplier_name,
"score": doc.match_score, "status": doc.status,
"invoice_no": doc.invoice_no, "total": doc.total_amount
}
# ---------- Export ----------
def export_zip(ids: List[int]):
session = SessionLocal()
docs = session.query(Document).filter(Document.id.in_(ids)).all()
if not docs:
session.close()
return None
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S")
zip_path = STORAGE_DIR / f"export_{ts}.zip"
csv_buffer = io.StringIO()
rows = []
for d in docs:
rows.append({
"id": d.id, "filename": d.filename, "supplier_name": d.supplier_name,
"supplier_taxno": d.supplier_taxno, "invoice_no": d.invoice_no,
"total_amount": d.total_amount, "uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"),
"status": d.status
})
df = pd.DataFrame(rows)
df.to_csv(csv_buffer, index=False, encoding="utf-8-sig")
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("metadata.csv", csv_buffer.getvalue())
for d in docs:
if os.path.exists(d.filepath):
zf.write(d.filepath, arcname=os.path.basename(d.filepath))
session.close()
# === 新增:复制 ZIP 到 Space 根目录,让用户可直接下载 ===
import shutil
try:
shutil.copy(zip_path, BASE_DIR / f"export_{ts}.zip")
print(f"✅ ZIP 已复制到根目录: export_{ts}.zip")
except Exception as e:
print("复制 ZIP 失败:", e)
return str(zip_path)
# ---------- Gradio UI ----------
def upload_files(files, uploader):
results = []
if not files:
return pd.DataFrame([], columns=["id","filename","supplier","score","status","invoice_no","total"])
for f in files:
try:
content = f.read() if hasattr(f, "read") else open(f.name, "rb").read()
res = save_file_and_record(content, getattr(f, "name", "file"), uploader or "unknown")
results.append(res)
except Exception as e:
results.append({
"id": -1, "filename": getattr(f, "name", "file"),
"supplier": None, "score": None, "status": "error",
"invoice_no": None, "total": None
})
print("File process error:", e, traceback.format_exc())
return pd.DataFrame(results)
def refresh_list():
session = SessionLocal()
q = session.query(Document).order_by(Document.uploaded_at.desc()).limit(200).all()
rows = [{
"id": d.id, "filename": d.filename,
"uploaded_at": d.uploaded_at.strftime("%Y-%m-%d %H:%M:%S"),
"supplier_name": d.supplier_name, "invoice_no": d.invoice_no,
"total_amount": d.total_amount, "score": d.match_score, "status": d.status
} for d in q]
session.close()
return pd.DataFrame(rows)
def ui_confirm(doc_id: str, supplier_name: str):
if not doc_id.isdigit():
return "Invalid doc id"
doc_id_i = int(doc_id)
session = SessionLocal()
s = session.query(Supplier).filter(Supplier.name==supplier_name).first()
d = session.query(Document).get(doc_id_i)
if not d:
session.close()
return "Doc not found"
if not s:
session.close()
return "Supplier not found"
d.supplier_id = s.id
d.supplier_name = s.name
d.supplier_taxno = s.tax_no
d.match_score = 100.0
d.status = "matched"
session.add(d)
session.commit()
session.close()
return "ok"
def ui_export(txt_ids: str):
ids = [int(x.strip()) for x in txt_ids.split(",") if x.strip().isdigit()]
if not ids:
return ""
return export_zip(ids)
with gr.Blocks() as demo:
gr.Markdown("## 发票收票分拣(Staging)\n上传发票 → 自动识别供应商 → 人工确认/导出")
with gr.Row():
with gr.Column(scale=3):
uploader = gr.Textbox(label="上传人 (可选)", placeholder="front_desk")
file_inputs = gr.File(label="拖拽或选择发票(图片或 PDF)", file_count="multiple")
upload_btn = gr.Button("上传并识别")
upload_out = gr.Dataframe(headers=["id","filename","supplier","score","status","invoice_no","total"], interactive=False)
with gr.Column(scale=2):
refresh_btn = gr.Button("刷新列表")
list_out = gr.Dataframe(headers=["id","filename","uploaded_at","supplier_name","invoice_no","total_amount","score","status"], interactive=False)
doc_id_in = gr.Textbox(label="文档 ID(从列表复制)")
supplier_dd = gr.Dropdown(choices=[s.name for s in SessionLocal().query(Supplier).all()], label="选择供应商")
confirm_btn = gr.Button("确认关联")
confirm_out = gr.Textbox(label="结果")
export_ids = gr.Textbox(label="导出 ID(如 1,2,3)")
export_btn = gr.Button("导出为 ZIP")
export_out = gr.Textbox(label="导出结果")
upload_btn.click(upload_files, inputs=[file_inputs, uploader], outputs=upload_out)
refresh_btn.click(refresh_list, inputs=[], outputs=list_out)
confirm_btn.click(ui_confirm, inputs=[doc_id_in, supplier_dd], outputs=confirm_out)
export_btn.click(ui_export, inputs=[export_ids], outputs=export_out)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)