Seth
update
33a43cb
import os
import time
from typing import List, Dict, Optional
from fastapi import FastAPI, UploadFile, File, Depends, Form, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from pydantic import BaseModel
from .db import Base, engine, SessionLocal
from .models import ExtractionRecord, User, ShareToken
from .schemas import ExtractionRecordBase, ExtractionStage
from .openrouter_client import extract_fields_from_document
from .auth import get_current_user, get_db, verify_token
from .auth_routes import router as auth_router
from .api_key_auth import get_user_from_api_key
# Allowed file types
ALLOWED_CONTENT_TYPES = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/tiff",
"image/tif"
]
# Allowed file extensions (for fallback validation)
ALLOWED_EXTENSIONS = [".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"]
# Maximum file size: 4 MB
MAX_FILE_SIZE = 4 * 1024 * 1024 # 4 MB in bytes
# Ensure data dir exists for SQLite
os.makedirs("data", exist_ok=True)
# Create tables
Base.metadata.create_all(bind=engine)
app = FastAPI(title="Document Capture Demo – Backend")
# Include auth routes
app.include_router(auth_router)
# CORS (for safety we allow all; you can tighten later)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_current_user_or_api_key_user(
api_key_user: Optional[User] = Depends(get_user_from_api_key),
credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)),
db: Session = Depends(get_db),
) -> User:
"""
Flexible authentication: supports both JWT Bearer token and API key.
Tries API key first, then falls back to JWT if no API key provided.
"""
# If API key authentication succeeded, use that
if api_key_user:
return api_key_user
# Otherwise, try JWT authentication
if credentials:
try:
token = credentials.credentials
payload = verify_token(token)
user_id = int(payload.get("sub"))
user = db.query(User).filter(User.id == user_id).first()
if user:
return user
except Exception:
pass # Will raise HTTPException below
# If neither worked, raise authentication error
raise HTTPException(
status_code=401,
detail="Authentication required. Provide either a Bearer token or X-API-Key header.",
headers={"WWW-Authenticate": "Bearer"},
)
@app.get("/ping")
def ping():
"""Healthcheck."""
return {"status": "ok", "message": "backend alive"}
def make_stages(total_ms: int, status: str) -> Dict[str, ExtractionStage]:
"""
Build synthetic stage timing data for the History UI.
For now we just split total_ms into 4 stages.
"""
if total_ms <= 0:
total_ms = 1000
return {
"uploading": ExtractionStage(
time=int(total_ms * 0.15),
status="completed",
variation="normal",
),
"aiAnalysis": ExtractionStage(
time=int(total_ms * 0.55),
status="completed" if status == "completed" else "failed",
variation="normal",
),
"dataExtraction": ExtractionStage(
time=int(total_ms * 0.2),
status="completed" if status == "completed" else "skipped",
variation="fast",
),
"outputRendering": ExtractionStage(
time=int(total_ms * 0.1),
status="completed" if status == "completed" else "skipped",
variation="normal",
),
}
@app.post("/api/extract")
async def extract_document(
file: UploadFile = File(...),
key_fields: Optional[str] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_or_api_key_user),
):
"""
Main extraction endpoint for document parsing.
Supports both JWT Bearer token and API key authentication.
Authentication methods:
1. JWT Bearer token: Header "Authorization: Bearer <token>"
2. API Key: Header "X-API-Key: <api_key>"
Parameters:
- file: Document file (PDF, PNG, JPEG, TIFF) - max 4MB
- key_fields: Optional comma-separated list of specific fields to extract (e.g., "Invoice Number,Invoice Date")
Returns JSON with extracted fields, text, confidence, and metadata.
"""
# Check trial limit (5 documents per user)
TRIAL_LIMIT = 5
extraction_count = (
db.query(ExtractionRecord)
.filter(ExtractionRecord.user_id == current_user.id)
.count()
)
if extraction_count >= TRIAL_LIMIT:
raise HTTPException(
status_code=403,
detail=f"Trial limit reached. You have processed {extraction_count} documents. The trial allows up to {TRIAL_LIMIT} documents. Please upgrade to continue."
)
start = time.time()
content = await file.read()
content_type = file.content_type or "application/octet-stream"
file_size = len(content)
size_mb = file_size / 1024 / 1024
size_str = f"{size_mb:.2f} MB"
# Convert file content to base64 for storage
import base64
file_base64 = base64.b64encode(content).decode("utf-8")
# Validate file size
if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"File size exceeds 4 MB limit. Your file is {size_mb:.2f} MB."
)
# Validate file type
file_extension = ""
if file.filename:
file_extension = "." + file.filename.split(".")[-1].lower()
is_valid_type = (
content_type in ALLOWED_CONTENT_TYPES or
file_extension in ALLOWED_EXTENSIONS
)
if not is_valid_type:
raise HTTPException(
status_code=400,
detail="Only PDF, PNG, JPG, and TIFF files are allowed."
)
try:
print(f"[INFO] Starting extraction for file: {file.filename}, type: {content_type}, size: {size_str}")
if key_fields:
print(f"[INFO] Key fields requested: {key_fields}")
extracted = await extract_fields_from_document(content, content_type, file.filename, key_fields)
total_ms = int((time.time() - start) * 1000)
print(f"[INFO] Extraction completed. Response keys: {list(extracted.keys())}")
print(f"[INFO] Fields extracted: {extracted.get('fields', {})}")
confidence = float(extracted.get("confidence", 90))
fields = extracted.get("fields", {})
# Get Fields from root level (if user provided key_fields)
root_fields = extracted.get("Fields", {})
# Get full_text for text output
full_text = extracted.get("full_text", "")
if full_text:
full_text_words = len(str(full_text).split())
print(f"[INFO] Full text extracted: {full_text_words} words")
# Check if fields contain structured data (from table parsing)
# If fields is a dict with page_X keys, it's already structured
# If fields is empty or simple, add full_text and pages for text display
if not fields or (isinstance(fields, dict) and not any(k.startswith("page_") for k in fields.keys())):
if full_text:
fields["full_text"] = full_text
# Also check for pages array
pages_data = extracted.get("pages", [])
if pages_data and isinstance(pages_data, list):
print(f"[INFO] Extracted text from {len(pages_data)} page(s)")
fields["pages"] = pages_data
# Add Fields at root level if it exists
if root_fields:
fields["Fields"] = root_fields
# Count fields - if structured data exists, count table rows + root Fields
if isinstance(fields, dict):
# Check if it's structured page data
if any(k.startswith("page_") for k in fields.keys()):
# Count table rows from all pages
table_rows_count = 0
for page_key, page_data in fields.items():
if page_key.startswith("page_") and isinstance(page_data, dict):
table_rows = page_data.get("table", [])
if isinstance(table_rows, list):
table_rows_count += len(table_rows)
# Count Fields from root level
fields_keys = 0
if isinstance(root_fields, dict):
fields_keys = len(root_fields)
fields_extracted = table_rows_count + fields_keys
print(f"[INFO] Structured data: {table_rows_count} table rows, {fields_keys} extracted fields")
else:
# Regular fields count (excluding full_text, pages, and Fields)
fields_extracted = len([k for k in fields.keys() if k not in ["full_text", "pages", "Fields"]])
# Add Fields count if it exists
if isinstance(root_fields, dict):
fields_extracted += len(root_fields)
else:
fields_extracted = 0
print(f"[INFO] Final stats - confidence: {confidence}, fields_count: {fields_extracted}")
status = "completed"
error_message = None
except Exception as e:
import traceback
total_ms = int((time.time() - start) * 1000)
confidence = 0.0
fields = {}
fields_extracted = 0
status = "failed"
error_message = str(e)
print(f"[ERROR] Extraction failed: {error_message}")
print(f"[ERROR] Traceback: {traceback.format_exc()}")
# Save record to DB
import json
import base64
rec = ExtractionRecord(
user_id=current_user.id,
file_name=file.filename,
file_type=content_type,
file_size=size_str,
status=status,
confidence=confidence,
fields_extracted=fields_extracted,
total_time_ms=total_ms,
raw_output=json.dumps(fields), # Use JSON instead of str() to preserve structure
file_base64=file_base64, # Store base64 encoded file for preview
error_message=error_message,
)
db.add(rec)
db.commit()
db.refresh(rec)
stages = make_stages(total_ms, status)
# Response shape that frontend will consume
return {
"id": rec.id,
"fileName": rec.file_name,
"fileType": rec.file_type,
"fileSize": rec.file_size,
"status": status,
"confidence": confidence,
"fieldsExtracted": fields_extracted,
"totalTime": total_ms,
"fields": fields,
"stages": {k: v.dict() for k, v in stages.items()},
"errorMessage": error_message,
}
@app.get("/api/history", response_model=List[ExtractionRecordBase])
def get_history(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Used by the History page.
Returns last 100 records for the current user, with synthetic stage data.
"""
recs = (
db.query(ExtractionRecord)
.filter(ExtractionRecord.user_id == current_user.id)
.order_by(ExtractionRecord.created_at.desc())
.limit(100)
.all()
)
# Deduplicate: if multiple extractions share the same shared_from_extraction_id,
# keep only the most recent one (to prevent duplicates when same extraction is shared multiple times)
seen_shared_ids = set()
deduplicated_recs = []
for rec in recs:
if rec.shared_from_extraction_id:
# This is a shared extraction
if rec.shared_from_extraction_id not in seen_shared_ids:
seen_shared_ids.add(rec.shared_from_extraction_id)
deduplicated_recs.append(rec)
# Skip duplicates
else:
# Original extraction (not shared), always include
deduplicated_recs.append(rec)
recs = deduplicated_recs
output: List[ExtractionRecordBase] = []
for r in recs:
stages = make_stages(r.total_time_ms or 1000, r.status or "completed")
output.append(
ExtractionRecordBase(
id=r.id,
fileName=r.file_name,
fileType=r.file_type or "",
fileSize=r.file_size or "",
extractedAt=r.created_at,
status=r.status or "completed",
confidence=r.confidence or 0.0,
fieldsExtracted=r.fields_extracted or 0,
totalTime=r.total_time_ms or 0,
stages=stages,
errorMessage=r.error_message,
)
)
return output
@app.get("/api/extraction/{extraction_id}")
def get_extraction(
extraction_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get a specific extraction by ID with full fields data.
Used when viewing output from History page.
"""
import json
rec = (
db.query(ExtractionRecord)
.filter(
ExtractionRecord.id == extraction_id,
ExtractionRecord.user_id == current_user.id
)
.first()
)
if not rec:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="Extraction not found")
# Parse the raw_output JSON string back to dict
fields = {}
if rec.raw_output:
try:
# Try parsing as JSON first (new format)
fields = json.loads(rec.raw_output)
except (json.JSONDecodeError, TypeError):
# If that fails, try using ast.literal_eval for old str() format (backward compatibility)
try:
import ast
# Only use literal_eval if it looks like a Python dict string
if rec.raw_output.strip().startswith('{'):
fields = ast.literal_eval(rec.raw_output)
else:
fields = {}
except:
fields = {}
stages = make_stages(rec.total_time_ms or 1000, rec.status or "completed")
return {
"id": rec.id,
"fileName": rec.file_name,
"fileType": rec.file_type or "",
"fileSize": rec.file_size or "",
"status": rec.status or "completed",
"confidence": rec.confidence or 0.0,
"fieldsExtracted": rec.fields_extracted or 0,
"totalTime": rec.total_time_ms or 0,
"fields": fields,
"fileBase64": rec.file_base64, # Include base64 encoded file for preview
"stages": {k: v.dict() for k, v in stages.items()},
"errorMessage": rec.error_message,
}
@app.post("/api/share")
async def share_extraction(
extraction_id: int = Body(...),
recipient_emails: List[str] = Body(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Share an extraction with one or more users via email.
Creates share tokens and sends emails to recipients.
"""
import secrets
from datetime import datetime, timedelta
from .brevo_service import send_share_email
from .email_validator import validate_business_email
# Validate recipient emails list
if not recipient_emails or len(recipient_emails) == 0:
raise HTTPException(status_code=400, detail="At least one recipient email is required")
# Validate each recipient email is a business email
for email in recipient_emails:
try:
validate_business_email(email)
except HTTPException:
raise # Re-raise HTTPException from validate_business_email
# Get the extraction record
extraction = (
db.query(ExtractionRecord)
.filter(
ExtractionRecord.id == extraction_id,
ExtractionRecord.user_id == current_user.id
)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="Extraction not found")
# Generate share link base URL
base_url = os.environ.get("VITE_API_BASE_URL", "https://seth0330-ezofisocr.hf.space")
# Process each recipient email
successful_shares = []
failed_shares = []
share_records = []
for recipient_email in recipient_emails:
recipient_email = recipient_email.strip().lower()
# Generate secure share token for this recipient
share_token = secrets.token_urlsafe(32)
# Create share token record (expires in 30 days)
expires_at = datetime.utcnow() + timedelta(days=30)
share_record = ShareToken(
token=share_token,
extraction_id=extraction_id,
sender_user_id=current_user.id,
recipient_email=recipient_email,
expires_at=expires_at,
)
db.add(share_record)
share_records.append((share_record, share_token, recipient_email))
# Commit all share tokens
try:
db.commit()
for share_record, share_token, recipient_email in share_records:
db.refresh(share_record)
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=f"Failed to create share tokens: {str(e)}")
# Send emails to all recipients
for share_record, share_token, recipient_email in share_records:
share_link = f"{base_url}/share/{share_token}"
try:
# Get sender's name from current_user, fallback to None if not available
sender_name = current_user.name if current_user.name else None
await send_share_email(recipient_email, current_user.email, share_link, sender_name)
successful_shares.append(recipient_email)
except Exception as e:
# Log error but continue with other emails
print(f"[ERROR] Failed to send share email to {recipient_email}: {str(e)}")
failed_shares.append(recipient_email)
# Optionally, you could delete the share token if email fails
# db.delete(share_record)
# Build response message
if len(failed_shares) == 0:
message = f"Extraction shared successfully with {len(successful_shares)} recipient(s)"
elif len(successful_shares) == 0:
raise HTTPException(status_code=500, detail=f"Failed to send share emails to all recipients")
else:
message = f"Extraction shared with {len(successful_shares)} recipient(s). Failed to send to: {', '.join(failed_shares)}"
return {
"success": True,
"message": message,
"successful_count": len(successful_shares),
"failed_count": len(failed_shares),
"successful_emails": successful_shares,
"failed_emails": failed_shares if failed_shares else None
}
class ShareLinkRequest(BaseModel):
extraction_id: int
@app.post("/api/share/link")
async def create_share_link(
request: ShareLinkRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a shareable link for an extraction without requiring recipient emails.
Returns a share link that can be copied and shared manually.
"""
import secrets
from datetime import datetime, timedelta
# Get the extraction record
extraction = (
db.query(ExtractionRecord)
.filter(
ExtractionRecord.id == request.extraction_id,
ExtractionRecord.user_id == current_user.id
)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="Extraction not found")
# Generate secure share token
share_token = secrets.token_urlsafe(32)
# Create share token record (expires in 30 days, no specific recipient)
expires_at = datetime.utcnow() + timedelta(days=30)
share_record = ShareToken(
token=share_token,
extraction_id=request.extraction_id,
sender_user_id=current_user.id,
recipient_email=None, # None for public share links (copyable links)
expires_at=expires_at,
)
db.add(share_record)
db.commit()
db.refresh(share_record)
# Generate share link
base_url = os.environ.get("VITE_API_BASE_URL", "https://seth0330-ezofisocr.hf.space")
share_link = f"{base_url}/share/{share_token}"
return {
"success": True,
"share_link": share_link,
"share_token": share_token,
"expires_at": expires_at.isoformat() if expires_at else None
}
@app.get("/api/share/{token}")
async def access_shared_extraction(
token: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Access a shared extraction and copy it to the current user's account.
This endpoint is called after the user logs in via the share link.
"""
from datetime import datetime
import json
# Find the share token
share = (
db.query(ShareToken)
.filter(ShareToken.token == token)
.first()
)
if not share:
raise HTTPException(status_code=404, detail="Share link not found or expired")
# Check if token is expired
if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException(status_code=410, detail="Share link has expired")
# Get the original extraction
original_extraction = (
db.query(ExtractionRecord)
.filter(ExtractionRecord.id == share.extraction_id)
.first()
)
if not original_extraction:
raise HTTPException(status_code=404, detail="Original extraction not found")
# Check if already copied for this user (check by share token to prevent duplicates from same share)
# Also check if this specific share token was already used by this user
if share.accessed and share.accessed_by_user_id == current_user.id:
# This share token was already used by this user, find the extraction
existing_copy = (
db.query(ExtractionRecord)
.filter(
ExtractionRecord.user_id == current_user.id,
ExtractionRecord.shared_from_extraction_id == original_extraction.id
)
.order_by(ExtractionRecord.created_at.desc())
.first()
)
if existing_copy:
return {
"success": True,
"extraction_id": existing_copy.id,
"message": "Extraction already shared with you"
}
# Also check if any copy exists for this user from this original extraction
existing_copy = (
db.query(ExtractionRecord)
.filter(
ExtractionRecord.user_id == current_user.id,
ExtractionRecord.shared_from_extraction_id == original_extraction.id
)
.first()
)
if existing_copy:
# Already copied, mark this share as accessed and return existing extraction ID
share.accessed = True
share.accessed_at = datetime.utcnow()
share.accessed_by_user_id = current_user.id
db.commit()
return {
"success": True,
"extraction_id": existing_copy.id,
"message": "Extraction already shared with you"
}
# Copy extraction to current user's account
# Parse the raw_output JSON string back to dict
fields = {}
if original_extraction.raw_output:
try:
fields = json.loads(original_extraction.raw_output)
except (json.JSONDecodeError, TypeError):
try:
import ast
if original_extraction.raw_output.strip().startswith('{'):
fields = ast.literal_eval(original_extraction.raw_output)
else:
fields = {}
except:
fields = {}
# Create new extraction record for the recipient
new_extraction = ExtractionRecord(
user_id=current_user.id,
file_name=original_extraction.file_name,
file_type=original_extraction.file_type,
file_size=original_extraction.file_size,
status=original_extraction.status or "completed",
confidence=original_extraction.confidence or 0.0,
fields_extracted=original_extraction.fields_extracted or 0,
total_time_ms=original_extraction.total_time_ms or 0,
raw_output=original_extraction.raw_output, # Copy the JSON string
file_base64=original_extraction.file_base64, # Copy the base64 file
shared_from_extraction_id=original_extraction.id,
shared_by_user_id=share.sender_user_id,
)
db.add(new_extraction)
# Mark share as accessed
share.accessed = True
share.accessed_at = datetime.utcnow()
share.accessed_by_user_id = current_user.id
db.commit()
db.refresh(new_extraction)
return {
"success": True,
"extraction_id": new_extraction.id,
"message": "Extraction shared successfully"
}
# Static frontend mounting (used after we build React)
# Dockerfile copies the Vite build into backend/frontend_dist
# IMPORTANT: API routes must be defined BEFORE this so they take precedence
frontend_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "frontend_dist"
)
if os.path.isdir(frontend_dir):
# Serve static files (JS, CSS, images, etc.) from assets directory
assets_dir = os.path.join(frontend_dir, "assets")
if os.path.isdir(assets_dir):
app.mount(
"/assets",
StaticFiles(directory=assets_dir),
name="assets",
)
# Serve static files from root (logo.png, favicon.ico, etc.)
# Files in public/ directory are copied to dist/ root during Vite build
# These routes must be defined BEFORE the catch-all route
@app.get("/logo.png")
async def serve_logo():
"""Serve logo.png from frontend_dist root."""
from fastapi.responses import FileResponse
logo_path = os.path.join(frontend_dir, "logo.png")
if os.path.exists(logo_path):
return FileResponse(logo_path, media_type="image/png")
from fastapi import HTTPException
raise HTTPException(status_code=404)
@app.get("/favicon.ico")
async def serve_favicon():
"""Serve favicon.ico from frontend_dist root."""
from fastapi.responses import FileResponse
favicon_path = os.path.join(frontend_dir, "favicon.ico")
if os.path.exists(favicon_path):
return FileResponse(favicon_path, media_type="image/x-icon")
from fastapi import HTTPException
raise HTTPException(status_code=404)
# Catch-all route to serve index.html for React Router
# This must be last so API routes and static files are matched first
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str):
"""
Serve React app for all non-API routes.
React Router will handle client-side routing.
"""
# Skip API routes, docs, static assets, and known static files
if (full_path.startswith("api/") or
full_path.startswith("docs") or
full_path.startswith("openapi.json") or
full_path.startswith("assets/") or
full_path in ["logo.png", "favicon.ico"]):
from fastapi import HTTPException
raise HTTPException(status_code=404)
# Serve index.html for all other routes (React Router will handle routing)
from fastapi.responses import FileResponse
index_path = os.path.join(frontend_dir, "index.html")
if os.path.exists(index_path):
return FileResponse(index_path)
from fastapi import HTTPException
raise HTTPException(status_code=404)