Spaces:
Sleeping
Sleeping
| # federated_utils.py | |
| import copy | |
| import torch | |
| import tempfile | |
| from typing import List, Optional, Tuple | |
| from ultralytics import YOLO | |
| import os | |
| import numpy as np | |
| def average_state_dicts(state_dicts: List[dict]) -> dict: | |
| keys = list(state_dicts[0].keys()) | |
| for sd in state_dicts: | |
| assert list(sd.keys()) == keys, "state_dict keys mismatch" | |
| avg = {} | |
| for k in keys: | |
| s = sum(sd[k].cpu().float() for sd in state_dicts) | |
| avg[k] = (s / len(state_dicts)).to(state_dicts[0][k].device) | |
| return avg | |
| def load_model_state(path: str) -> dict: | |
| model = YOLO(path) | |
| sd = model.model.state_dict() | |
| return {k: v.cpu() for k, v in sd.items()} | |
| def save_state_dict(sd: dict, out_path: str): | |
| torch.save(sd, out_path) | |
| def train_client_local(client_data_yaml: str, | |
| init_weights_path: str, | |
| epochs: int = 1, | |
| batch: int = 8, | |
| imgsz: int = 640, | |
| device: Optional[str] = None) -> dict: | |
| """用 ultralytics 进行极短本地训练,返回 state_dict(CPU)。""" | |
| device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| model = YOLO(init_weights_path) | |
| # 轻量训练;若你有更具体的超参可以改这里 | |
| model.train(data=client_data_yaml, epochs=epochs, batch=batch, imgsz=imgsz, device=device, save=False, verbose=False) | |
| sd = model.model.state_dict() | |
| return {k: v.cpu() for k, v in sd.items()} | |
| def run_fedavg_from_checkpoints(client_ckpt_paths: List[str], out_global_path: str) -> str: | |
| sds = [] | |
| for p in client_ckpt_paths: | |
| try: | |
| tmp = torch.load(p, map_location='cpu') | |
| if isinstance(tmp, dict) and 'model' in tmp: | |
| sd = tmp['model'] | |
| else: | |
| sd = tmp | |
| sd_cpu = {k: v.cpu() for k, v in sd.items()} | |
| except Exception: | |
| # 如果直接load失败,就用YOLO包装加载 | |
| model = YOLO(p) | |
| sd_cpu = {k: v.cpu() for k, v in model.model.state_dict().items()} | |
| sds.append(sd_cpu) | |
| avg = average_state_dicts(sds) | |
| torch.save(avg, out_global_path) | |
| return out_global_path | |
| def run_federated_simulation(client_data_yaml_list: List[str], | |
| init_global_weights: str, | |
| rounds: int = 1, | |
| local_epochs: int = 1, | |
| device: Optional[str] = None, | |
| imgsz: int = 640, | |
| batch: int = 8) -> Tuple[str, List[str]]: | |
| """ | |
| 超轻量联邦模拟(同步执行): | |
| for r in rounds: | |
| 对每个 client: 从当前全局权重做本地小步训练 → 得到 client state_dict | |
| 平均得到新的全局 → 保存 global_round_r.pt | |
| 返回:(初始全局路径, [每轮全局路径列表]) | |
| """ | |
| device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| global_state = load_model_state(init_global_weights) | |
| round_paths: List[str] = [] | |
| for r in range(rounds): | |
| client_states = [] | |
| for i, cyaml in enumerate(client_data_yaml_list): | |
| # 将当前全局state写临时文件供 YOLO 加载 | |
| init_tmp = f"__tmp_global_r{r}_c{i}.pt" | |
| torch.save(global_state, init_tmp) | |
| try: | |
| sd_client = train_client_local( | |
| cyaml, init_tmp, epochs=local_epochs, batch=batch, imgsz=imgsz, device=device | |
| ) | |
| except Exception as e: | |
| print(f"[WARN] client {i} local train failed: {e}; fallback to global") | |
| sd_client = load_model_state(init_tmp) | |
| client_states.append(sd_client) | |
| # FedAvg | |
| new_global = average_state_dicts(client_states) | |
| out_path = f"global_round_{r}.pt" | |
| save_state_dict(new_global, out_path) | |
| round_paths.append(out_path) | |
| global_state = new_global | |
| return init_global_weights, round_paths | |