import io import numpy as np from PIL import Image import gradio as gr import torch import cv2 from ultralytics import YOLO import attacks # 上面那个 attacks.py,确保和 app.py 在同一目录或可 import 的包路径 import os, glob from pathlib import Path import base64 MODEL_PATH = "weights/fed_model2.pt" MODEL_PATH_C = "weights/yolov8s_3cls.pt" names = ['car', 'van', 'truck'] imgsz = 640 SAMPLE_DIR = "./images" SAMPLE_IMAGES = sorted([ p for p in glob.glob(os.path.join(SAMPLE_DIR, "*")) if os.path.splitext(p)[1].lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"] ])[:9] # 只取前4张 # Load ultralytics model (wrapper) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") yolom = YOLO(MODEL_PATH) # wrapper def iou(a, b): ax1, ay1, ax2, ay2 = a bx1, by1, bx2, by2 = b iw = max(0, min(ax2, bx2) - max(ax1, bx1)) ih = max(0, min(ay2, by2) - max(ay1, by1)) inter = iw * ih if inter <= 0: return 0.0 area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1) area_b = max(0, bx2 - bx1) * max(0, by2 - by1) return inter / (area_a + area_b - inter + 1e-9) # def center_and_diag(b): #IOU足够好 未启用 # x1, y1, x2, y2 = b # cx = 0.5 * (x1 + x2); cy = 0.5 * (y1 + y2) # diag = max(1e-9, ((x2 - x1)**2 + (y2 - y1)**2)**0.5) # area = max(0, (x2 - x1)) * max(0, (y2 - y1)) # return cx, cy, diag, area def run_detection_on_pil(img_pil: Image.Image, eval_model_state, conf: float = 0.45, GT_boxes=None): """ 推理+可视化。GT_boxes 和返回的 preds 都是: [{'xyxy': (x1,y1,x2,y2), 'cls': int, 'conf': float(optional)}] """ # ---- 1) 推理 ---- img = np.array(img_pil) eva_model = yolom if eval_model_state == "yolom" else YOLO(MODEL_PATH_C) res = eva_model.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False) r = res[0] im_out = img.copy() # 名称表(尽量稳) names = getattr(r, "names", None) if names is None and hasattr(eva_model, "model") and hasattr(eva_model.model, "names"): names = eva_model.model.names # ---- 2) 规整预测框到简单结构 ---- preds = [] try: bxs = r.boxes if bxs is not None and len(bxs) > 0: for b in bxs: xyxy = b.xyxy[0].detach().cpu().numpy().tolist() x1, y1, x2, y2 = [int(v) for v in xyxy] cls_id = int(b.cls[0].detach().cpu().numpy()) conf_score = float(b.conf[0].detach().cpu().numpy()) preds.append({'xyxy': (x1, y1, x2, y2), 'cls': cls_id, 'conf': conf_score}) except Exception as e: print("collect preds error:", e) # ---- 3) IoU 匹配 + 画框 ---- IOU_THR = 0.3 # CENTER_DIST_RATIO = 0.30 # 中心点距离 / 预测框对角线 <= 0.30 即视为同一目标 # AREA_RATIO_THR = 0.25 # 面积比例下限:min(area_p, area_g) / max(...) >= 0.25 gt_used = set() for p in preds: color = (0, 255, 0) # 同类:绿 px1, py1, px2, py2 = p['xyxy'] pname = names[p['cls']] if (names is not None and p['cls'] in getattr(names, 'keys', lambda: [])()) else ( names[p['cls']] if (isinstance(names, (list, tuple)) and 0 <= p['cls'] < len(names)) else str(p['cls']) ) label = f"{pname}:{p.get('conf', 0.0):.2f}" if GT_boxes != None: # 找 IoU 最高的未用 GT best_j, best_iou = -1, 0.0 for j, g in enumerate(GT_boxes): if j in gt_used: continue i = iou(p['xyxy'], g['xyxy']) if i > best_iou: best_iou, best_j = i, j # 颜色规则:匹配且同类=绿;匹配但异类=红; if best_iou >= IOU_THR: gt_used.add(best_j) if p['cls'] != int(GT_boxes[best_j]['cls']): color = (255, 0, 0) # 异类:红 cv2.rectangle(im_out, (px1, py1), (px2, py2), color, 2) cv2.putText(im_out, label, (px1, max(10, py1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) return Image.fromarray(im_out), preds def detect_and_attack(image, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls): if image is None: return None, None, None pil = Image.fromarray(image.astype('uint8')).convert('RGB') original_vis, GT_boxes = run_detection_on_pil(pil, eval_model_state, conf=conf, GT_boxes=None) if attack_mode == "none": return original_vis, None, None try: if attack_mode == "fgsm": adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls) elif attack_mode == "pgd": adv_pil = attacks.pgd_attack_on_detector(yolom, pil, eps=eps, alpha=alpha, iters=iters, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls) else: adv_pil = attacks.demo_random_perturbation(pil, eps=eps) except Exception as ex: print("Whitebox attack failed:", ex) adv_pil = attacks.demo_random_perturbation(pil, eps=eps) adv_path = Path("tmp") # 相对当前工作目录 adv_path.mkdir(parents=True, exist_ok=True) adv_file = adv_path / "adv.png" adv_pil.save(adv_file, format="PNG") adv_vis, _ = run_detection_on_pil(adv_pil, eval_model_state, conf=conf, GT_boxes=GT_boxes) return original_vis, adv_vis, str(adv_file) def handle_select(key): return f"key={key}, label={names[key]}" def img_to_data_url(path: str) -> str: ext = os.path.splitext(path)[1].lower() if ext in [".jpg", ".jpeg"]: mime = "image/jpeg" elif ext == ".png": mime = "image/png" else: mime = "image/octet-stream" with open(path, "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") return f"data:{mime};base64,{b64}" # Gradio UI if __name__ == "__main__": title = "Federated Adversarial Attack — FGSM/PGD Demo" logo_left_path = os.path.join("./logos", "logo_NCL.png") logo_right_path = os.path.join("./logos", "logo_EdgeAI.png") logo_left = img_to_data_url(logo_left_path) logo_right = img_to_data_url(logo_right_path) desc_html = f"""

{title}

Adversarial examples are generated locally using a client-side model’s gradients (white-box), then evaluated against the server-side aggregated (FedAvg) global model. If the perturbation transfers, it can degrade or alter the FedAvg model’s predictions on the same input image. Object detection in this demo is limited to 'car', 'van', and 'truck' classes only.

""" with gr.Blocks(title=title) as demo: # 标题居中 gr.Markdown(desc_html) with gr.Row(): # ===== 左列:两个输入区块 ===== with gr.Column(scale=5): # 输入区块 1:上传窗口 & 样例选择 —— 左右并列 with gr.Row(): with gr.Column(scale=7): in_img = gr.Image(type="numpy", label="Input image", format="png", image_mode="RGB") with gr.Column(scale=2): if SAMPLE_IMAGES: gr.Examples( examples=SAMPLE_IMAGES, inputs=[in_img], label=f"Select from sample images", examples_per_page=9, # run_on_click 默认为 False(只填充,不执行) ) # 输入 2:攻击与参数 with gr.Accordion("Attack settings", open=True): attack_mode = gr.Radio( choices=["none", "fgsm", "pgd", "random noise"], value="fgsm", label="", show_label=False ) eps = gr.Slider(0.0, 1, step=0.01, value=0.0314, label="eps") alpha = gr.Slider(0.001, 0.05, step=0.001, value=0.0078, label="alpha (PGD step)") iters = gr.Slider(1, 100, step=1, value=10, label="PGD iterations") conf = gr.Slider(0.0, 1.0, step=0.01, value=0.45, label="Confidence threshold (live)") with gr.Row(): target_cls = gr.Dropdown( choices=[(name, i) for i, name in enumerate(names)], value=2, # default key=2 label="Class-targeted attack" ) with gr.Row(): btn_reset = gr.ClearButton(components=[], value="Reset to defaults") btn_submit = gr.Button("Submit", variant="primary") # ===== 右列:两个输出区块 ===== with gr.Column(scale=5): # 新增:评测模型选择 with gr.Row(): eval_choice = gr.Dropdown( choices=[(f"Client model {MODEL_PATH}", "client"), (f"Global model {MODEL_PATH_C}", "global")], value="global", # ★ 初始值为合法 value label="Evaluation model" ) out_orig = gr.Image(label="Original detection", format="png") out_adv = gr.Image(label="After attack detection", format="png") out_adv_file = gr.File(label="Download Adversarial example (PNG)") eval_model_state = gr.State(value="yolom") # Submit:手动运行 btn_submit.click( fn=detect_and_attack, inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls], outputs=[out_orig, out_adv, out_adv_file] ) def to_defaults(): return 0.0314, 0.0078, 10, 0.45 btn_reset.click( fn=to_defaults, outputs = [eps, alpha, iters, conf], ) # 仅 conf 滑块“实时” conf.release( fn=detect_and_attack, inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls], outputs=[out_orig, out_adv, out_adv_file] ) # ★ 合并后的单一回调:规范化下拉值 + 返回(更新后的下拉值, 模型对象) def on_eval_change(val: str): if isinstance(val, (list, tuple)): val = val[0] if len(val) else "client" if val not in ("client", "global"): val = "client" model = "yolom" if val == "client" else "yolom_c" return gr.update(value=val), model # 仅这一条 change 绑定(删掉你原来那个只写 State 的 change,避免并发覆盖) eval_choice.change( fn=on_eval_change, inputs=eval_choice, outputs=[eval_choice, eval_model_state] ) # 页面加载时同步一次,避免初次为空/不一致 demo.load( fn=on_eval_change, inputs=eval_choice, outputs=[eval_choice, eval_model_state] ) demo.queue(default_concurrency_limit=2, max_size=20) if os.getenv("SPACE_ID"): demo.launch( server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), show_error=True, ) else: demo.launch()