File size: 12,047 Bytes
d501783 685f135 d501783 e51df14 685f135 e3e8c47 e51df14 d501783 e51df14 d501783 e51df14 685f135 e51df14 685f135 e51df14 d501783 685f135 f60b72c d501783 f60b72c d501783 685f135 f60b72c d501783 f60b72c e51df14 d501783 f60b72c d501783 f60b72c d501783 f60b72c d501783 685f135 d501783 685f135 e51df14 685f135 e51df14 f60b72c e51df14 d501783 685f135 d501783 685f135 d501783 685f135 d501783 685f135 f60b72c 685f135 d501783 e3e8c47 e51df14 d501783 e3e8c47 685f135 e51df14 e3e8c47 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 685f135 e51df14 d501783 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 | import io
import numpy as np
from PIL import Image
import gradio as gr
import torch
import cv2
from ultralytics import YOLO
import attacks # 上面那个 attacks.py,确保和 app.py 在同一目录或可 import 的包路径
import os, glob
from pathlib import Path
import base64
MODEL_PATH = "weights/fed_model2.pt"
MODEL_PATH_C = "weights/yolov8s_3cls.pt"
names = ['car', 'van', 'truck']
imgsz = 640
SAMPLE_DIR = "./images"
SAMPLE_IMAGES = sorted([
p for p in glob.glob(os.path.join(SAMPLE_DIR, "*"))
if os.path.splitext(p)[1].lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
])[:9] # 只取前4张
# Load ultralytics model (wrapper)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolom = YOLO(MODEL_PATH) # wrapper
def iou(a, b):
ax1, ay1, ax2, ay2 = a
bx1, by1, bx2, by2 = b
iw = max(0, min(ax2, bx2) - max(ax1, bx1))
ih = max(0, min(ay2, by2) - max(ay1, by1))
inter = iw * ih
if inter <= 0:
return 0.0
area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1)
area_b = max(0, bx2 - bx1) * max(0, by2 - by1)
return inter / (area_a + area_b - inter + 1e-9)
# def center_and_diag(b): #IOU足够好 未启用
# x1, y1, x2, y2 = b
# cx = 0.5 * (x1 + x2); cy = 0.5 * (y1 + y2)
# diag = max(1e-9, ((x2 - x1)**2 + (y2 - y1)**2)**0.5)
# area = max(0, (x2 - x1)) * max(0, (y2 - y1))
# return cx, cy, diag, area
def run_detection_on_pil(img_pil: Image.Image, eval_model_state, conf: float = 0.45, GT_boxes=None):
"""
推理+可视化。GT_boxes 和返回的 preds 都是:
[{'xyxy': (x1,y1,x2,y2), 'cls': int, 'conf': float(optional)}]
"""
# ---- 1) 推理 ----
img = np.array(img_pil)
eva_model = yolom if eval_model_state == "yolom" else YOLO(MODEL_PATH_C)
res = eva_model.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False)
r = res[0]
im_out = img.copy()
# 名称表(尽量稳)
names = getattr(r, "names", None)
if names is None and hasattr(eva_model, "model") and hasattr(eva_model.model, "names"):
names = eva_model.model.names
# ---- 2) 规整预测框到简单结构 ----
preds = []
try:
bxs = r.boxes
if bxs is not None and len(bxs) > 0:
for b in bxs:
xyxy = b.xyxy[0].detach().cpu().numpy().tolist()
x1, y1, x2, y2 = [int(v) for v in xyxy]
cls_id = int(b.cls[0].detach().cpu().numpy())
conf_score = float(b.conf[0].detach().cpu().numpy())
preds.append({'xyxy': (x1, y1, x2, y2), 'cls': cls_id, 'conf': conf_score})
except Exception as e:
print("collect preds error:", e)
# ---- 3) IoU 匹配 + 画框 ----
IOU_THR = 0.3
# CENTER_DIST_RATIO = 0.30 # 中心点距离 / 预测框对角线 <= 0.30 即视为同一目标
# AREA_RATIO_THR = 0.25 # 面积比例下限:min(area_p, area_g) / max(...) >= 0.25
gt_used = set()
for p in preds:
color = (0, 255, 0) # 同类:绿
px1, py1, px2, py2 = p['xyxy']
pname = names[p['cls']] if (names is not None and p['cls'] in getattr(names, 'keys', lambda: [])()) else (
names[p['cls']] if (isinstance(names, (list, tuple)) and 0 <= p['cls'] < len(names)) else str(p['cls'])
)
label = f"{pname}:{p.get('conf', 0.0):.2f}"
if GT_boxes != None:
# 找 IoU 最高的未用 GT
best_j, best_iou = -1, 0.0
for j, g in enumerate(GT_boxes):
if j in gt_used:
continue
i = iou(p['xyxy'], g['xyxy'])
if i > best_iou:
best_iou, best_j = i, j
# 颜色规则:匹配且同类=绿;匹配但异类=红;
if best_iou >= IOU_THR:
gt_used.add(best_j)
if p['cls'] != int(GT_boxes[best_j]['cls']):
color = (255, 0, 0) # 异类:红
cv2.rectangle(im_out, (px1, py1), (px2, py2), color, 2)
cv2.putText(im_out, label, (px1, max(10, py1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
return Image.fromarray(im_out), preds
def detect_and_attack(image, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls):
if image is None:
return None, None, None
pil = Image.fromarray(image.astype('uint8')).convert('RGB')
original_vis, GT_boxes = run_detection_on_pil(pil, eval_model_state, conf=conf, GT_boxes=None)
if attack_mode == "none":
return original_vis, None, None
try:
if attack_mode == "fgsm":
adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls)
elif attack_mode == "pgd":
adv_pil = attacks.pgd_attack_on_detector(yolom, pil, eps=eps, alpha=alpha, iters=iters, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls)
else:
adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
except Exception as ex:
print("Whitebox attack failed:", ex)
adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
adv_path = Path("tmp") # 相对当前工作目录
adv_path.mkdir(parents=True, exist_ok=True)
adv_file = adv_path / "adv.png"
adv_pil.save(adv_file, format="PNG")
adv_vis, _ = run_detection_on_pil(adv_pil, eval_model_state, conf=conf, GT_boxes=GT_boxes)
return original_vis, adv_vis, str(adv_file)
def handle_select(key):
return f"key={key}, label={names[key]}"
def img_to_data_url(path: str) -> str:
ext = os.path.splitext(path)[1].lower()
if ext in [".jpg", ".jpeg"]:
mime = "image/jpeg"
elif ext == ".png":
mime = "image/png"
else:
mime = "image/octet-stream"
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
return f"data:{mime};base64,{b64}"
# Gradio UI
if __name__ == "__main__":
title = "Federated Adversarial Attack — FGSM/PGD Demo"
logo_left_path = os.path.join("./logos", "logo_NCL.png")
logo_right_path = os.path.join("./logos", "logo_EdgeAI.png")
logo_left = img_to_data_url(logo_left_path)
logo_right = img_to_data_url(logo_right_path)
desc_html = f"""
<div style="display:flex;align-items:center;justify-content:space-between;gap:20px;">
<img src="{logo_left}" style="height:60px;">
<div style="flex:1;text-align:center;">
<h1 style="margin:0;">{title}</h1>
<p style="font-size:14px;">
Adversarial examples are generated locally using a
<strong>client-side</strong> model’s gradients (white-box), then evaluated against the
<strong>server-side aggregated (FedAvg) global model</strong>.
If the perturbation transfers, it can degrade or alter the FedAvg model’s predictions on the same input image.
Object detection in this demo is limited to <strong>'car'</strong>, <strong>'van'</strong>, and <strong>'truck'</strong> classes only.
</p>
</div>
<img src="{logo_right}" style="height:60px;">
</div>
"""
with gr.Blocks(title=title) as demo:
# 标题居中
gr.Markdown(desc_html)
with gr.Row():
# ===== 左列:两个输入区块 =====
with gr.Column(scale=5):
# 输入区块 1:上传窗口 & 样例选择 —— 左右并列
with gr.Row():
with gr.Column(scale=7):
in_img = gr.Image(type="numpy", label="Input image", format="png", image_mode="RGB")
with gr.Column(scale=2):
if SAMPLE_IMAGES:
gr.Examples(
examples=SAMPLE_IMAGES,
inputs=[in_img],
label=f"Select from sample images",
examples_per_page=9,
# run_on_click 默认为 False(只填充,不执行)
)
# 输入 2:攻击与参数
with gr.Accordion("Attack settings", open=True):
attack_mode = gr.Radio(
choices=["none", "fgsm", "pgd", "random noise"],
value="fgsm",
label="",
show_label=False
)
eps = gr.Slider(0.0, 1, step=0.01, value=0.0314, label="eps")
alpha = gr.Slider(0.001, 0.05, step=0.001, value=0.0078, label="alpha (PGD step)")
iters = gr.Slider(1, 100, step=1, value=10, label="PGD iterations")
conf = gr.Slider(0.0, 1.0, step=0.01, value=0.45, label="Confidence threshold (live)")
with gr.Row():
target_cls = gr.Dropdown(
choices=[(name, i) for i, name in enumerate(names)],
value=2, # default key=2
label="Class-targeted attack"
)
with gr.Row():
btn_reset = gr.ClearButton(components=[], value="Reset to defaults")
btn_submit = gr.Button("Submit", variant="primary")
# ===== 右列:两个输出区块 =====
with gr.Column(scale=5):
# 新增:评测模型选择
with gr.Row():
eval_choice = gr.Dropdown(
choices=[(f"Client model {MODEL_PATH}", "client"),
(f"Global model {MODEL_PATH_C}", "global")],
value="global", # ★ 初始值为合法 value
label="Evaluation model"
)
out_orig = gr.Image(label="Original detection", format="png")
out_adv = gr.Image(label="After attack detection", format="png")
out_adv_file = gr.File(label="Download Adversarial example (PNG)")
eval_model_state = gr.State(value="yolom")
# Submit:手动运行
btn_submit.click(
fn=detect_and_attack,
inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls],
outputs=[out_orig, out_adv, out_adv_file]
)
def to_defaults():
return 0.0314, 0.0078, 10, 0.45
btn_reset.click(
fn=to_defaults,
outputs = [eps, alpha, iters, conf],
)
# 仅 conf 滑块“实时”
conf.release(
fn=detect_and_attack,
inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls],
outputs=[out_orig, out_adv, out_adv_file]
)
# ★ 合并后的单一回调:规范化下拉值 + 返回(更新后的下拉值, 模型对象)
def on_eval_change(val: str):
if isinstance(val, (list, tuple)):
val = val[0] if len(val) else "client"
if val not in ("client", "global"):
val = "client"
model = "yolom" if val == "client" else "yolom_c"
return gr.update(value=val), model
# 仅这一条 change 绑定(删掉你原来那个只写 State 的 change,避免并发覆盖)
eval_choice.change(
fn=on_eval_change,
inputs=eval_choice,
outputs=[eval_choice, eval_model_state]
)
# 页面加载时同步一次,避免初次为空/不一致
demo.load(
fn=on_eval_change,
inputs=eval_choice,
outputs=[eval_choice, eval_model_state]
)
demo.queue(default_concurrency_limit=2, max_size=20)
if os.getenv("SPACE_ID"):
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", 7860)),
show_error=True,
)
else:
demo.launch()
|