Spaces:
Sleeping
Sleeping
| """huggingface_example: A Flower / Hugging Face app.""" | |
| import warnings | |
| import torch | |
| from flwr.client import Client, ClientApp, NumPyClient | |
| from flwr.common import Context | |
| from transformers import logging | |
| from .task import ( | |
| train, | |
| test, | |
| load_data, | |
| set_params, | |
| get_params, | |
| get_model, | |
| ) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| # To mute warnings reminding that we need to train the model to a downstream task | |
| # This is something this example does. | |
| logging.set_verbosity_error() | |
| # Flower client | |
| class IMDBClient(NumPyClient): | |
| def __init__(self, model_name, trainloader, testloader) -> None: | |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| self.trainloader = trainloader | |
| self.testloader = testloader | |
| self.net = get_model(model_name) | |
| self.net.to(self.device) | |
| def fit(self, parameters, config) -> tuple[list, int, dict]: | |
| set_params(self.net, parameters) | |
| train(self.net, self.trainloader, epochs=1, device=self.device) | |
| return get_params(self.net), len(self.trainloader), {} | |
| def evaluate(self, parameters, config) -> tuple[float, int, dict[str, float]]: | |
| set_params(self.net, parameters) | |
| loss, accuracy = test(self.net, self.testloader, device=self.device) | |
| return float(loss), len(self.testloader), {"accuracy": float(accuracy)} | |
| def client_fn(context: Context) -> Client: | |
| """Construct a Client that will be run in a ClientApp.""" | |
| # Read the node_config to fetch data partition associated to this node | |
| partition_id = context.node_config["partition-id"] | |
| num_partitions = context.node_config["num-partitions"] | |
| # Read the run config to get settings to configure the Client | |
| model_name = context.run_config["model-name"] | |
| trainloader, testloader = load_data(partition_id, num_partitions, model_name) | |
| return IMDBClient(model_name, trainloader, testloader).to_client() | |
| app = ClientApp(client_fn=client_fn) | |