mrs83's picture
Initial Import
3e25ded
raw
history blame
1.2 kB
"""huggingface_example: A Flower / Hugging Face app."""
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from .task import get_params, get_model
def server_fn(context: Context) -> ServerAppComponents:
"""Construct components for ServerApp."""
# Construct ServerConfig
num_rounds = context.run_config["num-server-rounds"]
config = ServerConfig(num_rounds=num_rounds)
# Set global model initialization
model_name = context.run_config["model-name"]
ndarrays = get_params(get_model(model_name))
global_model_init = ndarrays_to_parameters(ndarrays)
# Define strategy
fraction_fit = context.run_config["fraction-fit"]
fraction_evaluate = context.run_config["fraction-evaluate"]
strategy = FedAvg(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=1, # Wait for at least one client to start
min_available_clients=1,
initial_parameters=global_model_init,
)
return ServerAppComponents(config=config, strategy=strategy)
app = ServerApp(server_fn=server_fn)