# federated_utils.py import copy import torch import tempfile from typing import List, Optional, Tuple from ultralytics import YOLO import os import numpy as np def average_state_dicts(state_dicts: List[dict]) -> dict: keys = list(state_dicts[0].keys()) for sd in state_dicts: assert list(sd.keys()) == keys, "state_dict keys mismatch" avg = {} for k in keys: s = sum(sd[k].cpu().float() for sd in state_dicts) avg[k] = (s / len(state_dicts)).to(state_dicts[0][k].device) return avg def load_model_state(path: str) -> dict: model = YOLO(path) sd = model.model.state_dict() return {k: v.cpu() for k, v in sd.items()} def save_state_dict(sd: dict, out_path: str): torch.save(sd, out_path) def train_client_local(client_data_yaml: str, init_weights_path: str, epochs: int = 1, batch: int = 8, imgsz: int = 640, device: Optional[str] = None) -> dict: """用 ultralytics 进行极短本地训练,返回 state_dict(CPU)。""" device = device or ("cuda" if torch.cuda.is_available() else "cpu") model = YOLO(init_weights_path) # 轻量训练;若你有更具体的超参可以改这里 model.train(data=client_data_yaml, epochs=epochs, batch=batch, imgsz=imgsz, device=device, save=False, verbose=False) sd = model.model.state_dict() return {k: v.cpu() for k, v in sd.items()} def run_fedavg_from_checkpoints(client_ckpt_paths: List[str], out_global_path: str) -> str: sds = [] for p in client_ckpt_paths: try: tmp = torch.load(p, map_location='cpu') if isinstance(tmp, dict) and 'model' in tmp: sd = tmp['model'] else: sd = tmp sd_cpu = {k: v.cpu() for k, v in sd.items()} except Exception: # 如果直接load失败,就用YOLO包装加载 model = YOLO(p) sd_cpu = {k: v.cpu() for k, v in model.model.state_dict().items()} sds.append(sd_cpu) avg = average_state_dicts(sds) torch.save(avg, out_global_path) return out_global_path def run_federated_simulation(client_data_yaml_list: List[str], init_global_weights: str, rounds: int = 1, local_epochs: int = 1, device: Optional[str] = None, imgsz: int = 640, batch: int = 8) -> Tuple[str, List[str]]: """ 超轻量联邦模拟(同步执行): for r in rounds: 对每个 client: 从当前全局权重做本地小步训练 → 得到 client state_dict 平均得到新的全局 → 保存 global_round_r.pt 返回:(初始全局路径, [每轮全局路径列表]) """ device = device or ("cuda" if torch.cuda.is_available() else "cpu") global_state = load_model_state(init_global_weights) round_paths: List[str] = [] for r in range(rounds): client_states = [] for i, cyaml in enumerate(client_data_yaml_list): # 将当前全局state写临时文件供 YOLO 加载 init_tmp = f"__tmp_global_r{r}_c{i}.pt" torch.save(global_state, init_tmp) try: sd_client = train_client_local( cyaml, init_tmp, epochs=local_epochs, batch=batch, imgsz=imgsz, device=device ) except Exception as e: print(f"[WARN] client {i} local train failed: {e}; fallback to global") sd_client = load_model_state(init_tmp) client_states.append(sd_client) # FedAvg new_global = average_state_dicts(client_states) out_path = f"global_round_{r}.pt" save_state_dict(new_global, out_path) round_paths.append(out_path) global_state = new_global return init_global_weights, round_paths