MarshallCN commited on
Commit
e51df14
·
1 Parent(s): 3cce1d2

Regular changes (checkpoints ignored)

Browse files
.gitignore CHANGED
@@ -6,4 +6,5 @@ AdvTest.ipynb
6
  *.pyc
7
  *checkpoint*
8
  *checkpoint*/
9
- .ipynb_checkpoints/*
 
 
6
  *.pyc
7
  *checkpoint*
8
  *checkpoint*/
9
+ **/.ipynb_checkpoints/
10
+ *-checkpoint.*
app.py CHANGED
@@ -6,18 +6,29 @@ import torch
6
  from ultralytics import YOLO
7
  import cv2
8
  import attacks # 上面那个 attacks.py,确保和 app.py 在同一目录或可 import 的包路径
 
 
9
 
10
  # MODEL_PATH = "weights/yolov8s_3cls.pt"
11
  MODEL_PATH = "weights/fed_model2.pt"
 
12
 
13
  names = ['car', 'van', 'truck']
14
  imgsz = 640
 
 
 
 
 
 
 
15
  # Load ultralytics model (wrapper)
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  yolom = YOLO(MODEL_PATH) # wrapper
 
18
  # put underlying module to eval on correct device might be needed in attacks functions
19
-
20
- def run_detection_on_pil(img_pil: Image.Image, conf: float = 0.4):
21
  """
22
  Use ultralytics wrapper predict to get a visualization image with boxes.
23
  This is inference-only and does not require gradient.
@@ -25,7 +36,8 @@ def run_detection_on_pil(img_pil: Image.Image, conf: float = 0.4):
25
  # ultralytics accepts numpy array (H,W,3) in RGB, we pass it directly
26
  img = np.array(img_pil)
27
  # use model.predict with verbose=False to avoid prints
28
- res = yolom.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False)
 
29
  r = res[0]
30
  im_out = img.copy()
31
  # Boxes object may be empty
@@ -45,15 +57,17 @@ def run_detection_on_pil(img_pil: Image.Image, conf: float = 0.4):
45
  pass
46
  return Image.fromarray(im_out)
47
 
48
- def detect_and_attack(image, attack_mode, eps, alpha, iters):
49
  if image is None:
50
  return None, None
 
51
  pil = Image.fromarray(image.astype('uint8'), 'RGB')
52
- original_vis = run_detection_on_pil(pil)
 
 
53
  if attack_mode == "none":
54
  return original_vis, None
55
 
56
- # Try the whitebox attacks; if they fail, fallback to demo perturbation
57
  try:
58
  if attack_mode == "fgsm":
59
  adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz)
@@ -62,34 +76,122 @@ def detect_and_attack(image, attack_mode, eps, alpha, iters):
62
  else:
63
  adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
64
  except Exception as ex:
65
- # fallback with informative message (will show in server logs)
66
  print("Whitebox attack failed:", ex)
67
  adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
68
 
69
- adv_vis = run_detection_on_pil(adv_pil)
70
  return original_vis, adv_vis
71
 
 
72
  # Gradio UI
73
  if __name__ == "__main__":
74
  title = "Federated Adversarial Attack — FGSM/PGD Demo"
75
- desc = "If the underlying model exposes a torch.nn.Module (whitebox), FGSM/PGD will compute gradients and produce true adversarial examples. "\
76
- "If not possible, the app falls back to a demo random perturbation."
77
-
78
- iface = gr.Interface(
79
- fn=detect_and_attack,
80
- inputs=[
81
- gr.Image(type="numpy", label="Input image"),
82
- gr.Radio(choices=["none", "fgsm", "pgd", "random noise"], value="none", label="Attack mode"),
83
- gr.Slider(minimum=0.0, maximum=0.3, step=0.01, value=0.0314, label="eps (L_inf radius)"),
84
- gr.Slider(minimum=0.001, maximum=0.05, step=0.001, value=0.0078, label="alpha (PGD step)"),
85
- gr.Slider(minimum=1, maximum=100, step=1, value=10, label="PGD iterations"),
86
- ],
87
- outputs=[gr.Image(label="Original detection"), gr.Image(label="After attack detection")],
88
- title=title,
89
- description=desc,
90
- allow_flagging="never",
91
  )
92
- iface.launch()
93
- # iface.launch(server_name="0.0.0.0", server_port=7860)
94
- # iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
6
  from ultralytics import YOLO
7
  import cv2
8
  import attacks # 上面那个 attacks.py,确保和 app.py 在同一目录或可 import 的包路径
9
+ import os, glob
10
+
11
 
12
  # MODEL_PATH = "weights/yolov8s_3cls.pt"
13
  MODEL_PATH = "weights/fed_model2.pt"
14
+ MODEL_PATH_C = "weights/yolov8s_3cls.pt"
15
 
16
  names = ['car', 'van', 'truck']
17
  imgsz = 640
18
+
19
+ SAMPLE_DIR = "./images/train"
20
+ SAMPLE_IMAGES = sorted([
21
+ p for p in glob.glob(os.path.join(SAMPLE_DIR, "*"))
22
+ if os.path.splitext(p)[1].lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
23
+ ])[:4] # 只取前4张
24
+
25
  # Load ultralytics model (wrapper)
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  yolom = YOLO(MODEL_PATH) # wrapper
28
+ # yolom_c = YOLO(MODEL_PATH_C) # wrapper
29
  # put underlying module to eval on correct device might be needed in attacks functions
30
+
31
+ def run_detection_on_pil(img_pil: Image.Image, eval_model_state, conf: float = 0.45):
32
  """
33
  Use ultralytics wrapper predict to get a visualization image with boxes.
34
  This is inference-only and does not require gradient.
 
36
  # ultralytics accepts numpy array (H,W,3) in RGB, we pass it directly
37
  img = np.array(img_pil)
38
  # use model.predict with verbose=False to avoid prints
39
+ eva_model = yolom if eval_model_state == "yolom" else YOLO(MODEL_PATH_C)
40
+ res = eva_model.predict(source=img, conf=conf, imgsz=imgsz, save=False, verbose=False)
41
  r = res[0]
42
  im_out = img.copy()
43
  # Boxes object may be empty
 
57
  pass
58
  return Image.fromarray(im_out)
59
 
60
+ def detect_and_attack(image, eval_model_state, attack_mode, eps, alpha, iters, conf):
61
  if image is None:
62
  return None, None
63
+
64
  pil = Image.fromarray(image.astype('uint8'), 'RGB')
65
+
66
+ original_vis = run_detection_on_pil(pil, eval_model_state, conf=conf)
67
+
68
  if attack_mode == "none":
69
  return original_vis, None
70
 
 
71
  try:
72
  if attack_mode == "fgsm":
73
  adv_pil = attacks.fgsm_attack_on_detector(yolom, pil, eps=eps, device=device, imgsz=imgsz)
 
76
  else:
77
  adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
78
  except Exception as ex:
 
79
  print("Whitebox attack failed:", ex)
80
  adv_pil = attacks.demo_random_perturbation(pil, eps=eps)
81
 
82
+ adv_vis = run_detection_on_pil(adv_pil, eval_model_state, conf=conf)
83
  return original_vis, adv_vis
84
 
85
+
86
  # Gradio UI
87
  if __name__ == "__main__":
88
  title = "Federated Adversarial Attack — FGSM/PGD Demo"
89
+ desc_html = (
90
+ "Adversarial examples are generated locally using a "
91
+ "<strong>client-side</strong> model’s gradients (white-box), then evaluated against the "
92
+ "<strong>server-side aggregated (FedAvg) central model</strong>. "
93
+ "If the perturbation transfers, it can "
94
+ "degrade or alter the FedAvg model’s predictions on the same input image."
 
 
 
 
 
 
 
 
 
 
95
  )
96
+ with gr.Blocks(title=title) as demo:
97
+ # 标题居中
98
+ gr.Markdown(f"""
99
+ <div>
100
+ <h1 style='text-align:center;margin-bottom:0.2rem'>{title}</h1>
101
+ <p style='opacity:0.85'>{desc_html}</p>
102
+ </div>""")
103
+
104
+ with gr.Row():
105
+ # ===== 左列:两个输入区块 =====
106
+ with gr.Column(scale=5):
107
+ # 输入区块 1:上传窗口 & 样例选择 —— 左右并列
108
+ with gr.Row():
109
+ with gr.Column(scale=7):
110
+ in_img = gr.Image(type="numpy", label="Input image")
111
+ with gr.Column(scale=2):
112
+ if SAMPLE_IMAGES:
113
+ gr.Examples(
114
+ examples=SAMPLE_IMAGES,
115
+ inputs=[in_img],
116
+ label=f"Select from sample images",
117
+ examples_per_page=4,
118
+ # run_on_click 默认为 False(只填充,不执行)
119
+ )
120
+
121
+ # 输入 2:攻击与参数
122
+ with gr.Accordion("Attack mode", open=True):
123
+ attack_mode = gr.Radio(
124
+ choices=["none", "fgsm", "pgd", "random noise"],
125
+ value="fgsm",
126
+ label="",
127
+ show_label=False
128
+ )
129
+ eps = gr.Slider(0.0, 0.3, step=0.01, value=0.0314, label="eps")
130
+ alpha = gr.Slider(0.001, 0.05, step=0.001, value=0.0078, label="alpha (PGD step)")
131
+ iters = gr.Slider(1, 100, step=1, value=10, label="PGD iterations")
132
+ conf = gr.Slider(0.0, 1.0, step=0.01, value=0.45, label="Confidence threshold (live)")
133
+
134
+ with gr.Row():
135
+ btn_clear = gr.ClearButton(
136
+ components=[in_img, eps, alpha, iters, conf], # 不清空 attack_mode
137
+ value="Clear"
138
+ )
139
+ btn_submit = gr.Button("Submit", variant="primary")
140
+
141
+ # ===== 右列:两个输出区块 =====
142
+ with gr.Column(scale=5):
143
+ # 新增:评测模型选择
144
+ with gr.Row():
145
+ eval_choice = gr.Dropdown(
146
+ choices=[(f"Client model {MODEL_PATH}", "client"),
147
+ (f"Central model {MODEL_PATH_C}", "central")],
148
+ value="client", # ★ 初始值为合法 value
149
+ label="Evaluation model"
150
+ )
151
+
152
+ eval_model_state = gr.State(value="yolom")
153
+
154
+ # ★ 合并后的单一回调:规范化下拉值 + 返回(更新后的下拉值, 模型对象)
155
+ def on_eval_change(val: str):
156
+ if isinstance(val, (list, tuple)):
157
+ val = val[0] if len(val) else "client"
158
+ if val not in ("client", "central"):
159
+ val = "client"
160
+ model = "yolom" if val == "client" else "yolom_c"
161
+ return gr.update(value=val), model
162
+
163
+ # 页面加载时同步一次,避免初次为空/不一致
164
+ demo.load(
165
+ fn=on_eval_change,
166
+ inputs=eval_choice,
167
+ outputs=[eval_choice, eval_model_state]
168
+ )
169
+
170
+ # 仅这一条 change 绑定(删掉你原来那个只写 State 的 change,避免并发覆盖)
171
+ eval_choice.change(
172
+ fn=on_eval_change,
173
+ inputs=eval_choice,
174
+ outputs=[eval_choice, eval_model_state]
175
+ )
176
+ out_orig = gr.Image(label="Original detection")
177
+ out_adv = gr.Image(label="After attack detection")
178
+
179
+ # Submit:手动运行
180
+ btn_submit.click(
181
+ fn=detect_and_attack,
182
+ inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf],
183
+ outputs=[out_orig, out_adv]
184
+ )
185
+
186
+ # 仅 conf 滑块“实时”
187
+ conf.release(
188
+ fn=detect_and_attack,
189
+ inputs=[in_img, eval_model_state, attack_mode, eps, alpha, iters, conf],
190
+ outputs=[out_orig, out_adv]
191
+ )
192
+
193
+ # demo.queue(concurrency_count=2, max_size=20)
194
+ demo.launch()
195
+ # demo.launch(server_name="0.0.0.0", server_port=7860)
196
+
197
 
images/train/c1.png ADDED

Git LFS Details

  • SHA256: 6681aac50ca717f3048089301482a01b0d6f007ef126231ecbead3ad02d1984e
  • Pointer size: 131 Bytes
  • Size of remote file: 859 kB
images/train/c2.png ADDED

Git LFS Details

  • SHA256: 913c94e8eb2dcab02b0c48a36cfdcdc2811368eeaa59baea3c2180c592433da4
  • Pointer size: 131 Bytes
  • Size of remote file: 263 kB
images/train/c3.png ADDED

Git LFS Details

  • SHA256: ffca42a28caa50805c61217e4a08884b8c038c0edd4786879dc74a1703a2bf82
  • Pointer size: 131 Bytes
  • Size of remote file: 498 kB
images/train/c4.png ADDED

Git LFS Details

  • SHA256: 9793d2afd966a929ff6147c22937b81c4da1fbebc98eddd869b31b389702f2bb
  • Pointer size: 131 Bytes
  • Size of remote file: 275 kB