from fastapi import APIRouter, HTTPException, BackgroundTasks from app.schemas.request import TrainingRequest from app.schemas.response import TrainingResponse from app.utils.model_loader import model_loader from mlpipeline.pipeline.training_pipeline import TrainingPipeline router = APIRouter(prefix="/train", tags=["training"]) def run_training_task(): try: pipeline = TrainingPipeline() artifact = pipeline.run_pipeline() model_loader.reload_model() return artifact except Exception as e: raise e @router.post("/", response_model=TrainingResponse) async def trigger_training(request: TrainingRequest, background_tasks: BackgroundTasks): try: if request.force_retrain: background_tasks.add_task(run_training_task) return TrainingResponse( status="training_started", message="Training pipeline started in background" ) else: artifact = run_training_task() return TrainingResponse( status="completed", message="Training completed successfully", model_path=artifact.pushed_model_path ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/reload", response_model=TrainingResponse) async def reload_model(): try: model_loader.reload_model() return TrainingResponse( status="success", message="Model reloaded successfully" ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))