vectorplasticity's picture
FIX: Start training queue workers on startup - jobs will now actually process!
cac616c verified
Raw
History Blame Contribute Delete
5.73 kB
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="Universal Model Trainer",
description="Train any model for any purpose on HuggingFace",
version="1.0.0"
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup authentication - session middleware only (no custom middleware)
from app.auth import setup_auth
from app.config import settings
setup_auth(app) # Adds session middleware
# Import routers AFTER app is created and middleware is configured
from app.routers import models, datasets, training, system, jobs
from app.routers import auth as auth_router
# Register auth router first (for login/logout endpoints)
app.include_router(auth_router.router)
# Register other routers with prefixes - routers define their own prefix in APIRouter()
app.include_router(models.router)
app.include_router(datasets.router)
app.include_router(training.router)
app.include_router(system.router)
app.include_router(jobs.router)
# Mount static files
static_dir = os.path.join(os.path.dirname(__file__), "static")
if os.path.exists(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Login page endpoint
@app.get("/login", response_class=HTMLResponse)
async def login_page():
"""Serve the login page"""
from app.auth import get_login_html
return HTMLResponse(content=get_login_html())
# Root endpoint - serves dashboard (with auth check)
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
"""Serve the main dashboard - check auth first if enabled"""
# Check if auth is required and user is logged in
if settings.APP_PASSWORD:
from app.auth import is_authenticated
if not is_authenticated(request):
# Redirect to login page
return HTMLResponse(
content='<script>window.location.href="/login";</script>',
status_code=302,
headers={"location": "/login"}
)
template_path = os.path.join(os.path.dirname(__file__), "templates", "dashboard.html")
try:
with open(template_path, 'r') as f:
return HTMLResponse(content=f.read())
except FileNotFoundError:
return HTMLResponse(content="<h1>Dashboard not found</h1><p>Please check that dashboard.html exists.</p>", status_code=404)
@app.get("/health")
async def health_check():
"""Health check endpoint - always public for monitoring"""
return {
"status": "healthy",
"version": "1.0.0",
"auth_enabled": bool(settings.APP_PASSWORD)
}
@app.on_event("startup")
async def startup_event():
logger.info("Universal Model Trainer starting up...")
# Log authentication status
if settings.APP_PASSWORD:
logger.info("Authentication enabled - password protected mode")
else:
logger.info("No password configured - public access mode")
# Initialize database tables
from app.database import init_db
await init_db()
logger.info("Database tables initialized")
# START THE TRAINING QUEUE - THIS WAS MISSING!
from app.routers.training import get_queue
queue = get_queue()
await queue.start()
logger.info("Training queue started - workers are now processing jobs")
logger.info("Routers registered: auth, models, datasets, training, system, jobs")
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Universal Model Trainer shutting down...")
# Stop the training queue gracefully
from app.routers.training import get_queue
queue = get_queue()
await queue.stop()
logger.info("Training queue stopped")
# Catch-all for 404 errors
@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
"""Handle 404 errors"""
# Check if it's an API request
if request.url.path.startswith("/api/"):
return JSONResponse(
status_code=404,
content={"detail": "Endpoint not found"}
)
# Return to dashboard for HTML requests
return HTMLResponse(
content="<script>window.location.href='/';</script>",
status_code=302,
headers={"location": "/"}
)
# Exception handler for auth errors
from fastapi.exceptions import HTTPException
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""Handle HTTP exceptions with proper response type"""
if exc.status_code == 401:
# Unauthorized - check if it's an API or HTML request
accept = request.headers.get("accept", "")
if "text/html" in accept:
from app.auth import get_login_html
return HTMLResponse(
content=get_login_html(error="Please log in to continue."),
status_code=401
)
else:
return JSONResponse(
status_code=401,
content={"detail": str(exc.detail), "login_url": "/login"}
)
# Default error response
if request.url.path.startswith("/api/"):
return JSONResponse(
status_code=exc.status_code,
content={"detail": str(exc.detail)}
)
return HTMLResponse(
content=f"<h1>Error {exc.status_code}</h1><p>{exc.detail}</p>",
status_code=exc.status_code
)