PrivaMed / client.py
NMMAITA's picture
Upload 17 files
906d507 verified
import flwr as fl
import tensorflow as tf
import numpy as np
import pandas as pd
from model import create_model
def load_data(file_path, client_id):
data = pd.read_csv(file_path)
# last column is the target
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
split = int(len(X) * 0.8)
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
size = len(X_train) // 2
return (
X_train[client_id * size:(client_id + 1) * size],
y_train[client_id * size:(client_id + 1) * size],
X_test,
y_test
)
class HospitalClient(fl.client.NumPyClient):
def __init__(self, client_id, file_path):
self.x_train, self.y_train, self.x_test, self.y_test = load_data(file_path, client_id)
self.model = create_model(input_shape=self.x_train.shape[1])
def get_parameters(self, config):
return self.model.get_weights()
def fit(self, parameters, config):
self.model.set_weights(parameters)
self.model.fit(self.x_train, self.y_train, epochs=5, verbose=0)
return self.model.get_weights(), len(self.x_train), {}
def evaluate(self, parameters, config):
self.model.set_weights(parameters)
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
return loss, len(self.x_test), {"accuracy": accuracy}
if __name__ == "__main__":
import sys
client_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
file_path = sys.argv[2] if len(sys.argv) > 2 else "diabetes.csv"
fl.client.start_numpy_client(
server_address="localhost:8080",
client=HospitalClient(client_id, file_path)
)