import io import numpy as np from PIL import Image import gradio as gr import torch import cv2 from ultralytics import YOLO import attacks # 上面那个 attacks.py,确保和 app.py 在同一目录或可 import 的包路径 import os, glob from pathlib import Path import base64 MODEL_PATH = "weights/fed_model2.pt" MODEL_PATH_C = "weights/yolov8s_3cls.pt" names = ['car', 'van', 'truck'] imgsz = 640 SAMPLE_DIR = "./images" SAMPLE_IMAGES = sorted([ p for p in glob.glob(os.path.join(SAMPLE_DIR, "*")) if os.path.splitext(p)[1].lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"] ])[:9] # 只取前4张 # Load ultralytics model (wrapper) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") yolom = YOLO(MODEL_PATH) # wrapper def iou(a, b): ax1, ay1, ax2, ay2 = a bx1, by1, bx2, by2 = b iw = max(0, min(ax2, bx2) - max(ax1, bx1)) ih = max(0, min(ay2, by2) - max(ay1, by1)) inter = iw * ih if inter <= 0: return 0.0 area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1) area_b = max(0, bx2 - bx1) * max(0, by2 - by1) return inter / (area_a + area_b - inter + 1e-9) # def center_and_diag(b): #IOU足够好 未启用 # x1, y1, x2, y2 = b # cx = 0.5 * (x1 + x2); cy = 0.5 * (y1 + y2) # diag = max(1e-9, ((x2 - x1)**2 + (y2 - y1)**2)**0.5) # area = max(0, (x2 - x1)) * max(0, (y2 - y1)) # return cx, cy, diag, area def run_detection_on_pil(img_pil: Image.Image, eval_model_state, conf: float = 0.45, GT_boxes=None): """ 推理+可视化。GT_boxes 和返回的 preds 都是: [{'xyxy': (x1,y1,x2,y2), 'cls': int, 'conf': float(optional)}] """ # ---- 1) 推理 ---- img = np.array(img_pil) eva_model = yolom if eval_model_state == "yolom" else YOLO(MODEL_PATH_C) res = eva_model.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False) r = res[0] im_out = img.copy() # 名称表(尽量稳) names = getattr(r, "names", None) if names is None and hasattr(eva_model, "model") and hasattr(eva_model.model, "names"): names = eva_model.model.names # ---- 2) 规整预测框到简单结构 ---- preds = [] try: bxs = r.boxes if bxs is not None and len(bxs) > 0: for b in bxs: xyxy = b.xyxy[0].detach().cpu().numpy().tolist() x1, y1, x2, y2 = [int(v) for v in xyxy] cls_id = int(b.cls[0].detach().cpu().numpy()) conf_score = float(b.conf[0].detach().cpu().numpy()) preds.append({'xyxy': (x1, y1, x2, y2), 'cls': cls_id, 'conf': conf_score}) except Exception as e: print("collect preds error:", e) # ---- 3) IoU 匹配 + 画框 ---- IOU_THR = 0.3 # CENTER_DIST_RATIO = 0.30 # 中心点距离 / 预测框对角线 <= 0.30 即视为同一目标 # AREA_RATIO_THR = 0.25 # 面积比例下限:min(area_p, area_g) / max(...) >= 0.25 gt_used = set() for p in preds: color = (0, 255, 0) # 同类:绿 px1, py1, px2, py2 = p['xyxy'] pname = names[p['cls']] if (names is not None and p['cls'] in getattr(names, 'keys', lambda: [])()) else ( names[p['cls']] if (isinstance(names, (list, tuple)) and 0 <= p['cls'] < len(names)) else str(p['cls']) ) label = f"{pname}:{p.get('conf', 0.0):.2f}" if GT_boxes != None: # 找 IoU 最高的未用 GT best_j, best_iou = -1, 0.0 for j, g in enumerate(GT_boxes): if j in gt_used: continue i = iou(p['xyxy'], g['xyxy']) if i > best_iou: best_iou, best_j = i, j # 颜色规则:匹配且同类=绿;匹配但异类=红; if best_iou >= IOU_THR: gt_used.add(best_j) if p['cls'] != int(GT_boxes[best_j]['cls']): color = (255, 0, 0) # 异类:红 cv2.rectangle(im_out, (px1, py1), (px2, py2), color, 2) cv2.putText(im_out, label, (px1, max(10, py1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) return Image.fromarray(im_out), preds def detect_and_attack(image, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls): if image is None: return None, None, None pil = Image.fromarray(image.astype('uint8')).convert('RGB') original_vis, GT_boxes = run_detection_on_pil(pil, eval_model_state, conf=conf, GT_boxes=None) if attack_mode == "none": return original_vis, None, None try: if attack_mode == "fgsm": adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls) elif attack_mode == "pgd": adv_pil = attacks.pgd_attack_on_detector(yolom, pil, eps=eps, alpha=alpha, iters=iters, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls) else: adv_pil = attacks.demo_random_perturbation(pil, eps=eps) except Exception as ex: print("Whitebox attack failed:", ex) adv_pil = attacks.demo_random_perturbation(pil, eps=eps) adv_path = Path("tmp") # 相对当前工作目录 adv_path.mkdir(parents=True, exist_ok=True) adv_file = adv_path / "adv.png" adv_pil.save(adv_file, format="PNG") adv_vis, _ = run_detection_on_pil(adv_pil, eval_model_state, conf=conf, GT_boxes=GT_boxes) return original_vis, adv_vis, str(adv_file) def handle_select(key): return f"key={key}, label={names[key]}" def img_to_data_url(path: str) -> str: ext = os.path.splitext(path)[1].lower() if ext in [".jpg", ".jpeg"]: mime = "image/jpeg" elif ext == ".png": mime = "image/png" else: mime = "image/octet-stream" with open(path, "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") return f"data:{mime};base64,{b64}" # Gradio UI if __name__ == "__main__": title = "Federated Adversarial Attack — FGSM/PGD Demo" logo_left_path = os.path.join("./logos", "logo_NCL.png") logo_right_path = os.path.join("./logos", "logo_EdgeAI.png") logo_left = img_to_data_url(logo_left_path) logo_right = img_to_data_url(logo_right_path) desc_html = f"""
Adversarial examples are generated locally using a client-side model’s gradients (white-box), then evaluated against the server-side aggregated (FedAvg) global model. If the perturbation transfers, it can degrade or alter the FedAvg model’s predictions on the same input image. Object detection in this demo is limited to 'car', 'van', and 'truck' classes only.