Spaces:
Running
Running
| from fastapi import APIRouter, HTTPException, BackgroundTasks | |
| from app.schemas.request import TrainingRequest | |
| from app.schemas.response import TrainingResponse | |
| from app.utils.model_loader import model_loader | |
| from mlpipeline.pipeline.training_pipeline import TrainingPipeline | |
| router = APIRouter(prefix="/train", tags=["training"]) | |
| def run_training_task(): | |
| try: | |
| pipeline = TrainingPipeline() | |
| artifact = pipeline.run_pipeline() | |
| model_loader.reload_model() | |
| return artifact | |
| except Exception as e: | |
| raise e | |
| async def trigger_training(request: TrainingRequest, background_tasks: BackgroundTasks): | |
| try: | |
| if request.force_retrain: | |
| background_tasks.add_task(run_training_task) | |
| return TrainingResponse( | |
| status="training_started", | |
| message="Training pipeline started in background" | |
| ) | |
| else: | |
| artifact = run_training_task() | |
| return TrainingResponse( | |
| status="completed", | |
| message="Training completed successfully", | |
| model_path=artifact.pushed_model_path | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def reload_model(): | |
| try: | |
| model_loader.reload_model() | |
| return TrainingResponse( | |
| status="success", | |
| message="Model reloaded successfully" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |