File size: 1,916 Bytes
906d507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import flwr as fl
import numpy as np
import json
import pandas as pd
from model import create_model
import sys

results_log = []


def get_evaluate_fn(file_path):
    data = pd.read_csv(file_path)
    X = data.iloc[:, :-1].values
    y = data.iloc[:, -1].values
    X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
    split = int(len(X) * 0.8)
    X_test, y_test = X[split:], y[split:]
    input_shape = X.shape[1]

    # determine model name from file
    if "diabetes" in file_path:
        model_name = "diabetes_model.keras"
    elif "heart" in file_path:
        model_name = "heart_model.keras"
    else:
        model_name = "global_model.keras"

    def evaluate(server_round, parameters, config):
        model = create_model(input_shape=input_shape)
        model.set_weights(parameters)
        loss, accuracy = model.evaluate(X_test, y_test, verbose=0)

        results_log.append({
            "round": server_round,
            "loss": round(loss, 4),
            "accuracy": round(accuracy, 4)
        })
        with open("training_results.json", "w") as f:
            json.dump(results_log, f)

        if server_round == 3:
            model.save(model_name)
            print(f"Model saved as {model_name}")

        return loss, {"accuracy": accuracy}

    return evaluate


def main(file_path="diabetes.csv"):
    global results_log
    results_log = []

    with open("training_results.json", "w") as f:
        json.dump([], f)

    strategy = fl.server.strategy.FedAvg(
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(file_path),
    )

    fl.server.start_server(
        server_address="localhost:8080",
        config=fl.server.ServerConfig(num_rounds=3),
        strategy=strategy,
    )


if __name__ == "__main__":
    file_path = sys.argv[1] if len(sys.argv) > 1 else "diabetes.csv"
    main(file_path)