|
|
import argparse
|
|
|
import flwr as fl
|
|
|
import torch
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
import seaborn as sns
|
|
|
from typing import List, Tuple, Dict, Optional
|
|
|
from flwr.common import Metrics, ndarrays_to_parameters
|
|
|
from pathlib import Path
|
|
|
import json
|
|
|
from datetime import datetime
|
|
|
|
|
|
from utils import Net
|
|
|
|
|
|
|
|
|
plt.style.use('seaborn-v0_8-darkgrid')
|
|
|
sns.set_palette("husl")
|
|
|
|
|
|
|
|
|
STRATEGIES = {
|
|
|
"fedavg": fl.server.strategy.FedAvg,
|
|
|
"fedadam": fl.server.strategy.FedAdam,
|
|
|
"fedyogi": fl.server.strategy.FedYogi,
|
|
|
"fedadagrad": fl.server.strategy.FedAdagrad,
|
|
|
}
|
|
|
|
|
|
class MetricsCollector:
|
|
|
"""Coleta e organiza métricas de treinamento e validação."""
|
|
|
|
|
|
def __init__(self, strategy_name):
|
|
|
self.strategy_name = strategy_name
|
|
|
self.train_metrics = {
|
|
|
"rounds": [],
|
|
|
"global_train_loss": [],
|
|
|
"client_1_train_loss": [],
|
|
|
"client_2_train_loss": [],
|
|
|
"client_3_train_loss": []
|
|
|
}
|
|
|
self.eval_metrics = {
|
|
|
"rounds": [],
|
|
|
"global_eval_loss": [],
|
|
|
"client_1_eval_loss": [],
|
|
|
"client_2_eval_loss": [],
|
|
|
"client_3_eval_loss": []
|
|
|
}
|
|
|
self.convergence_metrics = {
|
|
|
"rounds": [],
|
|
|
"loss_variance": [],
|
|
|
"loss_std": [],
|
|
|
"max_min_diff": []
|
|
|
}
|
|
|
|
|
|
def add_train_round(self, round_num, metrics):
|
|
|
"""Adiciona métricas de uma rodada de treinamento."""
|
|
|
self.train_metrics["rounds"].append(round_num)
|
|
|
for key, value in metrics.items():
|
|
|
if key in self.train_metrics:
|
|
|
self.train_metrics[key].append(value)
|
|
|
|
|
|
def add_eval_round(self, round_num, metrics):
|
|
|
"""Adiciona métricas de uma rodada de avaliação."""
|
|
|
self.eval_metrics["rounds"].append(round_num)
|
|
|
for key, value in metrics.items():
|
|
|
if key in self.eval_metrics:
|
|
|
self.eval_metrics[key].append(value)
|
|
|
|
|
|
def calculate_convergence_metrics(self, client_losses):
|
|
|
"""Calcula métricas de convergência entre clientes."""
|
|
|
if len(client_losses) > 0:
|
|
|
variance = np.var(client_losses)
|
|
|
std_dev = np.std(client_losses)
|
|
|
max_min_diff = max(client_losses) - min(client_losses)
|
|
|
return variance, std_dev, max_min_diff
|
|
|
return 0, 0, 0
|
|
|
|
|
|
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, any]:
|
|
|
"""Função de agregação para métricas de avaliação."""
|
|
|
individual_losses = {}
|
|
|
|
|
|
for num_examples, m in metrics:
|
|
|
client_id = m.get("client_id", 0)
|
|
|
if client_id > 0:
|
|
|
individual_losses[f"client_{client_id}_eval_loss"] = m["loss"]
|
|
|
|
|
|
|
|
|
losses = [num_examples * m["loss"] for num_examples, m in metrics]
|
|
|
examples = [num_examples for num_examples, _ in metrics]
|
|
|
avg_loss = sum(losses) / sum(examples) if sum(examples) > 0 else 0
|
|
|
|
|
|
aggregated_metrics = {"global_eval_loss": avg_loss}
|
|
|
aggregated_metrics.update(individual_losses)
|
|
|
|
|
|
return aggregated_metrics
|
|
|
|
|
|
def fit_metrics_aggregator(metrics: List[Tuple[int, Metrics]]) -> Dict[str, float]:
|
|
|
"""Agrega métricas de treinamento."""
|
|
|
individual_losses = {}
|
|
|
|
|
|
for num_examples, m in metrics:
|
|
|
client_id = m.get("client_id", 0)
|
|
|
if client_id > 0:
|
|
|
individual_losses[f"client_{client_id}_train_loss"] = m["train_loss"]
|
|
|
|
|
|
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
|
examples = [num_examples for num_examples, _ in metrics]
|
|
|
|
|
|
avg_loss = sum(losses) / sum(examples) if sum(examples) > 0 else 0
|
|
|
|
|
|
aggregated_metrics = {"global_train_loss": avg_loss}
|
|
|
aggregated_metrics.update(individual_losses)
|
|
|
|
|
|
return aggregated_metrics
|
|
|
|
|
|
def create_visualizations(collector: MetricsCollector, output_dir: Path):
|
|
|
"""Cria todas as visualizações de desempenho."""
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
|
|
|
|
|
|
|
|
ax1 = axes[0, 0]
|
|
|
rounds = collector.train_metrics["rounds"]
|
|
|
ax1.plot(rounds, collector.train_metrics["global_train_loss"],
|
|
|
'b-', marker='s', label='Treino Global', linewidth=2)
|
|
|
ax1.plot(collector.eval_metrics["rounds"], collector.eval_metrics["global_eval_loss"],
|
|
|
'r-', marker='o', label='Validação Global', linewidth=2)
|
|
|
ax1.set_title('Desempenho do Modelo Global', fontsize=14, fontweight='bold')
|
|
|
ax1.set_xlabel('Rodada')
|
|
|
ax1.set_ylabel('Perda (MSE)')
|
|
|
ax1.legend()
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
ax2 = axes[0, 1]
|
|
|
colors = ['#2E7D32', '#1565C0', '#E65100']
|
|
|
for i in range(1, 4):
|
|
|
key = f"client_{i}_train_loss"
|
|
|
if key in collector.train_metrics and collector.train_metrics[key]:
|
|
|
ax2.plot(rounds, collector.train_metrics[key],
|
|
|
marker='o', label=f'Cliente {i}', color=colors[i-1], linewidth=1.5)
|
|
|
ax2.plot(rounds, collector.train_metrics["global_train_loss"],
|
|
|
'k--', label='Média Global', linewidth=2, alpha=0.7)
|
|
|
ax2.set_title('Perda de Treinamento por Cliente', fontsize=14, fontweight='bold')
|
|
|
ax2.set_xlabel('Rodada')
|
|
|
ax2.set_ylabel('Perda de Treino (MSE)')
|
|
|
ax2.legend()
|
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
ax3 = axes[1, 0]
|
|
|
for i in range(1, 4):
|
|
|
key = f"client_{i}_eval_loss"
|
|
|
if key in collector.eval_metrics and collector.eval_metrics[key]:
|
|
|
ax3.plot(collector.eval_metrics["rounds"], collector.eval_metrics[key],
|
|
|
marker='s', label=f'Cliente {i}', color=colors[i-1], linewidth=1.5)
|
|
|
ax3.plot(collector.eval_metrics["rounds"], collector.eval_metrics["global_eval_loss"],
|
|
|
'k--', label='Média Global', linewidth=2, alpha=0.7)
|
|
|
ax3.set_title('Perda de Validação por Cliente', fontsize=14, fontweight='bold')
|
|
|
ax3.set_xlabel('Rodada')
|
|
|
ax3.set_ylabel('Perda de Validação (MSE)')
|
|
|
ax3.legend()
|
|
|
ax3.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
ax4 = axes[1, 1]
|
|
|
if len(rounds) > 1:
|
|
|
train_improvement = np.diff(collector.train_metrics["global_train_loss"])
|
|
|
eval_improvement = np.diff(collector.eval_metrics["global_eval_loss"])
|
|
|
ax4.plot(rounds[1:], train_improvement, 'g-', marker='v', label='Δ Treino', linewidth=1.5)
|
|
|
ax4.plot(collector.eval_metrics["rounds"][1:], eval_improvement,
|
|
|
'm-', marker='^', label='Δ Validação', linewidth=1.5)
|
|
|
ax4.axhline(y=0, color='k', linestyle='--', alpha=0.5)
|
|
|
ax4.set_title('Taxa de Melhoria (Δ Perda)', fontsize=14, fontweight='bold')
|
|
|
ax4.set_xlabel('Rodada')
|
|
|
ax4.set_ylabel('Mudança na Perda')
|
|
|
ax4.legend()
|
|
|
ax4.grid(True, alpha=0.3)
|
|
|
|
|
|
plt.suptitle(f'Análise de Desempenho - Estratégia: {collector.strategy_name.upper()}',
|
|
|
fontsize=16, fontweight='bold')
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(output_dir / f'performance_analysis_{collector.strategy_name}.pdf', dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
|
|
|
|
|
|
|
|
for round_idx, round_num in enumerate(rounds):
|
|
|
client_losses = []
|
|
|
for i in range(1, 4):
|
|
|
key = f"client_{i}_eval_loss"
|
|
|
if key in collector.eval_metrics and round_idx < len(collector.eval_metrics[key]):
|
|
|
client_losses.append(collector.eval_metrics[key][round_idx])
|
|
|
|
|
|
if client_losses:
|
|
|
var, std, diff = collector.calculate_convergence_metrics(client_losses)
|
|
|
collector.convergence_metrics["rounds"].append(round_num)
|
|
|
collector.convergence_metrics["loss_variance"].append(var)
|
|
|
collector.convergence_metrics["loss_std"].append(std)
|
|
|
collector.convergence_metrics["max_min_diff"].append(diff)
|
|
|
|
|
|
|
|
|
ax1 = axes[0]
|
|
|
ax1.plot(collector.convergence_metrics["rounds"],
|
|
|
collector.convergence_metrics["loss_variance"],
|
|
|
'b-', marker='o', linewidth=2)
|
|
|
ax1.fill_between(collector.convergence_metrics["rounds"],
|
|
|
collector.convergence_metrics["loss_variance"],
|
|
|
alpha=0.3)
|
|
|
ax1.set_title('Variância da Perda entre Clientes', fontsize=14, fontweight='bold')
|
|
|
ax1.set_xlabel('Rodada')
|
|
|
ax1.set_ylabel('Variância')
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
ax2 = axes[1]
|
|
|
ax2.plot(collector.convergence_metrics["rounds"],
|
|
|
collector.convergence_metrics["loss_std"],
|
|
|
'g-', marker='s', linewidth=2)
|
|
|
ax2.fill_between(collector.convergence_metrics["rounds"],
|
|
|
collector.convergence_metrics["loss_std"],
|
|
|
alpha=0.3, color='green')
|
|
|
ax2.set_title('Desvio Padrão da Perda entre Clientes', fontsize=14, fontweight='bold')
|
|
|
ax2.set_xlabel('Rodada')
|
|
|
ax2.set_ylabel('Desvio Padrão')
|
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
ax3 = axes[2]
|
|
|
ax3.plot(collector.convergence_metrics["rounds"],
|
|
|
collector.convergence_metrics["max_min_diff"],
|
|
|
'r-', marker='^', linewidth=2)
|
|
|
ax3.fill_between(collector.convergence_metrics["rounds"],
|
|
|
collector.convergence_metrics["max_min_diff"],
|
|
|
alpha=0.3, color='red')
|
|
|
ax3.set_title('Diferença Máx-Mín entre Clientes', fontsize=14, fontweight='bold')
|
|
|
ax3.set_xlabel('Rodada')
|
|
|
ax3.set_ylabel('Diferença')
|
|
|
ax3.grid(True, alpha=0.3)
|
|
|
|
|
|
plt.suptitle(f'Análise de Convergência e Heterogeneidade - {collector.strategy_name.upper()}',
|
|
|
fontsize=16, fontweight='bold')
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(output_dir / f'convergence_analysis_{collector.strategy_name}.pdf', dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
|
|
|
|
|
|
|
heatmap_data = []
|
|
|
for i in range(1, 4):
|
|
|
key = f"client_{i}_eval_loss"
|
|
|
if key in collector.eval_metrics:
|
|
|
heatmap_data.append(collector.eval_metrics[key])
|
|
|
|
|
|
if heatmap_data:
|
|
|
heatmap_array = np.array(heatmap_data)
|
|
|
im = ax.imshow(heatmap_array, aspect='auto', cmap='RdYlGn_r')
|
|
|
|
|
|
ax.set_xticks(range(len(collector.eval_metrics["rounds"])))
|
|
|
ax.set_xticklabels(collector.eval_metrics["rounds"])
|
|
|
ax.set_yticks(range(3))
|
|
|
ax.set_yticklabels([f'Cliente {i}' for i in range(1, 4)])
|
|
|
|
|
|
ax.set_xlabel('Rodada', fontsize=12)
|
|
|
ax.set_title(f'Mapa de Calor - Perda de Validação por Cliente - {collector.strategy_name.upper()}',
|
|
|
fontsize=14, fontweight='bold')
|
|
|
|
|
|
|
|
|
for i in range(3):
|
|
|
for j in range(len(collector.eval_metrics["rounds"])):
|
|
|
if j < len(heatmap_data[i]):
|
|
|
text = ax.text(j, i, f'{heatmap_data[i][j]:.4f}',
|
|
|
ha="center", va="center", color="black", fontsize=8)
|
|
|
|
|
|
plt.colorbar(im, ax=ax, label='Perda (MSE)')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(output_dir / f'heatmap_performance_{collector.strategy_name}.pdf', dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
print(f"Visualizações salvas em {output_dir}")
|
|
|
|
|
|
def save_detailed_metrics(collector: MetricsCollector, output_dir: Path):
|
|
|
"""Salva métricas detalhadas em diferentes formatos."""
|
|
|
|
|
|
|
|
|
all_metrics = pd.DataFrame()
|
|
|
|
|
|
|
|
|
if collector.train_metrics["rounds"]:
|
|
|
train_df = pd.DataFrame(collector.train_metrics)
|
|
|
train_df['phase'] = 'train'
|
|
|
all_metrics = pd.concat([all_metrics, train_df], ignore_index=True)
|
|
|
|
|
|
|
|
|
if collector.eval_metrics["rounds"]:
|
|
|
eval_df = pd.DataFrame(collector.eval_metrics)
|
|
|
eval_df['phase'] = 'eval'
|
|
|
all_metrics = pd.concat([all_metrics, eval_df], ignore_index=True)
|
|
|
|
|
|
|
|
|
csv_file = output_dir / f'detailed_metrics_{collector.strategy_name}.csv'
|
|
|
all_metrics.to_csv(csv_file, index=False)
|
|
|
print(f"Métricas detalhadas salvas em {csv_file}")
|
|
|
|
|
|
|
|
|
stats = {
|
|
|
"strategy": collector.strategy_name,
|
|
|
"total_rounds": len(collector.train_metrics["rounds"]),
|
|
|
"final_global_train_loss": collector.train_metrics["global_train_loss"][-1] if collector.train_metrics["global_train_loss"] else None,
|
|
|
"final_global_eval_loss": collector.eval_metrics["global_eval_loss"][-1] if collector.eval_metrics["global_eval_loss"] else None,
|
|
|
"train_improvement": (collector.train_metrics["global_train_loss"][0] - collector.train_metrics["global_train_loss"][-1]) if len(collector.train_metrics["global_train_loss"]) > 1 else 0,
|
|
|
"eval_improvement": (collector.eval_metrics["global_eval_loss"][0] - collector.eval_metrics["global_eval_loss"][-1]) if len(collector.eval_metrics["global_eval_loss"]) > 1 else 0,
|
|
|
"convergence_metrics": collector.convergence_metrics,
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
}
|
|
|
|
|
|
json_file = output_dir / f'analysis_{collector.strategy_name}.json'
|
|
|
with open(json_file, 'w') as f:
|
|
|
json.dump(stats, f, indent=2)
|
|
|
print(f"Análise estatística salva em {json_file}")
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Flower Server Enhanced")
|
|
|
parser.add_argument(
|
|
|
"--strategy", type=str, choices=list(STRATEGIES.keys()), default="fedavg",
|
|
|
help="Estratégia de agregação a ser usada."
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--rounds", type=int, default=10, help="Número de rodadas de FL."
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--min-clients", type=int, default=3, help="Número mínimo de clientes."
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--prediction-length", type=int, default=10, help="Tamanho da previsão."
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
print(f"Tamanho da Previsão: {args.prediction_length}")
|
|
|
|
|
|
|
|
|
net = Net(input_size=6, hidden_size=50, output_size=args.prediction_length)
|
|
|
initial_parameters = [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
|
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"SERVIDOR DE APRENDIZADO FEDERADO")
|
|
|
print(f"{'='*60}")
|
|
|
print(f"Estratégia: {args.strategy.upper()}")
|
|
|
print(f"Rodadas: {args.rounds}")
|
|
|
print(f"Clientes mínimos: {args.min_clients}")
|
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
|
|
|
output_dir = Path("results")
|
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
|
collector = MetricsCollector(args.strategy)
|
|
|
|
|
|
|
|
|
def enhanced_weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, any]:
|
|
|
result = weighted_average(metrics)
|
|
|
|
|
|
round_num = max([m.get("round", 0) for _, m in metrics])
|
|
|
if round_num > 0:
|
|
|
collector.add_eval_round(round_num, result)
|
|
|
return result
|
|
|
|
|
|
def enhanced_fit_aggregator(metrics: List[Tuple[int, Metrics]]) -> Dict[str, float]:
|
|
|
result = fit_metrics_aggregator(metrics)
|
|
|
|
|
|
round_num = max([m.get("round", 0) for _, m in metrics])
|
|
|
if round_num > 0:
|
|
|
collector.add_train_round(round_num, result)
|
|
|
return result
|
|
|
|
|
|
|
|
|
strategy_params = {
|
|
|
"min_fit_clients": args.min_clients,
|
|
|
"min_available_clients": args.min_clients,
|
|
|
"evaluate_metrics_aggregation_fn": enhanced_weighted_average,
|
|
|
"fit_metrics_aggregation_fn": enhanced_fit_aggregator,
|
|
|
}
|
|
|
|
|
|
if args.strategy != "fedavg":
|
|
|
strategy_params["initial_parameters"] = ndarrays_to_parameters(initial_parameters)
|
|
|
|
|
|
strategy = STRATEGIES[args.strategy](**strategy_params)
|
|
|
|
|
|
|
|
|
|
|
|
print("Iniciando servidor FL...")
|
|
|
history = fl.server.start_server(
|
|
|
server_address="0.0.0.0:8080",
|
|
|
config=fl.server.ServerConfig(num_rounds=args.rounds),
|
|
|
strategy=strategy,
|
|
|
)
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("TREINAMENTO CONCLUÍDO - GERANDO ANÁLISES")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
create_visualizations(collector, output_dir)
|
|
|
|
|
|
|
|
|
save_detailed_metrics(collector, output_dir)
|
|
|
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("RESUMO DO TREINAMENTO")
|
|
|
print("="*60)
|
|
|
|
|
|
if collector.train_metrics["global_train_loss"]:
|
|
|
initial_loss = collector.train_metrics["global_train_loss"][0]
|
|
|
final_loss = collector.train_metrics["global_train_loss"][-1]
|
|
|
improvement = ((initial_loss - final_loss) / initial_loss) * 100
|
|
|
|
|
|
print(f"Perda inicial de treino: {initial_loss:.6f}")
|
|
|
print(f"Perda final de treino: {final_loss:.6f}")
|
|
|
print(f"Melhoria no treino: {improvement:.2f}%")
|
|
|
|
|
|
if collector.eval_metrics["global_eval_loss"]:
|
|
|
initial_eval = collector.eval_metrics["global_eval_loss"][0]
|
|
|
final_eval = collector.eval_metrics["global_eval_loss"][-1]
|
|
|
eval_improvement = ((initial_eval - final_eval) / initial_eval) * 100
|
|
|
|
|
|
print(f"\nPerda inicial de validação: {initial_eval:.6f}")
|
|
|
print(f"Perda final de validação: {final_eval:.6f}")
|
|
|
print(f"Melhoria na validação: {eval_improvement:.2f}%")
|
|
|
|
|
|
|
|
|
if collector.convergence_metrics["loss_std"]:
|
|
|
final_std = collector.convergence_metrics["loss_std"][-1]
|
|
|
print(f"\nDesvio padrão final entre clientes: {final_std:.6f}")
|
|
|
print(f"Convergência: {'Boa' if final_std < 0.01 else 'Moderada' if final_std < 0.05 else 'Baixa'}")
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print(f"Resultados salvos em: {output_dir}")
|
|
|
print("="*60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |