"""Starts flwr.server.start_server with our custom strategy.""" from __future__ import annotations import logging import flwr as fl from server.config import ( FLOWER_SERVER_ADDRESS, GRPC_MAX_MESSAGE_LENGTH, NUM_ROUNDS, ) from server.hf_persistence import load_checkpoint, save_checkpoint from server.model import init_weights, weights_to_parameters from server.state import StrategyState from server.strategy import TabularFedAvg log = logging.getLogger(__name__) def build_strategy(state: StrategyState) -> TabularFedAvg: initial_weights = load_checkpoint() or init_weights() initial_parameters = weights_to_parameters(initial_weights) class PersistingStrategy(TabularFedAvg): def aggregate_fit(self, server_round, results, failures): agg_params, metrics = super().aggregate_fit(server_round, results, failures) if agg_params is not None and self.state.latest_weights is not None: save_checkpoint(self.state.latest_weights) return agg_params, metrics return PersistingStrategy( state=state, initial_parameters=initial_parameters, ) def run(state: StrategyState) -> None: logging.basicConfig(level=logging.INFO) strategy = build_strategy(state) log.info("Starting Flower server on %s", FLOWER_SERVER_ADDRESS) fl.server.start_server( server_address=FLOWER_SERVER_ADDRESS, config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), strategy=strategy, grpc_max_message_length=GRPC_MAX_MESSAGE_LENGTH, ) if __name__ == "__main__": run(StrategyState())