CreditScoreServer / server /fl_server.py
Sammi1211's picture
Initial push
d7f62d0
"""Starts flwr.server.start_server with our custom strategy."""
from __future__ import annotations
import logging
import flwr as fl
from server.config import (
FLOWER_SERVER_ADDRESS,
GRPC_MAX_MESSAGE_LENGTH,
NUM_ROUNDS,
)
from server.hf_persistence import load_checkpoint, save_checkpoint
from server.model import init_weights, weights_to_parameters
from server.state import StrategyState
from server.strategy import TabularFedAvg
log = logging.getLogger(__name__)
def build_strategy(state: StrategyState) -> TabularFedAvg:
initial_weights = load_checkpoint() or init_weights()
initial_parameters = weights_to_parameters(initial_weights)
class PersistingStrategy(TabularFedAvg):
def aggregate_fit(self, server_round, results, failures):
agg_params, metrics = super().aggregate_fit(server_round, results, failures)
if agg_params is not None and self.state.latest_weights is not None:
save_checkpoint(self.state.latest_weights)
return agg_params, metrics
return PersistingStrategy(
state=state,
initial_parameters=initial_parameters,
)
def run(state: StrategyState) -> None:
logging.basicConfig(level=logging.INFO)
strategy = build_strategy(state)
log.info("Starting Flower server on %s", FLOWER_SERVER_ADDRESS)
fl.server.start_server(
server_address=FLOWER_SERVER_ADDRESS,
config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
strategy=strategy,
grpc_max_message_length=GRPC_MAX_MESSAGE_LENGTH,
)
if __name__ == "__main__":
run(StrategyState())