Invoice / main.py
Corin1998's picture
Upload main.py
f359272 verified
"""
Mini Invoice/Estimate (Quote) SaaS — single-file FastAPI app
SQLite + SQLModel + FastAPI
"""
from __future__ import annotations
from typing import Optional, List, Any
from datetime import datetime, date
import os
import pathlib
from fastapi import FastAPI, Depends, HTTPException, Header, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, model_validator
from sqlmodel import Field as SQLField, Session, SQLModel, create_engine, select, Relationship
from sqlalchemy import func
from sqlalchemy.exc import SQLAlchemyError
# --------------------------
# Auth
# --------------------------
API_KEY = os.getenv("API_KEY", "dev")
async def require_api_key(x_api_key: str | None = Header(default=None)):
if x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing X-API-Key")
# --------------------------
# DB(/data → /tmp → ./ の順に自動フォールバック)
# --------------------------
def _ensure_dir(path: str) -> bool:
try:
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
return True
except Exception:
return False
def _make_sqlite_url(db_path: str) -> str:
return f"sqlite:///{db_path}"
DB_URL = os.getenv("DATABASE_URL")
if DB_URL and DB_URL.startswith("sqlite"):
raw = DB_URL.replace("sqlite:///", "", 1)
raw = raw if raw.startswith("/") else os.path.abspath(raw)
dir_ = os.path.dirname(raw)
if not _ensure_dir(dir_):
fallback = "/tmp/app.db"
_ensure_dir(os.path.dirname(fallback))
DB_URL = _make_sqlite_url(fallback)
else:
candidates = ["/data/app.db", "/tmp/app.db", os.path.abspath("./app.db")]
chosen = None
for p in candidates:
if _ensure_dir(os.path.dirname(p)):
chosen = p
break
assert chosen is not None, "No writable location found for SQLite DB"
DB_URL = _make_sqlite_url(chosen)
engine = create_engine(
DB_URL,
echo=False,
connect_args={"check_same_thread": False} if DB_URL.startswith("sqlite") else {}
)
IS_SQLITE = DB_URL.startswith("sqlite")
def get_session():
with Session(engine) as session:
yield session
# --------------------------
# Models
# --------------------------
class Customer(SQLModel, table=True):
id: Optional[int] = SQLField(default=None, primary_key=True)
name: str = SQLField(index=True)
email: Optional[str] = None
phone: Optional[str] = None
address: Optional[str] = None
city: Optional[str] = None
country: Optional[str] = None
quotes: List["Quote"] = Relationship(back_populates="customer")
invoices: List["Invoice"] = Relationship(back_populates="customer")
class Product(SQLModel, table=True):
id: Optional[int] = SQLField(default=None, primary_key=True)
name: str
unit_price: float = SQLField(ge=0)
currency: str = SQLField(default="JPY")
sku: Optional[str] = None
description: Optional[str] = None
class Quote(SQLModel, table=True):
id: Optional[int] = SQLField(default=None, primary_key=True)
customer_id: int = SQLField(foreign_key="customer.id")
status: str = SQLField(default="draft", index=True) # draft/sent/accepted/expired
issue_date: date = SQLField(default_factory=lambda: datetime.utcnow().date())
valid_until: Optional[date] = None
notes: Optional[str] = None
customer: Optional[Customer] = Relationship(back_populates="quotes")
items: List["QuoteItem"] = Relationship(back_populates="quote")
class QuoteItem(SQLModel, table=True):
id: Optional[int] = SQLField(default=None, primary_key=True)
quote_id: int = SQLField(foreign_key="quote.id")
product_id: Optional[int] = SQLField(foreign_key="product.id", default=None)
description: str
quantity: float = SQLField(gt=0, default=1)
unit_price: float = SQLField(ge=0)
tax_rate: float = SQLField(ge=0, default=0.0)
quote: Optional[Quote] = Relationship(back_populates="items")
class Invoice(SQLModel, table=True):
id: Optional[int] = SQLField(default=None, primary_key=True)
customer_id: int = SQLField(foreign_key="customer.id")
status: str = SQLField(default="unpaid", index=True) # unpaid/paid/void
issue_date: date = SQLField(default_factory=lambda: datetime.utcnow().date())
due_date: Optional[date] = None
notes: Optional[str] = None
# 入金関連(任意)
paid_at: Optional[datetime] = None
paid_amount: Optional[float] = None
payment_method: Optional[str] = None
customer: Optional[Customer] = Relationship(back_populates="invoices")
items: List["InvoiceItem"] = Relationship(back_populates="invoice")
class InvoiceItem(SQLModel, table=True):
id: Optional[int] = SQLField(default=None, primary_key=True)
invoice_id: int = SQLField(foreign_key="invoice.id")
product_id: Optional[int] = SQLField(foreign_key="product.id", default=None)
description: str
quantity: float = SQLField(gt=0, default=1)
unit_price: float = SQLField(ge=0)
tax_rate: float = SQLField(ge=0, default=0.0)
invoice: Optional[Invoice] = Relationship(back_populates="items")
# --------------------------
# DTO(入力用)
# --------------------------
class CreateQuoteIn(BaseModel):
customer_id: int
valid_until: Optional[date] = None
notes: Optional[str] = None
class QuoteItemIn(BaseModel):
description: str
quantity: float = 1
unit_price: float
tax_rate: float = 0.0
class CreateInvoiceIn(BaseModel):
customer_id: int
due_date: Optional[date] = None
notes: Optional[str] = None
# 見積→請求のコピー用(任意)
quote_id: Optional[int] = None
class InvoiceItemIn(BaseModel):
description: str
quantity: float = 1
unit_price: float
tax_rate: float = 0.0
class PayIn(BaseModel):
paid_amount: float
paid_at: Optional[datetime] = None
payment_method: Optional[str] = "bank_transfer"
class EmailIn(BaseModel):
to: str
subject: str
body: str
attach_pdf: bool = True
# ---- ウィザード(フロント一括登録用)----
class WizardItemIn(BaseModel):
description: str
quantity: float = Field(gt=0, default=1)
unit_price: float = Field(ge=0)
tax_rate: float = Field(ge=0, default=0.0)
class WizardCustomerIn(BaseModel):
id: int | None = None
name: str | None = None
email: str | None = None
phone: str | None = None
address: str | None = None
city: str | None = None
country: str | None = None
class WizardInvoiceIn(BaseModel):
customer: WizardCustomerIn
due_date: date | None = None
notes: str | None = None
items: list[WizardItemIn]
@model_validator(mode="before")
@classmethod
def coerce_dates(cls, v: dict):
d = v.get("due_date")
if isinstance(d, str) and d.strip():
s = d.strip().replace("/", "-")
try:
v["due_date"] = date.fromisoformat(s)
except Exception:
v["due_date"] = None
return v
# --------------------------
# Totals(堅牢化)
# --------------------------
class MoneyBreakdown(BaseModel):
subtotal: float
tax: float
total: float
def round2(v: float) -> float:
return float(f"{v:.2f}")
def _get(v: Any, key: str, default=0.0):
if isinstance(v, dict):
return v.get(key, default)
return getattr(v, key, default)
def compute_totals(items: list[Any]) -> MoneyBreakdown:
subtotal = 0.0
tax = 0.0
for it in items:
qty = float(_get(it, "quantity", 0) or 0)
unit = float(_get(it, "unit_price", 0) or 0)
rate = float(_get(it, "tax_rate", 0) or 0)
line = qty * unit
subtotal += line
tax += line * rate
return MoneyBreakdown(subtotal=round2(subtotal), tax=round2(tax), total=round2(subtotal + tax))
# --------------------------
# App
# --------------------------
app = FastAPI(title="Mini Invoice/Estimate SaaS", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
# Swagger の Authorize(X-API-Key)を出す
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description="Mini Invoice/Estimate SaaS API",
routes=app.routes,
)
openapi_schema.setdefault("components", {}).setdefault("securitySchemes", {})
openapi_schema["components"]["securitySchemes"]["APIKeyHeader"] = {
"type": "apiKey",
"name": "X-API-Key",
"in": "header",
}
for path in openapi_schema.get("paths", {}).values():
for method in path.values():
method.setdefault("security", [{"APIKeyHeader": []}])
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
@app.on_event("startup")
def on_startup():
SQLModel.metadata.create_all(engine)
# ---- UI(/app) ----
app.mount("/app", StaticFiles(directory="static", html=True), name="app")
# -------- Customers --------
@app.post("/customers", dependencies=[Depends(require_api_key)])
def create_customer(payload: Customer, session: Session = Depends(get_session)):
try:
if not (payload.name and str(payload.name).strip()):
raise HTTPException(422, "name は必須です")
session.add(payload)
session.commit()
session.refresh(payload)
return payload
except HTTPException:
raise
except SQLAlchemyError as e:
session.rollback()
raise HTTPException(400, f"DBエラー: {e.__class__.__name__}")
except Exception as e:
raise HTTPException(400, f"不正なリクエスト: {e}")
@app.get("/customers", dependencies=[Depends(require_api_key)])
def list_customers(
q: Optional[str] = Query(default=None, description="free text search (name/email/phone)"),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
session: Session = Depends(get_session),
):
base = select(Customer)
count_stmt = select(func.count(Customer.id))
if q:
pattern = f"%{q}%"
if IS_SQLITE:
cond = (
Customer.name.like(pattern) |
Customer.email.like(pattern) |
Customer.phone.like(pattern)
)
else:
cond = (
Customer.name.ilike(pattern) |
Customer.email.ilike(pattern) |
Customer.phone.ilike(pattern)
)
base = base.where(cond)
count_stmt = count_stmt.where(cond)
total = session.exec(count_stmt).scalar() or 0
rows = session.exec(base.offset(offset).limit(limit)).all()
return {"data": rows, "pagination": {"total": total, "limit": limit, "offset": offset}}
# -------- Products --------
@app.post("/products", dependencies=[Depends(require_api_key)])
def create_product(payload: Product, session: Session = Depends(get_session)):
try:
if not (payload.name and str(payload.name).strip()):
raise HTTPException(422, "name は必須です")
session.add(payload)
session.commit()
session.refresh(payload)
return payload
except HTTPException:
raise
except SQLAlchemyError as e:
session.rollback()
raise HTTPException(400, f"DBエラー: {e.__class__.__name__}")
except Exception as e:
raise HTTPException(400, f"不正なリクエスト: {e}")
@app.get("/products", dependencies=[Depends(require_api_key)])
def list_products(
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
session: Session = Depends(get_session),
):
base = select(Product)
total = session.exec(select(func.count(Product.id))).scalar() or 0
rows = session.exec(base.offset(offset).limit(limit)).all()
return {"data": rows, "pagination": {"total": total, "limit": limit, "offset": offset}}
# -------- Quotes --------
@app.post("/quotes", dependencies=[Depends(require_api_key)])
def create_quote(payload: CreateQuoteIn, session: Session = Depends(get_session)):
if not session.get(Customer, payload.customer_id):
raise HTTPException(400, "Customer not found")
q = Quote(customer_id=payload.customer_id, valid_until=payload.valid_until, notes=payload.notes)
session.add(q); session.commit(); session.refresh(q)
return q
@app.get("/quotes/{quote_id}", dependencies=[Depends(require_api_key)])
def get_quote(quote_id: int, session: Session = Depends(get_session)):
q = session.get(Quote, quote_id)
if not q:
raise HTTPException(404, "Quote not found")
items = session.exec(select(QuoteItem).where(QuoteItem.quote_id == quote_id)).all()
totals = compute_totals(items)
return {"quote": q, "items": items, "totals": totals}
@app.post("/quotes/{quote_id}/items", dependencies=[Depends(require_api_key)])
def add_quote_item(quote_id: int, payload: QuoteItemIn, session: Session = Depends(get_session)):
if not session.get(Quote, quote_id):
raise HTTPException(404, "Quote not found")
item = QuoteItem(
quote_id=quote_id,
description=payload.description,
quantity=payload.quantity,
unit_price=payload.unit_price,
tax_rate=payload.tax_rate,
)
session.add(item); session.commit(); session.refresh(item)
return item
# -------- Invoices --------
@app.post("/invoices", dependencies=[Depends(require_api_key)])
def create_invoice(payload: CreateInvoiceIn, session: Session = Depends(get_session)):
if not session.get(Customer, payload.customer_id):
raise HTTPException(400, "Customer not found")
inv = Invoice(customer_id=payload.customer_id, due_date=payload.due_date, notes=payload.notes)
session.add(inv); session.commit(); session.refresh(inv)
if payload.quote_id:
q = session.get(Quote, payload.quote_id)
if not q:
raise HTTPException(404, "Quote not found to copy")
q_items = session.exec(select(QuoteItem).where(QuoteItem.quote_id == payload.quote_id)).all()
for it in q_items:
new_item = InvoiceItem(
invoice_id=inv.id,
product_id=it.product_id,
description=it.description,
quantity=it.quantity,
unit_price=it.unit_price,
tax_rate=it.tax_rate,
)
session.add(new_item)
session.commit()
session.refresh(inv)
return inv
@app.get("/invoices/{invoice_id}", dependencies=[Depends(require_api_key)])
def get_invoice(invoice_id: int, session: Session = Depends(get_session)):
inv = session.get(Invoice, invoice_id)
if not inv:
raise HTTPException(404, "Invoice not found")
items = session.exec(select(InvoiceItem).where(InvoiceItem.invoice_id == invoice_id)).all()
totals = compute_totals(items)
return {"invoice": inv, "items": items, "totals": totals}
@app.post("/invoices/{invoice_id}/items", dependencies=[Depends(require_api_key)])
def add_invoice_item(invoice_id: int, payload: InvoiceItemIn, session: Session = Depends(get_session)):
if not session.get(Invoice, invoice_id):
raise HTTPException(404, "Invoice not found")
item = InvoiceItem(
invoice_id=invoice_id,
description=payload.description,
quantity=payload.quantity,
unit_price=payload.unit_price,
tax_rate=payload.tax_rate,
)
session.add(item); session.commit(); session.refresh(item)
return item
# ---- 支払い登録 ----
@app.post("/invoices/{invoice_id}/pay", dependencies=[Depends(require_api_key)])
def pay_invoice(invoice_id: int, payload: PayIn, session: Session = Depends(get_session)):
inv = session.get(Invoice, invoice_id)
if not inv:
raise HTTPException(404, "Invoice not found")
inv.paid_amount = payload.paid_amount
inv.paid_at = payload.paid_at or datetime.utcnow()
inv.payment_method = payload.payment_method
inv.status = "paid"
session.add(inv); session.commit(); session.refresh(inv)
return {"ok": True, "invoice": inv}
# ---- PDF生成&ダウンロード ----
@app.get("/invoices/{invoice_id}/pdf", dependencies=[Depends(require_api_key)])
def invoice_pdf(invoice_id: int, session: Session = Depends(get_session)):
inv = session.get(Invoice, invoice_id)
if not inv:
raise HTTPException(404, "Invoice not found")
cust = session.get(Customer, inv.customer_id)
items = session.exec(select(InvoiceItem).where(InvoiceItem.invoice_id == invoice_id)).all()
from pdf_export import write_invoice_pdf # 遅延import
out_path = f"/tmp/invoice_{invoice_id}.pdf"
totals = compute_totals(items)
write_invoice_pdf(out_path, inv, items, cust, totals)
return FileResponse(out_path, media_type="application/pdf", filename=f"invoice_{invoice_id}.pdf")
# ---- メール送信(PDF添付可)----
@app.post("/invoices/{invoice_id}/email", dependencies=[Depends(require_api_key)])
def email_invoice(invoice_id: int, payload: EmailIn, session: Session = Depends(get_session)):
inv = session.get(Invoice, invoice_id)
if not inv:
raise HTTPException(404, "Invoice not found")
cust = session.get(Customer, inv.customer_id)
items = session.exec(select(InvoiceItem).where(InvoiceItem.invoice_id == invoice_id)).all()
totals = compute_totals(items)
attachment_path = None
if payload.attach_pdf:
from pdf_export import write_invoice_pdf
attachment_path = f"/tmp/invoice_{invoice_id}.pdf"
write_invoice_pdf(attachment_path, inv, items, cust, totals)
from mailer import send_email_smtp
ok, detail = send_email_smtp(
to=payload.to,
subject=payload.subject,
body=payload.body,
attachment_path=attachment_path
)
return {"ok": ok, "detail": detail}
# ---- ウィザード:顧客→請求→明細を一括作成 ----
@app.post("/wizard/invoice", dependencies=[Depends(require_api_key)])
def wizard_create_invoice(payload: WizardInvoiceIn, session: Session = Depends(get_session)):
try:
# 顧客確定(既存 or 新規)
cust_id: int | None = payload.customer.id
if cust_id:
cust = session.get(Customer, cust_id)
if not cust:
raise HTTPException(404, "Customer id not found")
else:
if not payload.customer.name:
raise HTTPException(422, "Either customer.id or customer.name is required")
existing = session.exec(select(Customer).where(Customer.name == payload.customer.name)).first()
cust = existing or Customer(
name=payload.customer.name.strip(),
email=(payload.customer.email or None),
phone=(payload.customer.phone or None),
address=(payload.customer.address or None),
city=(payload.customer.city or None),
country=(payload.customer.country or None),
)
if not existing:
session.add(cust); session.commit(); session.refresh(cust)
cust_id = cust.id
# 請求書
inv = Invoice(customer_id=cust_id, due_date=payload.due_date, notes=payload.notes)
session.add(inv); session.commit(); session.refresh(inv)
# 明細
for it in payload.items:
session.add(InvoiceItem(
invoice_id=inv.id,
description=it.description,
quantity=it.quantity,
unit_price=it.unit_price,
tax_rate=it.tax_rate,
))
session.commit()
items = session.exec(select(InvoiceItem).where(InvoiceItem.invoice_id == inv.id)).all()
totals = compute_totals(items)
return {"customer_id": cust_id, "invoice_id": inv.id, "totals": totals}
except HTTPException:
raise
except Exception as e:
raise HTTPException(400, f"Wizard failed: {e}")
# ---- ルート(UI) ----
@app.get("/", response_class=FileResponse)
def root_ui():
return FileResponse("static/app.html")
# -------- ChatGPT router(最後に組み込む) --------
from openai_integration import router as ai_router
app.include_router(ai_router, prefix="/ai", tags=["ai"])