Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from typing import List, Dict, Optional | |
| from fastapi import FastAPI, UploadFile, File, Depends, Form, HTTPException, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | |
| from sqlalchemy.orm import Session | |
| from pydantic import BaseModel | |
| from .db import Base, engine, SessionLocal | |
| from .models import ExtractionRecord, User, ShareToken | |
| from .schemas import ExtractionRecordBase, ExtractionStage | |
| from .openrouter_client import extract_fields_from_document | |
| from .auth import get_current_user, get_db, verify_token | |
| from .auth_routes import router as auth_router | |
| from .api_key_auth import get_user_from_api_key | |
| # Allowed file types | |
| ALLOWED_CONTENT_TYPES = [ | |
| "application/pdf", | |
| "image/png", | |
| "image/jpeg", | |
| "image/jpg", | |
| "image/tiff", | |
| "image/tif" | |
| ] | |
| # Allowed file extensions (for fallback validation) | |
| ALLOWED_EXTENSIONS = [".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".tif"] | |
| # Maximum file size: 4 MB | |
| MAX_FILE_SIZE = 4 * 1024 * 1024 # 4 MB in bytes | |
| # Ensure data dir exists for SQLite | |
| os.makedirs("data", exist_ok=True) | |
| # Create tables | |
| Base.metadata.create_all(bind=engine) | |
| app = FastAPI(title="Document Capture Demo – Backend") | |
| # Include auth routes | |
| app.include_router(auth_router) | |
| # CORS (for safety we allow all; you can tighten later) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| async def get_current_user_or_api_key_user( | |
| api_key_user: Optional[User] = Depends(get_user_from_api_key), | |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)), | |
| db: Session = Depends(get_db), | |
| ) -> User: | |
| """ | |
| Flexible authentication: supports both JWT Bearer token and API key. | |
| Tries API key first, then falls back to JWT if no API key provided. | |
| """ | |
| # If API key authentication succeeded, use that | |
| if api_key_user: | |
| return api_key_user | |
| # Otherwise, try JWT authentication | |
| if credentials: | |
| try: | |
| token = credentials.credentials | |
| payload = verify_token(token) | |
| user_id = int(payload.get("sub")) | |
| user = db.query(User).filter(User.id == user_id).first() | |
| if user: | |
| return user | |
| except Exception: | |
| pass # Will raise HTTPException below | |
| # If neither worked, raise authentication error | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Authentication required. Provide either a Bearer token or X-API-Key header.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| def ping(): | |
| """Healthcheck.""" | |
| return {"status": "ok", "message": "backend alive"} | |
| def make_stages(total_ms: int, status: str) -> Dict[str, ExtractionStage]: | |
| """ | |
| Build synthetic stage timing data for the History UI. | |
| For now we just split total_ms into 4 stages. | |
| """ | |
| if total_ms <= 0: | |
| total_ms = 1000 | |
| return { | |
| "uploading": ExtractionStage( | |
| time=int(total_ms * 0.15), | |
| status="completed", | |
| variation="normal", | |
| ), | |
| "aiAnalysis": ExtractionStage( | |
| time=int(total_ms * 0.55), | |
| status="completed" if status == "completed" else "failed", | |
| variation="normal", | |
| ), | |
| "dataExtraction": ExtractionStage( | |
| time=int(total_ms * 0.2), | |
| status="completed" if status == "completed" else "skipped", | |
| variation="fast", | |
| ), | |
| "outputRendering": ExtractionStage( | |
| time=int(total_ms * 0.1), | |
| status="completed" if status == "completed" else "skipped", | |
| variation="normal", | |
| ), | |
| } | |
| async def extract_document( | |
| file: UploadFile = File(...), | |
| key_fields: Optional[str] = Form(None), | |
| db: Session = Depends(get_db), | |
| current_user: User = Depends(get_current_user_or_api_key_user), | |
| ): | |
| """ | |
| Main extraction endpoint for document parsing. | |
| Supports both JWT Bearer token and API key authentication. | |
| Authentication methods: | |
| 1. JWT Bearer token: Header "Authorization: Bearer <token>" | |
| 2. API Key: Header "X-API-Key: <api_key>" | |
| Parameters: | |
| - file: Document file (PDF, PNG, JPEG, TIFF) - max 4MB | |
| - key_fields: Optional comma-separated list of specific fields to extract (e.g., "Invoice Number,Invoice Date") | |
| Returns JSON with extracted fields, text, confidence, and metadata. | |
| """ | |
| start = time.time() | |
| content = await file.read() | |
| content_type = file.content_type or "application/octet-stream" | |
| file_size = len(content) | |
| size_mb = file_size / 1024 / 1024 | |
| size_str = f"{size_mb:.2f} MB" | |
| # Convert file content to base64 for storage | |
| import base64 | |
| file_base64 = base64.b64encode(content).decode("utf-8") | |
| # Validate file size | |
| if file_size > MAX_FILE_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File size exceeds 4 MB limit. Your file is {size_mb:.2f} MB." | |
| ) | |
| # Validate file type | |
| file_extension = "" | |
| if file.filename: | |
| file_extension = "." + file.filename.split(".")[-1].lower() | |
| is_valid_type = ( | |
| content_type in ALLOWED_CONTENT_TYPES or | |
| file_extension in ALLOWED_EXTENSIONS | |
| ) | |
| if not is_valid_type: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Only PDF, PNG, JPG, and TIFF files are allowed." | |
| ) | |
| try: | |
| print(f"[INFO] Starting extraction for file: {file.filename}, type: {content_type}, size: {size_str}") | |
| if key_fields: | |
| print(f"[INFO] Key fields requested: {key_fields}") | |
| extracted = await extract_fields_from_document(content, content_type, file.filename, key_fields) | |
| total_ms = int((time.time() - start) * 1000) | |
| print(f"[INFO] Extraction completed. Response keys: {list(extracted.keys())}") | |
| print(f"[INFO] Fields extracted: {extracted.get('fields', {})}") | |
| confidence = float(extracted.get("confidence", 90)) | |
| fields = extracted.get("fields", {}) | |
| # Get Fields from root level (if user provided key_fields) | |
| root_fields = extracted.get("Fields", {}) | |
| # Get full_text for text output | |
| full_text = extracted.get("full_text", "") | |
| if full_text: | |
| full_text_words = len(str(full_text).split()) | |
| print(f"[INFO] Full text extracted: {full_text_words} words") | |
| # Check if fields contain structured data (from table parsing) | |
| # If fields is a dict with page_X keys, it's already structured | |
| # If fields is empty or simple, add full_text and pages for text display | |
| if not fields or (isinstance(fields, dict) and not any(k.startswith("page_") for k in fields.keys())): | |
| if full_text: | |
| fields["full_text"] = full_text | |
| # Also check for pages array | |
| pages_data = extracted.get("pages", []) | |
| if pages_data and isinstance(pages_data, list): | |
| print(f"[INFO] Extracted text from {len(pages_data)} page(s)") | |
| fields["pages"] = pages_data | |
| # Add Fields at root level if it exists | |
| if root_fields: | |
| fields["Fields"] = root_fields | |
| # Count fields - if structured data exists, count table rows + root Fields | |
| if isinstance(fields, dict): | |
| # Check if it's structured page data | |
| if any(k.startswith("page_") for k in fields.keys()): | |
| # Count table rows from all pages | |
| table_rows_count = 0 | |
| for page_key, page_data in fields.items(): | |
| if page_key.startswith("page_") and isinstance(page_data, dict): | |
| table_rows = page_data.get("table", []) | |
| if isinstance(table_rows, list): | |
| table_rows_count += len(table_rows) | |
| # Count Fields from root level | |
| fields_keys = 0 | |
| if isinstance(root_fields, dict): | |
| fields_keys = len(root_fields) | |
| fields_extracted = table_rows_count + fields_keys | |
| print(f"[INFO] Structured data: {table_rows_count} table rows, {fields_keys} extracted fields") | |
| else: | |
| # Regular fields count (excluding full_text, pages, and Fields) | |
| fields_extracted = len([k for k in fields.keys() if k not in ["full_text", "pages", "Fields"]]) | |
| # Add Fields count if it exists | |
| if isinstance(root_fields, dict): | |
| fields_extracted += len(root_fields) | |
| else: | |
| fields_extracted = 0 | |
| print(f"[INFO] Final stats - confidence: {confidence}, fields_count: {fields_extracted}") | |
| status = "completed" | |
| error_message = None | |
| except Exception as e: | |
| import traceback | |
| total_ms = int((time.time() - start) * 1000) | |
| confidence = 0.0 | |
| fields = {} | |
| fields_extracted = 0 | |
| status = "failed" | |
| error_message = str(e) | |
| print(f"[ERROR] Extraction failed: {error_message}") | |
| print(f"[ERROR] Traceback: {traceback.format_exc()}") | |
| # Save record to DB | |
| import json | |
| import base64 | |
| rec = ExtractionRecord( | |
| user_id=current_user.id, | |
| file_name=file.filename, | |
| file_type=content_type, | |
| file_size=size_str, | |
| status=status, | |
| confidence=confidence, | |
| fields_extracted=fields_extracted, | |
| total_time_ms=total_ms, | |
| raw_output=json.dumps(fields), # Use JSON instead of str() to preserve structure | |
| file_base64=file_base64, # Store base64 encoded file for preview | |
| error_message=error_message, | |
| ) | |
| db.add(rec) | |
| db.commit() | |
| db.refresh(rec) | |
| stages = make_stages(total_ms, status) | |
| # Response shape that frontend will consume | |
| return { | |
| "id": rec.id, | |
| "fileName": rec.file_name, | |
| "fileType": rec.file_type, | |
| "fileSize": rec.file_size, | |
| "status": status, | |
| "confidence": confidence, | |
| "fieldsExtracted": fields_extracted, | |
| "totalTime": total_ms, | |
| "fields": fields, | |
| "stages": {k: v.dict() for k, v in stages.items()}, | |
| "errorMessage": error_message, | |
| } | |
| def get_history( | |
| db: Session = Depends(get_db), | |
| current_user: User = Depends(get_current_user), | |
| ): | |
| """ | |
| Used by the History page. | |
| Returns last 100 records for the current user, with synthetic stage data. | |
| """ | |
| recs = ( | |
| db.query(ExtractionRecord) | |
| .filter(ExtractionRecord.user_id == current_user.id) | |
| .order_by(ExtractionRecord.created_at.desc()) | |
| .limit(100) | |
| .all() | |
| ) | |
| # Deduplicate: if multiple extractions share the same shared_from_extraction_id, | |
| # keep only the most recent one (to prevent duplicates when same extraction is shared multiple times) | |
| seen_shared_ids = set() | |
| deduplicated_recs = [] | |
| for rec in recs: | |
| if rec.shared_from_extraction_id: | |
| # This is a shared extraction | |
| if rec.shared_from_extraction_id not in seen_shared_ids: | |
| seen_shared_ids.add(rec.shared_from_extraction_id) | |
| deduplicated_recs.append(rec) | |
| # Skip duplicates | |
| else: | |
| # Original extraction (not shared), always include | |
| deduplicated_recs.append(rec) | |
| recs = deduplicated_recs | |
| output: List[ExtractionRecordBase] = [] | |
| for r in recs: | |
| stages = make_stages(r.total_time_ms or 1000, r.status or "completed") | |
| output.append( | |
| ExtractionRecordBase( | |
| id=r.id, | |
| fileName=r.file_name, | |
| fileType=r.file_type or "", | |
| fileSize=r.file_size or "", | |
| extractedAt=r.created_at, | |
| status=r.status or "completed", | |
| confidence=r.confidence or 0.0, | |
| fieldsExtracted=r.fields_extracted or 0, | |
| totalTime=r.total_time_ms or 0, | |
| stages=stages, | |
| errorMessage=r.error_message, | |
| ) | |
| ) | |
| return output | |
| def get_extraction( | |
| extraction_id: int, | |
| db: Session = Depends(get_db), | |
| current_user: User = Depends(get_current_user), | |
| ): | |
| """ | |
| Get a specific extraction by ID with full fields data. | |
| Used when viewing output from History page. | |
| """ | |
| import json | |
| rec = ( | |
| db.query(ExtractionRecord) | |
| .filter( | |
| ExtractionRecord.id == extraction_id, | |
| ExtractionRecord.user_id == current_user.id | |
| ) | |
| .first() | |
| ) | |
| if not rec: | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404, detail="Extraction not found") | |
| # Parse the raw_output JSON string back to dict | |
| fields = {} | |
| if rec.raw_output: | |
| try: | |
| # Try parsing as JSON first (new format) | |
| fields = json.loads(rec.raw_output) | |
| except (json.JSONDecodeError, TypeError): | |
| # If that fails, try using ast.literal_eval for old str() format (backward compatibility) | |
| try: | |
| import ast | |
| # Only use literal_eval if it looks like a Python dict string | |
| if rec.raw_output.strip().startswith('{'): | |
| fields = ast.literal_eval(rec.raw_output) | |
| else: | |
| fields = {} | |
| except: | |
| fields = {} | |
| stages = make_stages(rec.total_time_ms or 1000, rec.status or "completed") | |
| return { | |
| "id": rec.id, | |
| "fileName": rec.file_name, | |
| "fileType": rec.file_type or "", | |
| "fileSize": rec.file_size or "", | |
| "status": rec.status or "completed", | |
| "confidence": rec.confidence or 0.0, | |
| "fieldsExtracted": rec.fields_extracted or 0, | |
| "totalTime": rec.total_time_ms or 0, | |
| "fields": fields, | |
| "fileBase64": rec.file_base64, # Include base64 encoded file for preview | |
| "stages": {k: v.dict() for k, v in stages.items()}, | |
| "errorMessage": rec.error_message, | |
| } | |
| async def share_extraction( | |
| extraction_id: int = Body(...), | |
| recipient_emails: List[str] = Body(...), | |
| db: Session = Depends(get_db), | |
| current_user: User = Depends(get_current_user), | |
| ): | |
| """ | |
| Share an extraction with one or more users via email. | |
| Creates share tokens and sends emails to recipients. | |
| """ | |
| import secrets | |
| from datetime import datetime, timedelta | |
| from .brevo_service import send_share_email | |
| from .email_validator import validate_business_email | |
| # Validate recipient emails list | |
| if not recipient_emails or len(recipient_emails) == 0: | |
| raise HTTPException(status_code=400, detail="At least one recipient email is required") | |
| # Validate each recipient email is a business email | |
| for email in recipient_emails: | |
| try: | |
| validate_business_email(email) | |
| except HTTPException: | |
| raise # Re-raise HTTPException from validate_business_email | |
| # Get the extraction record | |
| extraction = ( | |
| db.query(ExtractionRecord) | |
| .filter( | |
| ExtractionRecord.id == extraction_id, | |
| ExtractionRecord.user_id == current_user.id | |
| ) | |
| .first() | |
| ) | |
| if not extraction: | |
| raise HTTPException(status_code=404, detail="Extraction not found") | |
| # Generate share link base URL | |
| base_url = os.environ.get("VITE_API_BASE_URL", "https://seth0330-ezofisocr.hf.space") | |
| # Process each recipient email | |
| successful_shares = [] | |
| failed_shares = [] | |
| share_records = [] | |
| for recipient_email in recipient_emails: | |
| recipient_email = recipient_email.strip().lower() | |
| # Generate secure share token for this recipient | |
| share_token = secrets.token_urlsafe(32) | |
| # Create share token record (expires in 30 days) | |
| expires_at = datetime.utcnow() + timedelta(days=30) | |
| share_record = ShareToken( | |
| token=share_token, | |
| extraction_id=extraction_id, | |
| sender_user_id=current_user.id, | |
| recipient_email=recipient_email, | |
| expires_at=expires_at, | |
| ) | |
| db.add(share_record) | |
| share_records.append((share_record, share_token, recipient_email)) | |
| # Commit all share tokens | |
| try: | |
| db.commit() | |
| for share_record, share_token, recipient_email in share_records: | |
| db.refresh(share_record) | |
| except Exception as e: | |
| db.rollback() | |
| raise HTTPException(status_code=500, detail=f"Failed to create share tokens: {str(e)}") | |
| # Send emails to all recipients | |
| for share_record, share_token, recipient_email in share_records: | |
| share_link = f"{base_url}/share/{share_token}" | |
| try: | |
| # Get sender's name from current_user, fallback to None if not available | |
| sender_name = current_user.name if current_user.name else None | |
| await send_share_email(recipient_email, current_user.email, share_link, sender_name) | |
| successful_shares.append(recipient_email) | |
| except Exception as e: | |
| # Log error but continue with other emails | |
| print(f"[ERROR] Failed to send share email to {recipient_email}: {str(e)}") | |
| failed_shares.append(recipient_email) | |
| # Optionally, you could delete the share token if email fails | |
| # db.delete(share_record) | |
| # Build response message | |
| if len(failed_shares) == 0: | |
| message = f"Extraction shared successfully with {len(successful_shares)} recipient(s)" | |
| elif len(successful_shares) == 0: | |
| raise HTTPException(status_code=500, detail=f"Failed to send share emails to all recipients") | |
| else: | |
| message = f"Extraction shared with {len(successful_shares)} recipient(s). Failed to send to: {', '.join(failed_shares)}" | |
| return { | |
| "success": True, | |
| "message": message, | |
| "successful_count": len(successful_shares), | |
| "failed_count": len(failed_shares), | |
| "successful_emails": successful_shares, | |
| "failed_emails": failed_shares if failed_shares else None | |
| } | |
| class ShareLinkRequest(BaseModel): | |
| extraction_id: int | |
| async def create_share_link( | |
| request: ShareLinkRequest, | |
| db: Session = Depends(get_db), | |
| current_user: User = Depends(get_current_user), | |
| ): | |
| """ | |
| Create a shareable link for an extraction without requiring recipient emails. | |
| Returns a share link that can be copied and shared manually. | |
| """ | |
| import secrets | |
| from datetime import datetime, timedelta | |
| # Get the extraction record | |
| extraction = ( | |
| db.query(ExtractionRecord) | |
| .filter( | |
| ExtractionRecord.id == request.extraction_id, | |
| ExtractionRecord.user_id == current_user.id | |
| ) | |
| .first() | |
| ) | |
| if not extraction: | |
| raise HTTPException(status_code=404, detail="Extraction not found") | |
| # Generate secure share token | |
| share_token = secrets.token_urlsafe(32) | |
| # Create share token record (expires in 30 days, no specific recipient) | |
| expires_at = datetime.utcnow() + timedelta(days=30) | |
| share_record = ShareToken( | |
| token=share_token, | |
| extraction_id=request.extraction_id, | |
| sender_user_id=current_user.id, | |
| recipient_email=None, # None for public share links (copyable links) | |
| expires_at=expires_at, | |
| ) | |
| db.add(share_record) | |
| db.commit() | |
| db.refresh(share_record) | |
| # Generate share link | |
| base_url = os.environ.get("VITE_API_BASE_URL", "https://seth0330-ezofisocr.hf.space") | |
| share_link = f"{base_url}/share/{share_token}" | |
| return { | |
| "success": True, | |
| "share_link": share_link, | |
| "share_token": share_token, | |
| "expires_at": expires_at.isoformat() if expires_at else None | |
| } | |
| async def access_shared_extraction( | |
| token: str, | |
| db: Session = Depends(get_db), | |
| current_user: User = Depends(get_current_user), | |
| ): | |
| """ | |
| Access a shared extraction and copy it to the current user's account. | |
| This endpoint is called after the user logs in via the share link. | |
| """ | |
| from datetime import datetime | |
| import json | |
| # Find the share token | |
| share = ( | |
| db.query(ShareToken) | |
| .filter(ShareToken.token == token) | |
| .first() | |
| ) | |
| if not share: | |
| raise HTTPException(status_code=404, detail="Share link not found or expired") | |
| # Check if token is expired | |
| if share.expires_at and share.expires_at < datetime.utcnow(): | |
| raise HTTPException(status_code=410, detail="Share link has expired") | |
| # Get the original extraction | |
| original_extraction = ( | |
| db.query(ExtractionRecord) | |
| .filter(ExtractionRecord.id == share.extraction_id) | |
| .first() | |
| ) | |
| if not original_extraction: | |
| raise HTTPException(status_code=404, detail="Original extraction not found") | |
| # Check if already copied for this user (check by share token to prevent duplicates from same share) | |
| # Also check if this specific share token was already used by this user | |
| if share.accessed and share.accessed_by_user_id == current_user.id: | |
| # This share token was already used by this user, find the extraction | |
| existing_copy = ( | |
| db.query(ExtractionRecord) | |
| .filter( | |
| ExtractionRecord.user_id == current_user.id, | |
| ExtractionRecord.shared_from_extraction_id == original_extraction.id | |
| ) | |
| .order_by(ExtractionRecord.created_at.desc()) | |
| .first() | |
| ) | |
| if existing_copy: | |
| return { | |
| "success": True, | |
| "extraction_id": existing_copy.id, | |
| "message": "Extraction already shared with you" | |
| } | |
| # Also check if any copy exists for this user from this original extraction | |
| existing_copy = ( | |
| db.query(ExtractionRecord) | |
| .filter( | |
| ExtractionRecord.user_id == current_user.id, | |
| ExtractionRecord.shared_from_extraction_id == original_extraction.id | |
| ) | |
| .first() | |
| ) | |
| if existing_copy: | |
| # Already copied, mark this share as accessed and return existing extraction ID | |
| share.accessed = True | |
| share.accessed_at = datetime.utcnow() | |
| share.accessed_by_user_id = current_user.id | |
| db.commit() | |
| return { | |
| "success": True, | |
| "extraction_id": existing_copy.id, | |
| "message": "Extraction already shared with you" | |
| } | |
| # Copy extraction to current user's account | |
| # Parse the raw_output JSON string back to dict | |
| fields = {} | |
| if original_extraction.raw_output: | |
| try: | |
| fields = json.loads(original_extraction.raw_output) | |
| except (json.JSONDecodeError, TypeError): | |
| try: | |
| import ast | |
| if original_extraction.raw_output.strip().startswith('{'): | |
| fields = ast.literal_eval(original_extraction.raw_output) | |
| else: | |
| fields = {} | |
| except: | |
| fields = {} | |
| # Create new extraction record for the recipient | |
| new_extraction = ExtractionRecord( | |
| user_id=current_user.id, | |
| file_name=original_extraction.file_name, | |
| file_type=original_extraction.file_type, | |
| file_size=original_extraction.file_size, | |
| status=original_extraction.status or "completed", | |
| confidence=original_extraction.confidence or 0.0, | |
| fields_extracted=original_extraction.fields_extracted or 0, | |
| total_time_ms=original_extraction.total_time_ms or 0, | |
| raw_output=original_extraction.raw_output, # Copy the JSON string | |
| file_base64=original_extraction.file_base64, # Copy the base64 file | |
| shared_from_extraction_id=original_extraction.id, | |
| shared_by_user_id=share.sender_user_id, | |
| ) | |
| db.add(new_extraction) | |
| # Mark share as accessed | |
| share.accessed = True | |
| share.accessed_at = datetime.utcnow() | |
| share.accessed_by_user_id = current_user.id | |
| db.commit() | |
| db.refresh(new_extraction) | |
| return { | |
| "success": True, | |
| "extraction_id": new_extraction.id, | |
| "message": "Extraction shared successfully" | |
| } | |
| # Static frontend mounting (used after we build React) | |
| # Dockerfile copies the Vite build into backend/frontend_dist | |
| # IMPORTANT: API routes must be defined BEFORE this so they take precedence | |
| frontend_dir = os.path.join( | |
| os.path.dirname(os.path.dirname(__file__)), "frontend_dist" | |
| ) | |
| if os.path.isdir(frontend_dir): | |
| # Serve static files (JS, CSS, images, etc.) from assets directory | |
| assets_dir = os.path.join(frontend_dir, "assets") | |
| if os.path.isdir(assets_dir): | |
| app.mount( | |
| "/assets", | |
| StaticFiles(directory=assets_dir), | |
| name="assets", | |
| ) | |
| # Serve static files from root (logo.png, favicon.ico, etc.) | |
| # Files in public/ directory are copied to dist/ root during Vite build | |
| # These routes must be defined BEFORE the catch-all route | |
| async def serve_logo(): | |
| """Serve logo.png from frontend_dist root.""" | |
| from fastapi.responses import FileResponse | |
| logo_path = os.path.join(frontend_dir, "logo.png") | |
| if os.path.exists(logo_path): | |
| return FileResponse(logo_path, media_type="image/png") | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404) | |
| async def serve_favicon(): | |
| """Serve favicon.ico from frontend_dist root.""" | |
| from fastapi.responses import FileResponse | |
| favicon_path = os.path.join(frontend_dir, "favicon.ico") | |
| if os.path.exists(favicon_path): | |
| return FileResponse(favicon_path, media_type="image/x-icon") | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404) | |
| # Catch-all route to serve index.html for React Router | |
| # This must be last so API routes and static files are matched first | |
| async def serve_frontend(full_path: str): | |
| """ | |
| Serve React app for all non-API routes. | |
| React Router will handle client-side routing. | |
| """ | |
| # Skip API routes, docs, static assets, and known static files | |
| if (full_path.startswith("api/") or | |
| full_path.startswith("docs") or | |
| full_path.startswith("openapi.json") or | |
| full_path.startswith("assets/") or | |
| full_path in ["logo.png", "favicon.ico"]): | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404) | |
| # Serve index.html for all other routes (React Router will handle routing) | |
| from fastapi.responses import FileResponse | |
| index_path = os.path.join(frontend_dir, "index.html") | |
| if os.path.exists(index_path): | |
| return FileResponse(index_path) | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404) | |