File size: 12,047 Bytes
d501783
 
 
 
 
 
685f135
d501783
e51df14
685f135
e3e8c47
e51df14
d501783
e51df14
d501783
 
 
e51df14
685f135
e51df14
 
 
685f135
e51df14
d501783
 
 
685f135
f60b72c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d501783
f60b72c
 
d501783
685f135
f60b72c
 
d501783
f60b72c
e51df14
d501783
 
f60b72c
 
 
 
 
 
 
 
d501783
f60b72c
 
 
 
 
 
 
 
d501783
f60b72c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d501783
685f135
d501783
685f135
e51df14
685f135
e51df14
f60b72c
e51df14
d501783
685f135
d501783
 
685f135
d501783
685f135
d501783
 
 
 
 
 
685f135
 
 
 
 
f60b72c
685f135
 
 
 
d501783
e3e8c47
 
 
 
 
 
 
 
 
 
 
e51df14
d501783
 
 
e3e8c47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685f135
e51df14
 
e3e8c47
e51df14
 
 
 
 
 
 
685f135
e51df14
 
 
 
 
 
685f135
e51df14
 
 
 
685f135
e51df14
 
 
 
 
 
685f135
e51df14
 
 
 
685f135
 
 
 
e51df14
685f135
 
e51df14
 
 
 
 
 
 
 
685f135
 
e51df14
 
 
685f135
 
 
e51df14
685f135
 
e51df14
 
 
685f135
 
 
 
 
 
 
 
 
e51df14
 
 
 
 
685f135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51df14
 
685f135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51df14
d501783
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import io
import numpy as np
from PIL import Image
import gradio as gr
import torch
import cv2
from ultralytics import YOLO
import attacks  # 上面那个 attacks.py,确保和 app.py 在同一目录或可 import 的包路径
import os, glob
from pathlib import Path
import base64

MODEL_PATH = "weights/fed_model2.pt"
MODEL_PATH_C = "weights/yolov8s_3cls.pt"

names = ['car', 'van', 'truck']
imgsz = 640

SAMPLE_DIR = "./images"
SAMPLE_IMAGES = sorted([
    p for p in glob.glob(os.path.join(SAMPLE_DIR, "*"))
    if os.path.splitext(p)[1].lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
])[:9]  # 只取前4张

# Load ultralytics model (wrapper)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
yolom = YOLO(MODEL_PATH)  # wrapper

def iou(a, b):
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    iw = max(0, min(ax2, bx2) - max(ax1, bx1))
    ih = max(0, min(ay2, by2) - max(ay1, by1))
    inter = iw * ih
    if inter <= 0:
        return 0.0
    area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1)
    area_b = max(0, bx2 - bx1) * max(0, by2 - by1)
    return inter / (area_a + area_b - inter + 1e-9)
    
# def center_and_diag(b): #IOU足够好 未启用
#     x1, y1, x2, y2 = b
#     cx = 0.5 * (x1 + x2); cy = 0.5 * (y1 + y2)
#     diag = max(1e-9, ((x2 - x1)**2 + (y2 - y1)**2)**0.5)
#     area = max(0, (x2 - x1)) * max(0, (y2 - y1))
#     return cx, cy, diag, area
  
def run_detection_on_pil(img_pil: Image.Image, eval_model_state, conf: float = 0.45, GT_boxes=None):
    """
    推理+可视化。GT_boxes 和返回的 preds 都是:
      [{'xyxy': (x1,y1,x2,y2), 'cls': int, 'conf': float(optional)}]
    """


    # ---- 1) 推理 ----
    img = np.array(img_pil)
    eva_model = yolom if eval_model_state == "yolom" else YOLO(MODEL_PATH_C)
    res = eva_model.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False)
    r = res[0]
    im_out = img.copy()

    # 名称表(尽量稳)
    names = getattr(r, "names", None)
    if names is None and hasattr(eva_model, "model") and hasattr(eva_model.model, "names"):
        names = eva_model.model.names

    # ---- 2) 规整预测框到简单结构 ----
    preds = []
    try:
        bxs = r.boxes
        if bxs is not None and len(bxs) > 0:
            for b in bxs:
                xyxy = b.xyxy[0].detach().cpu().numpy().tolist()
                x1, y1, x2, y2 = [int(v) for v in xyxy]
                cls_id = int(b.cls[0].detach().cpu().numpy())
                conf_score = float(b.conf[0].detach().cpu().numpy())
                preds.append({'xyxy': (x1, y1, x2, y2), 'cls': cls_id, 'conf': conf_score})
    except Exception as e:
        print("collect preds error:", e)

    # ---- 3) IoU 匹配 + 画框 ----
    IOU_THR = 0.3
    # CENTER_DIST_RATIO = 0.30    # 中心点距离 / 预测框对角线 <= 0.30 即视为同一目标
    # AREA_RATIO_THR = 0.25       # 面积比例下限:min(area_p, area_g) / max(...) >= 0.25
    gt_used = set()
    
    for p in preds:
        color = (0, 255, 0)   # 同类:绿
        px1, py1, px2, py2 = p['xyxy']
        pname = names[p['cls']] if (names is not None and p['cls'] in getattr(names, 'keys', lambda: [])()) else (
                names[p['cls']] if (isinstance(names, (list, tuple)) and 0 <= p['cls'] < len(names)) else str(p['cls'])
        )
        label = f"{pname}:{p.get('conf', 0.0):.2f}"

        if GT_boxes != None:
            # 找 IoU 最高的未用 GT
            best_j, best_iou = -1, 0.0
            for j, g in enumerate(GT_boxes):
                if j in gt_used:
                    continue
                i = iou(p['xyxy'], g['xyxy'])
                if i > best_iou:
                    best_iou, best_j = i, j
    
            # 颜色规则:匹配且同类=绿;匹配但异类=红;
            if best_iou >= IOU_THR:
                gt_used.add(best_j)
                if p['cls'] != int(GT_boxes[best_j]['cls']):
                    color = (255, 0, 0)   # 异类:红

        cv2.rectangle(im_out, (px1, py1), (px2, py2), color, 2)
        cv2.putText(im_out, label, (px1, max(10, py1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

    return Image.fromarray(im_out), preds


def detect_and_attack(image, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls):
    if image is None:
        return None, None, None

    pil = Image.fromarray(image.astype('uint8')).convert('RGB')

    original_vis, GT_boxes = run_detection_on_pil(pil, eval_model_state, conf=conf, GT_boxes=None) 

    if attack_mode == "none":
        return original_vis, None, None
    try:
        if attack_mode == "fgsm":
            adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls)
        elif attack_mode == "pgd":
            adv_pil = attacks.pgd_attack_on_detector(yolom, pil, eps=eps, alpha=alpha, iters=iters, device=device, imgsz=imgsz, gt_xywh=GT_boxes, target_cls=target_cls)
        else:
            adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
    except Exception as ex:
        print("Whitebox attack failed:", ex)
        adv_pil = attacks.demo_random_perturbation(pil, eps=eps)

    adv_path = Path("tmp")         # 相对当前工作目录
    adv_path.mkdir(parents=True, exist_ok=True)
    adv_file = adv_path / "adv.png"
    adv_pil.save(adv_file, format="PNG")  

    adv_vis, _ = run_detection_on_pil(adv_pil, eval_model_state, conf=conf, GT_boxes=GT_boxes)
    return original_vis, adv_vis, str(adv_file)

def handle_select(key):
    return f"key={key}, label={names[key]}"

def img_to_data_url(path: str) -> str:
    ext = os.path.splitext(path)[1].lower()
    if ext in [".jpg", ".jpeg"]:
        mime = "image/jpeg"
    elif ext == ".png":
        mime = "image/png"
    else:
        mime = "image/octet-stream"
    with open(path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode("utf-8")
    return f"data:{mime};base64,{b64}"

# Gradio UI
if __name__ == "__main__":
    title = "Federated Adversarial Attack — FGSM/PGD Demo"

    logo_left_path = os.path.join("./logos", "logo_NCL.png")
    logo_right_path = os.path.join("./logos", "logo_EdgeAI.png")

    logo_left = img_to_data_url(logo_left_path)
    logo_right = img_to_data_url(logo_right_path)

    desc_html = f"""
<div style="display:flex;align-items:center;justify-content:space-between;gap:20px;">
  <img src="{logo_left}" style="height:60px;">
  <div style="flex:1;text-align:center;">
    <h1 style="margin:0;">{title}</h1>
    <p style="font-size:14px;">
      Adversarial examples are generated locally using a
      <strong>client-side</strong> model’s gradients (white-box), then evaluated against the
      <strong>server-side aggregated (FedAvg) global model</strong>.
      If the perturbation transfers, it can degrade or alter the FedAvg model’s predictions on the same input image.
      Object detection in this demo is limited to <strong>'car'</strong>, <strong>'van'</strong>, and <strong>'truck'</strong> classes only.
    </p>
  </div>
  <img src="{logo_right}" style="height:60px;">
</div>
"""

    with gr.Blocks(title=title) as demo:
        # 标题居中
        gr.Markdown(desc_html)

        with gr.Row():
            # ===== 左列:两个输入区块 =====
            with gr.Column(scale=5):
                # 输入区块 1:上传窗口 & 样例选择 —— 左右并列
                with gr.Row():
                    with gr.Column(scale=7):
                        in_img = gr.Image(type="numpy", label="Input image", format="png", image_mode="RGB")
                    with gr.Column(scale=2):
                        if SAMPLE_IMAGES:
                            gr.Examples(
                                examples=SAMPLE_IMAGES,
                                inputs=[in_img],
                                label=f"Select from sample images",
                                examples_per_page=9,
                                # run_on_click 默认为 False(只填充,不执行)
                            )

                # 输入 2:攻击与参数
                with gr.Accordion("Attack settings", open=True):
                    attack_mode = gr.Radio(
                        choices=["none", "fgsm", "pgd", "random noise"],
                        value="fgsm",
                        label="",          
                        show_label=False
                    )
                    eps   = gr.Slider(0.0, 1, step=0.01, value=0.0314, label="eps")
                    alpha = gr.Slider(0.001, 0.05, step=0.001, value=0.0078, label="alpha (PGD step)")
                    iters = gr.Slider(1, 100, step=1, value=10, label="PGD iterations")
                    conf  = gr.Slider(0.0, 1.0, step=0.01, value=0.45, label="Confidence threshold (live)")
                with gr.Row():
                    target_cls = gr.Dropdown(
                        choices=[(name, i) for i, name in enumerate(names)],
                        value=2,      # default key=2
                        label="Class-targeted attack"
                    )
                with gr.Row():
                    btn_reset  = gr.ClearButton(components=[], value="Reset to defaults")
                    btn_submit = gr.Button("Submit", variant="primary")

            # ===== 右列:两个输出区块 =====
            with gr.Column(scale=5):
                # 新增:评测模型选择
                with gr.Row():
                    eval_choice = gr.Dropdown(
                        choices=[(f"Client model {MODEL_PATH}", "client"),
                                 (f"Global model {MODEL_PATH_C}", "global")],
                        value="global",                 # ★ 初始值为合法 value
                        label="Evaluation model"
                    )
                    
                out_orig = gr.Image(label="Original detection", format="png")
                out_adv  = gr.Image(label="After attack detection", format="png")
                out_adv_file = gr.File(label="Download Adversarial example (PNG)")


        eval_model_state = gr.State(value="yolom")  
        # Submit:手动运行
        btn_submit.click(
            fn=detect_and_attack,
            inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls],
            outputs=[out_orig, out_adv, out_adv_file]
        )

        def to_defaults():
            return 0.0314, 0.0078, 10, 0.45
        btn_reset.click(
            fn=to_defaults,
            outputs = [eps, alpha, iters, conf],
        )

        # 仅 conf 滑块“实时”
        conf.release(
            fn=detect_and_attack,
            inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf, target_cls],
            outputs=[out_orig, out_adv, out_adv_file]
        )

        # ★ 合并后的单一回调:规范化下拉值 + 返回(更新后的下拉值, 模型对象)
        def on_eval_change(val: str):
            if isinstance(val, (list, tuple)):
                val = val[0] if len(val) else "client"
            if val not in ("client", "global"):
                val = "client"
            model = "yolom" if val == "client" else "yolom_c"
            return gr.update(value=val), model
        # 仅这一条 change 绑定(删掉你原来那个只写 State 的 change,避免并发覆盖)
        eval_choice.change(
            fn=on_eval_change,
            inputs=eval_choice,
            outputs=[eval_choice, eval_model_state]
        )

            # 页面加载时同步一次,避免初次为空/不一致
        demo.load(
            fn=on_eval_change,
            inputs=eval_choice,
            outputs=[eval_choice, eval_model_state]
        ) 

 
    demo.queue(default_concurrency_limit=2, max_size=20)
    if os.getenv("SPACE_ID"):
        demo.launch(
            server_name="0.0.0.0",
            server_port=int(os.getenv("PORT", 7860)),
            show_error=True,
        )
    else:
        demo.launch()