| from fastapi import FastAPI, Request |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import HTMLResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import os |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI( |
| title="Universal Model Trainer", |
| description="Train any model for any purpose on HuggingFace", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| from app.auth import setup_auth |
| from app.config import settings |
|
|
| setup_auth(app) |
|
|
| |
| from app.routers import models, datasets, training, system, jobs |
| from app.routers import auth as auth_router |
|
|
| |
| app.include_router(auth_router.router) |
|
|
| |
| app.include_router(models.router) |
| app.include_router(datasets.router) |
| app.include_router(training.router) |
| app.include_router(system.router) |
| app.include_router(jobs.router) |
|
|
| |
| static_dir = os.path.join(os.path.dirname(__file__), "static") |
| if os.path.exists(static_dir): |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
| |
| @app.get("/login", response_class=HTMLResponse) |
| async def login_page(): |
| """Serve the login page""" |
| from app.auth import get_login_html |
| return HTMLResponse(content=get_login_html()) |
|
|
| |
| @app.get("/", response_class=HTMLResponse) |
| async def root(request: Request): |
| """Serve the main dashboard - check auth first if enabled""" |
| |
| if settings.APP_PASSWORD: |
| from app.auth import is_authenticated |
| if not is_authenticated(request): |
| |
| return HTMLResponse( |
| content='<script>window.location.href="/login";</script>', |
| status_code=302, |
| headers={"location": "/login"} |
| ) |
| |
| template_path = os.path.join(os.path.dirname(__file__), "templates", "dashboard.html") |
| try: |
| with open(template_path, 'r') as f: |
| return HTMLResponse(content=f.read()) |
| except FileNotFoundError: |
| return HTMLResponse(content="<h1>Dashboard not found</h1><p>Please check that dashboard.html exists.</p>", status_code=404) |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint - always public for monitoring""" |
| return { |
| "status": "healthy", |
| "version": "1.0.0", |
| "auth_enabled": bool(settings.APP_PASSWORD) |
| } |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| logger.info("Universal Model Trainer starting up...") |
| |
| |
| if settings.APP_PASSWORD: |
| logger.info("Authentication enabled - password protected mode") |
| else: |
| logger.info("No password configured - public access mode") |
| |
| |
| from app.database import init_db |
| await init_db() |
| logger.info("Database tables initialized") |
| |
| |
| from app.routers.training import get_queue |
| queue = get_queue() |
| await queue.start() |
| logger.info("Training queue started - workers are now processing jobs") |
| |
| logger.info("Routers registered: auth, models, datasets, training, system, jobs") |
|
|
| @app.on_event("shutdown") |
| async def shutdown_event(): |
| logger.info("Universal Model Trainer shutting down...") |
| |
| |
| from app.routers.training import get_queue |
| queue = get_queue() |
| await queue.stop() |
| logger.info("Training queue stopped") |
|
|
| |
| @app.exception_handler(404) |
| async def not_found_handler(request: Request, exc): |
| """Handle 404 errors""" |
| |
| if request.url.path.startswith("/api/"): |
| return JSONResponse( |
| status_code=404, |
| content={"detail": "Endpoint not found"} |
| ) |
| |
| return HTMLResponse( |
| content="<script>window.location.href='/';</script>", |
| status_code=302, |
| headers={"location": "/"} |
| ) |
|
|
| |
| from fastapi.exceptions import HTTPException |
|
|
| @app.exception_handler(HTTPException) |
| async def http_exception_handler(request: Request, exc: HTTPException): |
| """Handle HTTP exceptions with proper response type""" |
| if exc.status_code == 401: |
| |
| accept = request.headers.get("accept", "") |
| if "text/html" in accept: |
| from app.auth import get_login_html |
| return HTMLResponse( |
| content=get_login_html(error="Please log in to continue."), |
| status_code=401 |
| ) |
| else: |
| return JSONResponse( |
| status_code=401, |
| content={"detail": str(exc.detail), "login_url": "/login"} |
| ) |
| |
| |
| if request.url.path.startswith("/api/"): |
| return JSONResponse( |
| status_code=exc.status_code, |
| content={"detail": str(exc.detail)} |
| ) |
| |
| return HTMLResponse( |
| content=f"<h1>Error {exc.status_code}</h1><p>{exc.detail}</p>", |
| status_code=exc.status_code |
| ) |
|
|