PrivaMed / server.py
NMMAITA's picture
Upload 17 files
906d507 verified
import flwr as fl
import numpy as np
import json
import pandas as pd
from model import create_model
import sys
results_log = []
def get_evaluate_fn(file_path):
data = pd.read_csv(file_path)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
split = int(len(X) * 0.8)
X_test, y_test = X[split:], y[split:]
input_shape = X.shape[1]
# determine model name from file
if "diabetes" in file_path:
model_name = "diabetes_model.keras"
elif "heart" in file_path:
model_name = "heart_model.keras"
else:
model_name = "global_model.keras"
def evaluate(server_round, parameters, config):
model = create_model(input_shape=input_shape)
model.set_weights(parameters)
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
results_log.append({
"round": server_round,
"loss": round(loss, 4),
"accuracy": round(accuracy, 4)
})
with open("training_results.json", "w") as f:
json.dump(results_log, f)
if server_round == 3:
model.save(model_name)
print(f"Model saved as {model_name}")
return loss, {"accuracy": accuracy}
return evaluate
def main(file_path="diabetes.csv"):
global results_log
results_log = []
with open("training_results.json", "w") as f:
json.dump([], f)
strategy = fl.server.strategy.FedAvg(
min_fit_clients=2,
min_evaluate_clients=2,
min_available_clients=2,
evaluate_fn=get_evaluate_fn(file_path),
)
fl.server.start_server(
server_address="localhost:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)
if __name__ == "__main__":
file_path = sys.argv[1] if len(sys.argv) > 1 else "diabetes.csv"
main(file_path)