FedAdv / app.py
MarshallCN
add logo
e3e8c47
import io
import numpy as np
from PIL import Image
import gradio as gr
import torch
import cv2
from ultralytics import YOLO
import attacks # 上面那个 attacks.py,确保和 app.py 在同一目录或可 import 的包路径
import os, glob
from pathlib import Path
import base64
MODEL_PATH = "weights/fed_model2.pt"
MODEL_PATH_C = "weights/yolov8s_3cls.pt"
names = ['car', 'van', 'truck']
imgsz = 640
SAMPLE_DIR = "./images"
SAMPLE_IMAGES = sorted([
p for p in glob.glob(os.path.join(SAMPLE_DIR, "*"))
if os.path.splitext(p)[1].lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
])[:9] # 只取前4张
# Load ultralytics model (wrapper)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolom = YOLO(MODEL_PATH) # wrapper
def iou(a, b):
ax1, ay1, ax2, ay2 = a
bx1, by1, bx2, by2 = b
iw = max(0, min(ax2, bx2) - max(ax1, bx1))
ih = max(0, min(ay2, by2) - max(ay1, by1))
inter = iw * ih
if inter <= 0:
return 0.0
area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1)
area_b = max(0, bx2 - bx1) * max(0, by2 - by1)
return inter / (area_a + area_b - inter + 1e-9)
# def center_and_diag(b): #IOU足够好 未启用
# x1, y1, x2, y2 = b
# cx = 0.5 * (x1 + x2); cy = 0.5 * (y1 + y2)
# diag = max(1e-9, ((x2 - x1)**2 + (y2 - y1)**2)**0.5)
# area = max(0, (x2 - x1)) * max(0, (y2 - y1))
# return cx, cy, diag, area
def run_detection_on_pil(img_pil: Image.Image, eval_model_state, conf: float = 0.45, GT_boxes=None):
"""
推理+可视化。GT_boxes 和返回的 preds 都是:
[{'xyxy': (x1,y1,x2,y2), 'cls': int, 'conf': float(optional)}]
"""
# ---- 1) 推理 ----
img = np.array(img_pil)
eva_model = yolom if eval_model_state == "yolom" else YOLO(MODEL_PATH_C)
res = eva_model.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False)
r = res[0]
im_out = img.copy()
# 名称表(尽量稳)
names = getattr(r, "names", None)
if names is None and hasattr(eva_model, "model") and hasattr(eva_model.model, "names"):
names = eva_model.model.names
# ---- 2) 规整预测框到简单结构 ----
preds = []
try:
bxs = r.boxes
if bxs is not None and len(bxs) > 0:
for b in bxs:
xyxy = b.xyxy[0].detach().cpu().numpy().tolist()
x1, y1, x2, y2 = [int(v) for v in xyxy]
cls_id = int(b.cls[0].detach().cpu().numpy())
conf_score = float(b.conf[0].detach().cpu().numpy())
preds.append({'xyxy': (x1, y1, x2, y2), 'cls': cls_id, 'conf': conf_score})
except Exception as e:
print("collect preds error:", e)
# ---- 3) IoU 匹配 + 画框 ----
IOU_THR = 0.3
# CENTER_DIST_RATIO = 0.30 # 中心点距离 / 预测框对角线 <= 0.30 即视为同一目标
# AREA_RATIO_THR = 0.25 # 面积比例下限:min(area_p, area_g) / max(...) >= 0.25
gt_used = set()
for p in preds:
color = (0, 255, 0) # 同类:绿
px1, py1, px2, py2 = p['xyxy']
pname = names[p['cls']] if (names is not None and p['cls'] in getattr(names, 'keys', lambda: [])()) else (
names[p['cls']] if (isinstance(names, (list, tuple)) and 0 <= p['cls'] < len(names)) else str(p['cls'])
)
label = f"{pname}:{p.get('conf', 0.0):.2f}"
if GT_boxes != None:
# 找 IoU 最高的未用 GT
best_j, best_iou = -1, 0.0
for j, g in enumerate(GT_boxes):
if j in gt_used:
continue
i = iou(p['xyxy'], g['xyxy'])
if i > best_iou:
best_iou, best_j = i, j
# 颜色规则:匹配且同类=绿;匹配但异类=红;
if best_iou >= IOU_THR:
gt_used.add(best_j)
if p['cls'] != int(GT_boxes[best_j]['cls']):
color = (255, 0, 0) # 异类:红
cv2.rectangle(im_out, (px1, py1), (px2, py2), color, 2)
cv2.putText(im_out, label, (px1, max(10, py1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
return Image.fromarray(im_out), preds
def detect_and_attack(image, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls):
if image is None:
return None, None, None
pil = Image.fromarray(image.astype('uint8')).convert('RGB')
original_vis, GT_boxes = run_detection_on_pil(pil, eval_model_state, conf=conf, GT_boxes=None)
if attack_mode == "none":
return original_vis, None, None
try:
if attack_mode == "fgsm":
adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls)
elif attack_mode == "pgd":
adv_pil = attacks.pgd_attack_on_detector(yolom, pil, eps=eps, alpha=alpha, iters=iters, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls)
else:
adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
except Exception as ex:
print("Whitebox attack failed:", ex)
adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
adv_path = Path("tmp") # 相对当前工作目录
adv_path.mkdir(parents=True, exist_ok=True)
adv_file = adv_path / "adv.png"
adv_pil.save(adv_file, format="PNG")
adv_vis, _ = run_detection_on_pil(adv_pil, eval_model_state, conf=conf, GT_boxes=GT_boxes)
return original_vis, adv_vis, str(adv_file)
def handle_select(key):
return f"key={key}, label={names[key]}"
def img_to_data_url(path: str) -> str:
ext = os.path.splitext(path)[1].lower()
if ext in [".jpg", ".jpeg"]:
mime = "image/jpeg"
elif ext == ".png":
mime = "image/png"
else:
mime = "image/octet-stream"
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
return f"data:{mime};base64,{b64}"
# Gradio UI
if __name__ == "__main__":
title = "Federated Adversarial Attack — FGSM/PGD Demo"
logo_left_path = os.path.join("./logos", "logo_NCL.png")
logo_right_path = os.path.join("./logos", "logo_EdgeAI.png")
logo_left = img_to_data_url(logo_left_path)
logo_right = img_to_data_url(logo_right_path)
desc_html = f"""
<div style="display:flex;align-items:center;justify-content:space-between;gap:20px;">
<img src="{logo_left}" style="height:60px;">
<div style="flex:1;text-align:center;">
<h1 style="margin:0;">{title}</h1>
<p style="font-size:14px;">
Adversarial examples are generated locally using a
<strong>client-side</strong> model’s gradients (white-box), then evaluated against the
<strong>server-side aggregated (FedAvg) global model</strong>.
If the perturbation transfers, it can degrade or alter the FedAvg model’s predictions on the same input image.
Object detection in this demo is limited to <strong>'car'</strong>, <strong>'van'</strong>, and <strong>'truck'</strong> classes only.
</p>
</div>
<img src="{logo_right}" style="height:60px;">
</div>
"""
with gr.Blocks(title=title) as demo:
# 标题居中
gr.Markdown(desc_html)
with gr.Row():
# ===== 左列:两个输入区块 =====
with gr.Column(scale=5):
# 输入区块 1:上传窗口 & 样例选择 —— 左右并列
with gr.Row():
with gr.Column(scale=7):
in_img = gr.Image(type="numpy", label="Input image", format="png", image_mode="RGB")
with gr.Column(scale=2):
if SAMPLE_IMAGES:
gr.Examples(
examples=SAMPLE_IMAGES,
inputs=[in_img],
label=f"Select from sample images",
examples_per_page=9,
# run_on_click 默认为 False(只填充,不执行)
)
# 输入 2:攻击与参数
with gr.Accordion("Attack settings", open=True):
attack_mode = gr.Radio(
choices=["none", "fgsm", "pgd", "random noise"],
value="fgsm",
label="",
show_label=False
)
eps = gr.Slider(0.0, 1, step=0.01, value=0.0314, label="eps")
alpha = gr.Slider(0.001, 0.05, step=0.001, value=0.0078, label="alpha (PGD step)")
iters = gr.Slider(1, 100, step=1, value=10, label="PGD iterations")
conf = gr.Slider(0.0, 1.0, step=0.01, value=0.45, label="Confidence threshold (live)")
with gr.Row():
target_cls = gr.Dropdown(
choices=[(name, i) for i, name in enumerate(names)],
value=2, # default key=2
label="Class-targeted attack"
)
with gr.Row():
btn_reset = gr.ClearButton(components=[], value="Reset to defaults")
btn_submit = gr.Button("Submit", variant="primary")
# ===== 右列:两个输出区块 =====
with gr.Column(scale=5):
# 新增:评测模型选择
with gr.Row():
eval_choice = gr.Dropdown(
choices=[(f"Client model {MODEL_PATH}", "client"),
(f"Global model {MODEL_PATH_C}", "global")],
value="global", # ★ 初始值为合法 value
label="Evaluation model"
)
out_orig = gr.Image(label="Original detection", format="png")
out_adv = gr.Image(label="After attack detection", format="png")
out_adv_file = gr.File(label="Download Adversarial example (PNG)")
eval_model_state = gr.State(value="yolom")
# Submit:手动运行
btn_submit.click(
fn=detect_and_attack,
inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls],
outputs=[out_orig, out_adv, out_adv_file]
)
def to_defaults():
return 0.0314, 0.0078, 10, 0.45
btn_reset.click(
fn=to_defaults,
outputs = [eps, alpha, iters, conf],
)
# 仅 conf 滑块“实时”
conf.release(
fn=detect_and_attack,
inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls],
outputs=[out_orig, out_adv, out_adv_file]
)
# ★ 合并后的单一回调:规范化下拉值 + 返回(更新后的下拉值, 模型对象)
def on_eval_change(val: str):
if isinstance(val, (list, tuple)):
val = val[0] if len(val) else "client"
if val not in ("client", "global"):
val = "client"
model = "yolom" if val == "client" else "yolom_c"
return gr.update(value=val), model
# 仅这一条 change 绑定(删掉你原来那个只写 State 的 change,避免并发覆盖)
eval_choice.change(
fn=on_eval_change,
inputs=eval_choice,
outputs=[eval_choice, eval_model_state]
)
# 页面加载时同步一次,避免初次为空/不一致
demo.load(
fn=on_eval_change,
inputs=eval_choice,
outputs=[eval_choice, eval_model_state]
)
demo.queue(default_concurrency_limit=2, max_size=20)
if os.getenv("SPACE_ID"):
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", 7860)),
show_error=True,
)
else:
demo.launch()