|
|
|
|
|
import copy |
|
|
import torch |
|
|
import tempfile |
|
|
from typing import List, Optional, Tuple |
|
|
from ultralytics import YOLO |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
def average_state_dicts(state_dicts: List[dict]) -> dict: |
|
|
keys = list(state_dicts[0].keys()) |
|
|
for sd in state_dicts: |
|
|
assert list(sd.keys()) == keys, "state_dict keys mismatch" |
|
|
avg = {} |
|
|
for k in keys: |
|
|
s = sum(sd[k].cpu().float() for sd in state_dicts) |
|
|
avg[k] = (s / len(state_dicts)).to(state_dicts[0][k].device) |
|
|
return avg |
|
|
|
|
|
def load_model_state(path: str) -> dict: |
|
|
model = YOLO(path) |
|
|
sd = model.model.state_dict() |
|
|
return {k: v.cpu() for k, v in sd.items()} |
|
|
|
|
|
def save_state_dict(sd: dict, out_path: str): |
|
|
torch.save(sd, out_path) |
|
|
|
|
|
def train_client_local(client_data_yaml: str, |
|
|
init_weights_path: str, |
|
|
epochs: int = 1, |
|
|
batch: int = 8, |
|
|
imgsz: int = 640, |
|
|
device: Optional[str] = None) -> dict: |
|
|
"""用 ultralytics 进行极短本地训练,返回 state_dict(CPU)。""" |
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = YOLO(init_weights_path) |
|
|
|
|
|
model.train(data=client_data_yaml, epochs=epochs, batch=batch, imgsz=imgsz, device=device, save=False, verbose=False) |
|
|
sd = model.model.state_dict() |
|
|
return {k: v.cpu() for k, v in sd.items()} |
|
|
|
|
|
def run_fedavg_from_checkpoints(client_ckpt_paths: List[str], out_global_path: str) -> str: |
|
|
sds = [] |
|
|
for p in client_ckpt_paths: |
|
|
try: |
|
|
tmp = torch.load(p, map_location='cpu') |
|
|
if isinstance(tmp, dict) and 'model' in tmp: |
|
|
sd = tmp['model'] |
|
|
else: |
|
|
sd = tmp |
|
|
sd_cpu = {k: v.cpu() for k, v in sd.items()} |
|
|
except Exception: |
|
|
|
|
|
model = YOLO(p) |
|
|
sd_cpu = {k: v.cpu() for k, v in model.model.state_dict().items()} |
|
|
sds.append(sd_cpu) |
|
|
avg = average_state_dicts(sds) |
|
|
torch.save(avg, out_global_path) |
|
|
return out_global_path |
|
|
|
|
|
def run_federated_simulation(client_data_yaml_list: List[str], |
|
|
init_global_weights: str, |
|
|
rounds: int = 1, |
|
|
local_epochs: int = 1, |
|
|
device: Optional[str] = None, |
|
|
imgsz: int = 640, |
|
|
batch: int = 8) -> Tuple[str, List[str]]: |
|
|
""" |
|
|
超轻量联邦模拟(同步执行): |
|
|
for r in rounds: |
|
|
对每个 client: 从当前全局权重做本地小步训练 → 得到 client state_dict |
|
|
平均得到新的全局 → 保存 global_round_r.pt |
|
|
返回:(初始全局路径, [每轮全局路径列表]) |
|
|
""" |
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
global_state = load_model_state(init_global_weights) |
|
|
round_paths: List[str] = [] |
|
|
for r in range(rounds): |
|
|
client_states = [] |
|
|
for i, cyaml in enumerate(client_data_yaml_list): |
|
|
|
|
|
init_tmp = f"__tmp_global_r{r}_c{i}.pt" |
|
|
torch.save(global_state, init_tmp) |
|
|
try: |
|
|
sd_client = train_client_local( |
|
|
cyaml, init_tmp, epochs=local_epochs, batch=batch, imgsz=imgsz, device=device |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"[WARN] client {i} local train failed: {e}; fallback to global") |
|
|
sd_client = load_model_state(init_tmp) |
|
|
client_states.append(sd_client) |
|
|
|
|
|
new_global = average_state_dicts(client_states) |
|
|
out_path = f"global_round_{r}.pt" |
|
|
save_state_dict(new_global, out_path) |
|
|
round_paths.append(out_path) |
|
|
global_state = new_global |
|
|
return init_global_weights, round_paths |
|
|
|