FedAdv / federated_utils.py
MarshallCN
add targeted attack
685f135
# federated_utils.py
import copy
import torch
import tempfile
from typing import List, Optional, Tuple
from ultralytics import YOLO
import os
import numpy as np
def average_state_dicts(state_dicts: List[dict]) -> dict:
keys = list(state_dicts[0].keys())
for sd in state_dicts:
assert list(sd.keys()) == keys, "state_dict keys mismatch"
avg = {}
for k in keys:
s = sum(sd[k].cpu().float() for sd in state_dicts)
avg[k] = (s / len(state_dicts)).to(state_dicts[0][k].device)
return avg
def load_model_state(path: str) -> dict:
model = YOLO(path)
sd = model.model.state_dict()
return {k: v.cpu() for k, v in sd.items()}
def save_state_dict(sd: dict, out_path: str):
torch.save(sd, out_path)
def train_client_local(client_data_yaml: str,
init_weights_path: str,
epochs: int = 1,
batch: int = 8,
imgsz: int = 640,
device: Optional[str] = None) -> dict:
"""用 ultralytics 进行极短本地训练,返回 state_dict(CPU)。"""
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model = YOLO(init_weights_path)
# 轻量训练;若你有更具体的超参可以改这里
model.train(data=client_data_yaml, epochs=epochs, batch=batch, imgsz=imgsz, device=device, save=False, verbose=False)
sd = model.model.state_dict()
return {k: v.cpu() for k, v in sd.items()}
def run_fedavg_from_checkpoints(client_ckpt_paths: List[str], out_global_path: str) -> str:
sds = []
for p in client_ckpt_paths:
try:
tmp = torch.load(p, map_location='cpu')
if isinstance(tmp, dict) and 'model' in tmp:
sd = tmp['model']
else:
sd = tmp
sd_cpu = {k: v.cpu() for k, v in sd.items()}
except Exception:
# 如果直接load失败,就用YOLO包装加载
model = YOLO(p)
sd_cpu = {k: v.cpu() for k, v in model.model.state_dict().items()}
sds.append(sd_cpu)
avg = average_state_dicts(sds)
torch.save(avg, out_global_path)
return out_global_path
def run_federated_simulation(client_data_yaml_list: List[str],
init_global_weights: str,
rounds: int = 1,
local_epochs: int = 1,
device: Optional[str] = None,
imgsz: int = 640,
batch: int = 8) -> Tuple[str, List[str]]:
"""
超轻量联邦模拟(同步执行):
for r in rounds:
对每个 client: 从当前全局权重做本地小步训练 → 得到 client state_dict
平均得到新的全局 → 保存 global_round_r.pt
返回:(初始全局路径, [每轮全局路径列表])
"""
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
global_state = load_model_state(init_global_weights)
round_paths: List[str] = []
for r in range(rounds):
client_states = []
for i, cyaml in enumerate(client_data_yaml_list):
# 将当前全局state写临时文件供 YOLO 加载
init_tmp = f"__tmp_global_r{r}_c{i}.pt"
torch.save(global_state, init_tmp)
try:
sd_client = train_client_local(
cyaml, init_tmp, epochs=local_epochs, batch=batch, imgsz=imgsz, device=device
)
except Exception as e:
print(f"[WARN] client {i} local train failed: {e}; fallback to global")
sd_client = load_model_state(init_tmp)
client_states.append(sd_client)
# FedAvg
new_global = average_state_dicts(client_states)
out_path = f"global_round_{r}.pt"
save_state_dict(new_global, out_path)
round_paths.append(out_path)
global_state = new_global
return init_global_weights, round_paths