File size: 1,916 Bytes
906d507 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import flwr as fl
import numpy as np
import json
import pandas as pd
from model import create_model
import sys
results_log = []
def get_evaluate_fn(file_path):
data = pd.read_csv(file_path)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
split = int(len(X) * 0.8)
X_test, y_test = X[split:], y[split:]
input_shape = X.shape[1]
# determine model name from file
if "diabetes" in file_path:
model_name = "diabetes_model.keras"
elif "heart" in file_path:
model_name = "heart_model.keras"
else:
model_name = "global_model.keras"
def evaluate(server_round, parameters, config):
model = create_model(input_shape=input_shape)
model.set_weights(parameters)
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
results_log.append({
"round": server_round,
"loss": round(loss, 4),
"accuracy": round(accuracy, 4)
})
with open("training_results.json", "w") as f:
json.dump(results_log, f)
if server_round == 3:
model.save(model_name)
print(f"Model saved as {model_name}")
return loss, {"accuracy": accuracy}
return evaluate
def main(file_path="diabetes.csv"):
global results_log
results_log = []
with open("training_results.json", "w") as f:
json.dump([], f)
strategy = fl.server.strategy.FedAvg(
min_fit_clients=2,
min_evaluate_clients=2,
min_available_clients=2,
evaluate_fn=get_evaluate_fn(file_path),
)
fl.server.start_server(
server_address="localhost:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)
if __name__ == "__main__":
file_path = sys.argv[1] if len(sys.argv) > 1 else "diabetes.csv"
main(file_path) |