Spaces:
Runtime error
Runtime error
| from substra_helpers.substra_runner import SubstraRunner, algo_generator | |
| from substra_helpers.model import CNN | |
| from substra_helpers.dataset import TorchDataset | |
| from substrafl.strategies import FedAvg | |
| import torch | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| NUM_CLIENTS = int(os.environ["NUM_CLIENTS"]) | |
| seed = 42 | |
| torch.manual_seed(seed) | |
| model = CNN() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
| criterion = torch.nn.CrossEntropyLoss() | |
| runner = SubstraRunner(num_clients=NUM_CLIENTS) | |
| runner.set_up_clients() | |
| runner.prepare_data() | |
| runner.register_data() | |
| runner.register_metric() | |
| runner.algorithm = algo_generator( | |
| model=model, | |
| criterion=criterion, | |
| optimizer=optimizer, | |
| index_generator=runner.index_generator, | |
| dataset=TorchDataset, | |
| seed=seed | |
| )() | |
| runner.strategy = FedAvg() | |
| runner.set_aggregation() | |
| runner.set_testing() | |
| runner.run_compute_plan() | |