| """
|
| Federated Learning API endpoints.
|
|
|
| Provides endpoints for FL operations:
|
| - Client registration
|
| - Contribution submission
|
| - Federation status
|
| """
|
|
|
| from fastapi import APIRouter
|
| from fastapi import HTTPException
|
| from fastapi import status
|
| from pydantic import BaseModel
|
| from pydantic import Field
|
|
|
| from app.api.dependencies import ValidatedApiKey
|
| from app.services.federated_learning import fl_manager
|
|
|
| router = APIRouter()
|
|
|
|
|
| class ClientRegistrationRequest(BaseModel):
|
| """Request to register as FL client."""
|
|
|
| client_id: str = Field(..., min_length=3, max_length=64)
|
| organization: str | None = Field(None, max_length=128)
|
|
|
|
|
| class ContributionRequest(BaseModel):
|
| """Request to submit a training contribution."""
|
|
|
| client_id: str = Field(..., min_length=3, max_length=64)
|
| gradient_hash: str = Field(..., min_length=16, max_length=128)
|
| samples_trained: int = Field(..., ge=1, le=100000)
|
| local_accuracy: float = Field(..., ge=0.0, le=1.0)
|
|
|
|
|
| @router.post(
|
| "/federated/register",
|
| summary="Register as Federated Client",
|
| description="Register as a federated learning participant.",
|
| )
|
| async def register_client(
|
| request: ClientRegistrationRequest,
|
| api_key: ValidatedApiKey,
|
| ) -> dict:
|
| """Register a new federated learning client."""
|
| client = fl_manager.register_client(
|
| client_id=request.client_id,
|
| organization=request.organization,
|
| )
|
|
|
| return {
|
| "status": "registered",
|
| "client_id": client.client_id,
|
| "organization": client.organization,
|
| "registered_at": client.registered_at,
|
| }
|
|
|
|
|
| @router.post(
|
| "/federated/contribute",
|
| summary="Submit Training Contribution",
|
| description="Submit model gradients from local training.",
|
| )
|
| async def submit_contribution(
|
| request: ContributionRequest,
|
| api_key: ValidatedApiKey,
|
| ) -> dict:
|
| """Submit a training contribution."""
|
| result = fl_manager.submit_contribution(
|
| client_id=request.client_id,
|
| gradient_hash=request.gradient_hash,
|
| samples_trained=request.samples_trained,
|
| local_accuracy=request.local_accuracy,
|
| )
|
|
|
| return result
|
|
|
|
|
| @router.get(
|
| "/federated/status",
|
| summary="Federation Status",
|
| description="Get current federated learning status.",
|
| )
|
| async def federation_status() -> dict:
|
| """Get federation status."""
|
| return fl_manager.get_federation_status()
|
|
|
|
|
| @router.post(
|
| "/federated/aggregate",
|
| summary="Trigger Aggregation",
|
| description="Trigger federated model aggregation (admin only).",
|
| )
|
| async def trigger_aggregation(
|
| api_key: ValidatedApiKey,
|
| ) -> dict:
|
| """Trigger model aggregation."""
|
| return fl_manager.simulate_aggregation()
|
|
|