Spaces:
Sleeping
Sleeping
| """Starts flwr.server.start_server with our custom strategy.""" | |
| from __future__ import annotations | |
| import logging | |
| import flwr as fl | |
| from server.config import ( | |
| FLOWER_SERVER_ADDRESS, | |
| GRPC_MAX_MESSAGE_LENGTH, | |
| NUM_ROUNDS, | |
| ) | |
| from server.hf_persistence import load_checkpoint, save_checkpoint | |
| from server.model import init_weights, weights_to_parameters | |
| from server.state import StrategyState | |
| from server.strategy import TabularFedAvg | |
| log = logging.getLogger(__name__) | |
| def build_strategy(state: StrategyState) -> TabularFedAvg: | |
| initial_weights = load_checkpoint() or init_weights() | |
| initial_parameters = weights_to_parameters(initial_weights) | |
| class PersistingStrategy(TabularFedAvg): | |
| def aggregate_fit(self, server_round, results, failures): | |
| agg_params, metrics = super().aggregate_fit(server_round, results, failures) | |
| if agg_params is not None and self.state.latest_weights is not None: | |
| save_checkpoint(self.state.latest_weights) | |
| return agg_params, metrics | |
| return PersistingStrategy( | |
| state=state, | |
| initial_parameters=initial_parameters, | |
| ) | |
| def run(state: StrategyState) -> None: | |
| logging.basicConfig(level=logging.INFO) | |
| strategy = build_strategy(state) | |
| log.info("Starting Flower server on %s", FLOWER_SERVER_ADDRESS) | |
| fl.server.start_server( | |
| server_address=FLOWER_SERVER_ADDRESS, | |
| config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), | |
| strategy=strategy, | |
| grpc_max_message_length=GRPC_MAX_MESSAGE_LENGTH, | |
| ) | |
| if __name__ == "__main__": | |
| run(StrategyState()) | |