|
|
import io |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import torch |
|
|
import cv2 |
|
|
from ultralytics import YOLO |
|
|
import attacks |
|
|
import os, glob |
|
|
from pathlib import Path |
|
|
import base64 |
|
|
|
|
|
MODEL_PATH = "weights/fed_model2.pt" |
|
|
MODEL_PATH_C = "weights/yolov8s_3cls.pt" |
|
|
|
|
|
names = ['car', 'van', 'truck'] |
|
|
imgsz = 640 |
|
|
|
|
|
SAMPLE_DIR = "./images" |
|
|
SAMPLE_IMAGES = sorted([ |
|
|
p for p in glob.glob(os.path.join(SAMPLE_DIR, "*")) |
|
|
if os.path.splitext(p)[1].lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"] |
|
|
])[:9] |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
yolom = YOLO(MODEL_PATH) |
|
|
|
|
|
def iou(a, b): |
|
|
ax1, ay1, ax2, ay2 = a |
|
|
bx1, by1, bx2, by2 = b |
|
|
iw = max(0, min(ax2, bx2) - max(ax1, bx1)) |
|
|
ih = max(0, min(ay2, by2) - max(ay1, by1)) |
|
|
inter = iw * ih |
|
|
if inter <= 0: |
|
|
return 0.0 |
|
|
area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1) |
|
|
area_b = max(0, bx2 - bx1) * max(0, by2 - by1) |
|
|
return inter / (area_a + area_b - inter + 1e-9) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_detection_on_pil(img_pil: Image.Image, eval_model_state, conf: float = 0.45, GT_boxes=None): |
|
|
""" |
|
|
推理+可视化。GT_boxes 和返回的 preds 都是: |
|
|
[{'xyxy': (x1,y1,x2,y2), 'cls': int, 'conf': float(optional)}] |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
img = np.array(img_pil) |
|
|
eva_model = yolom if eval_model_state == "yolom" else YOLO(MODEL_PATH_C) |
|
|
res = eva_model.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False) |
|
|
r = res[0] |
|
|
im_out = img.copy() |
|
|
|
|
|
|
|
|
names = getattr(r, "names", None) |
|
|
if names is None and hasattr(eva_model, "model") and hasattr(eva_model.model, "names"): |
|
|
names = eva_model.model.names |
|
|
|
|
|
|
|
|
preds = [] |
|
|
try: |
|
|
bxs = r.boxes |
|
|
if bxs is not None and len(bxs) > 0: |
|
|
for b in bxs: |
|
|
xyxy = b.xyxy[0].detach().cpu().numpy().tolist() |
|
|
x1, y1, x2, y2 = [int(v) for v in xyxy] |
|
|
cls_id = int(b.cls[0].detach().cpu().numpy()) |
|
|
conf_score = float(b.conf[0].detach().cpu().numpy()) |
|
|
preds.append({'xyxy': (x1, y1, x2, y2), 'cls': cls_id, 'conf': conf_score}) |
|
|
except Exception as e: |
|
|
print("collect preds error:", e) |
|
|
|
|
|
|
|
|
IOU_THR = 0.3 |
|
|
|
|
|
|
|
|
gt_used = set() |
|
|
|
|
|
for p in preds: |
|
|
color = (0, 255, 0) |
|
|
px1, py1, px2, py2 = p['xyxy'] |
|
|
pname = names[p['cls']] if (names is not None and p['cls'] in getattr(names, 'keys', lambda: [])()) else ( |
|
|
names[p['cls']] if (isinstance(names, (list, tuple)) and 0 <= p['cls'] < len(names)) else str(p['cls']) |
|
|
) |
|
|
label = f"{pname}:{p.get('conf', 0.0):.2f}" |
|
|
|
|
|
if GT_boxes != None: |
|
|
|
|
|
best_j, best_iou = -1, 0.0 |
|
|
for j, g in enumerate(GT_boxes): |
|
|
if j in gt_used: |
|
|
continue |
|
|
i = iou(p['xyxy'], g['xyxy']) |
|
|
if i > best_iou: |
|
|
best_iou, best_j = i, j |
|
|
|
|
|
|
|
|
if best_iou >= IOU_THR: |
|
|
gt_used.add(best_j) |
|
|
if p['cls'] != int(GT_boxes[best_j]['cls']): |
|
|
color = (255, 0, 0) |
|
|
|
|
|
cv2.rectangle(im_out, (px1, py1), (px2, py2), color, 2) |
|
|
cv2.putText(im_out, label, (px1, max(10, py1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) |
|
|
|
|
|
return Image.fromarray(im_out), preds |
|
|
|
|
|
|
|
|
def detect_and_attack(image, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls): |
|
|
if image is None: |
|
|
return None, None, None |
|
|
|
|
|
pil = Image.fromarray(image.astype('uint8')).convert('RGB') |
|
|
|
|
|
original_vis, GT_boxes = run_detection_on_pil(pil, eval_model_state, conf=conf, GT_boxes=None) |
|
|
|
|
|
if attack_mode == "none": |
|
|
return original_vis, None, None |
|
|
try: |
|
|
if attack_mode == "fgsm": |
|
|
adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls) |
|
|
elif attack_mode == "pgd": |
|
|
adv_pil = attacks.pgd_attack_on_detector(yolom, pil, eps=eps, alpha=alpha, iters=iters, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls) |
|
|
else: |
|
|
adv_pil = attacks.demo_random_perturbation(pil, eps=eps) |
|
|
except Exception as ex: |
|
|
print("Whitebox attack failed:", ex) |
|
|
adv_pil = attacks.demo_random_perturbation(pil, eps=eps) |
|
|
|
|
|
adv_path = Path("tmp") |
|
|
adv_path.mkdir(parents=True, exist_ok=True) |
|
|
adv_file = adv_path / "adv.png" |
|
|
adv_pil.save(adv_file, format="PNG") |
|
|
|
|
|
adv_vis, _ = run_detection_on_pil(adv_pil, eval_model_state, conf=conf, GT_boxes=GT_boxes) |
|
|
return original_vis, adv_vis, str(adv_file) |
|
|
|
|
|
def handle_select(key): |
|
|
return f"key={key}, label={names[key]}" |
|
|
|
|
|
def img_to_data_url(path: str) -> str: |
|
|
ext = os.path.splitext(path)[1].lower() |
|
|
if ext in [".jpg", ".jpeg"]: |
|
|
mime = "image/jpeg" |
|
|
elif ext == ".png": |
|
|
mime = "image/png" |
|
|
else: |
|
|
mime = "image/octet-stream" |
|
|
with open(path, "rb") as f: |
|
|
b64 = base64.b64encode(f.read()).decode("utf-8") |
|
|
return f"data:{mime};base64,{b64}" |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
title = "Federated Adversarial Attack — FGSM/PGD Demo" |
|
|
|
|
|
logo_left_path = os.path.join("./logos", "logo_NCL.png") |
|
|
logo_right_path = os.path.join("./logos", "logo_EdgeAI.png") |
|
|
|
|
|
logo_left = img_to_data_url(logo_left_path) |
|
|
logo_right = img_to_data_url(logo_right_path) |
|
|
|
|
|
desc_html = f""" |
|
|
<div style="display:flex;align-items:center;justify-content:space-between;gap:20px;"> |
|
|
<img src="{logo_left}" style="height:60px;"> |
|
|
<div style="flex:1;text-align:center;"> |
|
|
<h1 style="margin:0;">{title}</h1> |
|
|
<p style="font-size:14px;"> |
|
|
Adversarial examples are generated locally using a |
|
|
<strong>client-side</strong> model’s gradients (white-box), then evaluated against the |
|
|
<strong>server-side aggregated (FedAvg) global model</strong>. |
|
|
If the perturbation transfers, it can degrade or alter the FedAvg model’s predictions on the same input image. |
|
|
Object detection in this demo is limited to <strong>'car'</strong>, <strong>'van'</strong>, and <strong>'truck'</strong> classes only. |
|
|
</p> |
|
|
</div> |
|
|
<img src="{logo_right}" style="height:60px;"> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title=title) as demo: |
|
|
|
|
|
gr.Markdown(desc_html) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=5): |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=7): |
|
|
in_img = gr.Image(type="numpy", label="Input image", format="png", image_mode="RGB") |
|
|
with gr.Column(scale=2): |
|
|
if SAMPLE_IMAGES: |
|
|
gr.Examples( |
|
|
examples=SAMPLE_IMAGES, |
|
|
inputs=[in_img], |
|
|
label=f"Select from sample images", |
|
|
examples_per_page=9, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Attack settings", open=True): |
|
|
attack_mode = gr.Radio( |
|
|
choices=["none", "fgsm", "pgd", "random noise"], |
|
|
value="fgsm", |
|
|
label="", |
|
|
show_label=False |
|
|
) |
|
|
eps = gr.Slider(0.0, 1, step=0.01, value=0.0314, label="eps") |
|
|
alpha = gr.Slider(0.001, 0.05, step=0.001, value=0.0078, label="alpha (PGD step)") |
|
|
iters = gr.Slider(1, 100, step=1, value=10, label="PGD iterations") |
|
|
conf = gr.Slider(0.0, 1.0, step=0.01, value=0.45, label="Confidence threshold (live)") |
|
|
with gr.Row(): |
|
|
target_cls = gr.Dropdown( |
|
|
choices=[(name, i) for i, name in enumerate(names)], |
|
|
value=2, |
|
|
label="Class-targeted attack" |
|
|
) |
|
|
with gr.Row(): |
|
|
btn_reset = gr.ClearButton(components=[], value="Reset to defaults") |
|
|
btn_submit = gr.Button("Submit", variant="primary") |
|
|
|
|
|
|
|
|
with gr.Column(scale=5): |
|
|
|
|
|
with gr.Row(): |
|
|
eval_choice = gr.Dropdown( |
|
|
choices=[(f"Client model {MODEL_PATH}", "client"), |
|
|
(f"Global model {MODEL_PATH_C}", "global")], |
|
|
value="global", |
|
|
label="Evaluation model" |
|
|
) |
|
|
|
|
|
out_orig = gr.Image(label="Original detection", format="png") |
|
|
out_adv = gr.Image(label="After attack detection", format="png") |
|
|
out_adv_file = gr.File(label="Download Adversarial example (PNG)") |
|
|
|
|
|
|
|
|
eval_model_state = gr.State(value="yolom") |
|
|
|
|
|
btn_submit.click( |
|
|
fn=detect_and_attack, |
|
|
inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls], |
|
|
outputs=[out_orig, out_adv, out_adv_file] |
|
|
) |
|
|
|
|
|
def to_defaults(): |
|
|
return 0.0314, 0.0078, 10, 0.45 |
|
|
btn_reset.click( |
|
|
fn=to_defaults, |
|
|
outputs = [eps, alpha, iters, conf], |
|
|
) |
|
|
|
|
|
|
|
|
conf.release( |
|
|
fn=detect_and_attack, |
|
|
inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls], |
|
|
outputs=[out_orig, out_adv, out_adv_file] |
|
|
) |
|
|
|
|
|
|
|
|
def on_eval_change(val: str): |
|
|
if isinstance(val, (list, tuple)): |
|
|
val = val[0] if len(val) else "client" |
|
|
if val not in ("client", "global"): |
|
|
val = "client" |
|
|
model = "yolom" if val == "client" else "yolom_c" |
|
|
return gr.update(value=val), model |
|
|
|
|
|
eval_choice.change( |
|
|
fn=on_eval_change, |
|
|
inputs=eval_choice, |
|
|
outputs=[eval_choice, eval_model_state] |
|
|
) |
|
|
|
|
|
|
|
|
demo.load( |
|
|
fn=on_eval_change, |
|
|
inputs=eval_choice, |
|
|
outputs=[eval_choice, eval_model_state] |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(default_concurrency_limit=2, max_size=20) |
|
|
if os.getenv("SPACE_ID"): |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=int(os.getenv("PORT", 7860)), |
|
|
show_error=True, |
|
|
) |
|
|
else: |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|