Str0keOOOO commited on
Commit
f1799a0
·
2 Parent(s): e3f9290bb6b32a

merge:合并

Browse files
BFDS_train.py CHANGED
@@ -3,6 +3,19 @@ import logging
3
  import warnings
4
  import json
5
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from utils.logger import setlogger
8
  from utils.train import train_utils
@@ -28,15 +41,12 @@ class Argument:
28
  self.model_name = "ResNet_1d" # 模型名
29
  self.bottleneck = True # 是否使用bottleneck层
30
  self.bottleneck_num = 256 # bottleneck层的输出维数
31
- self.pretrained = False # 是否使用预训练模型
32
 
33
  # 训练
34
  self.batch_size = 64 # 批次大小
35
  self.cuda_device = "0" # 训练设备
36
- self.last_batch = False # 是否保留后的不完整批次
37
- self.max_epoch = 10 # 训练最大轮数
38
  self.num_workers = 0 # 训练设备数
39
- self.pretrained = False # 是否加载预训练模型
40
 
41
  # 数据记录
42
  self.checkpoint_dir = "./checkpoint" # 参数保存路径
@@ -58,7 +68,7 @@ class Argument:
58
 
59
  # 基于映射
60
  self.distance_option = True # 是否采用基于映射的损失
61
- self.distance_loss = "JMMD" # 损失模型 MK-MMD/JMMD/CORAL
62
  self.distance_tradeoff = "Step" # 损失的trade_off参数 Cons/Step
63
  self.distance_lambda = 1 # 若调整模式为Cons,指定其具体值
64
 
@@ -74,31 +84,55 @@ class Argument:
74
  # 输出可视化
75
  self.wavelet = "cmor1.5-1.0" # 小波类型
76
 
77
- def update_param(self, param_name, param_value):
78
- if hasattr(self, param_name):
79
- setattr(self, param_name, param_value)
80
- else:
81
- raise AttributeError(f"Parameter '{param_name}' does not exist.")
82
-
83
-
84
- def update_param(args, batch_size, optimizer, learning_rate, scheduler, transfer_method, distance_loss):
85
- if transfer_method not in ["基于映射", "基于领域对抗"]:
86
- return "错误: 迁移学习方式无效,请选择 '基于映射' 或 '基于领域对抗'。"
87
- args.update_param("batch_size", batch_size)
88
- args.update_param("opt", optimizer.lower())
89
- args.update_param("lr", learning_rate)
90
- args.update_param("lr_scheduler", scheduler)
91
- if transfer_method == "基于映射":
92
- args.update_param("adversarial_option", False)
93
- args.update_param("distance_option", True)
94
- elif transfer_method == "基于领域对抗":
95
- args.update_param("adversarial_option", True)
96
- args.update_param("distance_option", False)
97
- args.update_param("distance_loss", distance_loss)
98
- # 返回所有参数
99
- # FIXME __dict__
100
- all_params = {attr: getattr(args, attr) for attr in dir(args) if not attr.startswith("__") and not callable(getattr(args, attr))}
101
- return json.dumps(all_params, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  if __name__ == "__main__":
 
3
  import warnings
4
  import json
5
  from datetime import datetime
6
+ import requests
7
+
8
+ if __name__ == "__main__":
9
+ try:
10
+ # 这里尝试连接hugging face连接不上就换国内镜像源
11
+ response = requests.get("https://huggingface.co", timeout=5)
12
+ if response.status_code == 200:
13
+ print("成功连接到 Hugging Face")
14
+ else:
15
+ print(f"连接失败,状态码: {response.status_code}")
16
+ except requests.exceptions.RequestException:
17
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
18
+ print(f"无法连接到 Hugging Face:换源到{os.environ['HF_ENDPOINT']}")
19
 
20
  from utils.logger import setlogger
21
  from utils.train import train_utils
 
41
  self.model_name = "ResNet_1d" # 模型名
42
  self.bottleneck = True # 是否使用bottleneck层
43
  self.bottleneck_num = 256 # bottleneck层的输出维数
 
44
 
45
  # 训练
46
  self.batch_size = 64 # 批次大小
47
  self.cuda_device = "0" # 训练设备
48
+ self.max_epoch = 2 # 训练大轮数
 
49
  self.num_workers = 0 # 训练设备数
 
50
 
51
  # 数据记录
52
  self.checkpoint_dir = "./checkpoint" # 参数保存路径
 
68
 
69
  # 基于映射
70
  self.distance_option = True # 是否采用基于映射的损失
71
+ self.distance_loss = "MK-MMD" # 损失模型 MK-MMD/JMMD/CORAL
72
  self.distance_tradeoff = "Step" # 损失的trade_off参数 Cons/Step
73
  self.distance_lambda = 1 # 若调整模式为Cons,指定其具体值
74
 
 
84
  # 输出可视化
85
  self.wavelet = "cmor1.5-1.0" # 小波类型
86
 
87
+ def update_params(self, **kwargs):
88
+ """
89
+ 使用 **kwargs 动态更新 args 的参数。
90
+ """
91
+ for param_name, param_value in kwargs.items():
92
+ if hasattr(self, param_name):
93
+ setattr(self, param_name, param_value)
94
+ else:
95
+ print(f"警告: Parameter '{param_name}' does not exist.")
96
+
97
+ def set_recommended_params(self):
98
+ # 给用户设定的推荐参数
99
+ recommended_params = {
100
+ "data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
101
+ "conditions": fetch_all_conditions_from_huggingface("BFDS-Project/Bearing-Fault-Diagnosis-System"),
102
+ "labels": {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5},
103
+ "transfer_task": [["CWRU", "CWRU_12k_Drive_End_Bearing_Fault_Data"], ["CWRU", "CWRU_12k_Fan_End_Bearing_Fault_Data"]],
104
+ "normalize_type": None,
105
+ "model_name": "CNN",
106
+ "bottleneck": True,
107
+ "bottleneck_num": 256,
108
+ "batch_size": 64,
109
+ "cuda_device": "0",
110
+ "max_epoch": 2,
111
+ "num_workers": 0,
112
+ "checkpoint_dir": "./checkpoint",
113
+ "print_step": 50,
114
+ "opt": "adam",
115
+ "momentum": 0.9,
116
+ "weight_decay": 1e-5,
117
+ "lr": 1e-3,
118
+ "lr_scheduler": "step",
119
+ "gamma": 0.1,
120
+ "steps": [150, 250],
121
+ "middle_epoch": 0,
122
+ "distance_option": True,
123
+ "distance_loss": "JMMD",
124
+ "distance_tradeoff": "Step",
125
+ "distance_lambda": 1,
126
+ "adversarial_option": False,
127
+ "adversarial_loss": "CDA",
128
+ "hidden_size": 1024,
129
+ "grl_option": "Step",
130
+ "grl_lambda": 1,
131
+ "adversarial_tradeoff": "Step",
132
+ "adversarial_lambda": 1,
133
+ "wavelet": "cmor1.5-1.0",
134
+ }
135
+ self.update_params(**recommended_params)
136
 
137
 
138
  if __name__ == "__main__":
BFDS_web.py CHANGED
@@ -1,11 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import matplotlib
3
  import matplotlib.pyplot as plt
4
- from BFDS_train import Argument, update_param
5
- import pandas as pd
6
  import torch
7
  from utils.predict import predict
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # 设置 Matplotlib 的后端为非交互式后端
10
  matplotlib.use("Agg")
11
  plt.rcParams.update(
@@ -20,106 +48,359 @@ plt.rcParams.update(
20
 
21
  # 初始化 Argument 实例
22
  args = Argument()
 
23
 
24
 
25
  # 更新参数的函数
26
- def transfer_learning(batch_size, optimizer, learning_rate, scheduler, transfer_method, distance_loss):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # 这里更新参数
28
- all_params = update_param(args, batch_size, optimizer, learning_rate, scheduler, transfer_method, distance_loss)
 
 
29
  # 这里进行训练
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # 这里返回各种结果
32
- return all_params
 
 
 
 
 
 
33
 
34
 
35
  # 下面是信号推理的函数
36
  def signal_inference(model_file, signal_file):
 
37
  if model_file is None or signal_file is None:
38
- raise ValueError("请上传模型文件和信号数据")
39
  model_state_dict = torch.load(model_file)
40
- if isinstance(signal_file, list):
41
- for signal_file_single in signal_file:
42
- signal = pd.read_csv(signal_file_single)
43
- # FIXME 最后做成(n,1,128)的形式
44
- else:
45
- signal = pd.read_csv(signal_file)
46
- result = predict(model_state_dict, signal)
47
  return result
48
 
49
 
50
- # 创建一个绘图函数
51
- def create_plot():
52
- x = [1, 2, 3, 4, 5]
53
- y = [1, 4, 9, 16, 25]
54
- fig, ax = plt.subplots()
55
- ax.plot(x, y, label="y = x^2")
56
- ax.set_title("示例折线图")
57
- ax.set_xlabel("X 轴")
58
- ax.set_ylabel("Y 轴")
59
- ax.legend()
60
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  with gr.Blocks(title="BFDS WebUI") as app:
 
 
 
 
 
64
  with gr.Tab("模型训练"):
65
  gr.Markdown("在此模块中,您可以选择不同的迁移学习方法进行模型训练。")
66
- with gr.Tab("使用预训练模型"):
67
- gr.Markdown("使用预训练模型进行迁移学习,您可以选择以下参数进行配置:")
68
- with gr.Row():
69
- with gr.Column():
70
- batch_size_slider = gr.Slider(1, 258, label="batch_size", step=1, value=args.batch_size)
71
- optimizer_radio = gr.Radio(
72
- label="选择优化器",
73
- choices=["Adam", "SGD", "RMSprop"],
74
- value=args.opt.capitalize(),
75
- )
76
- learning_rate_slider = gr.Slider(1e-5, 1e-2, label="学习率", step=1e-5, value=args.lr)
77
- scheduler_radio = gr.Radio(
78
- label="学习率调度器",
79
- choices=["step", "exp", "stepLR", "fix"],
80
- value=args.lr_scheduler,
81
- )
82
- transfer_method_radio = gr.Radio(
83
- label="迁移学习方式",
84
- choices=["基于映射", "基于领域对抗"],
85
- value="基于领域对抗" if args.adversarial_option else "基于映射",
86
- )
87
- with gr.Column():
88
- distance_loss_radio = gr.Radio(
89
- label="距离损失函数",
90
- choices=["MK-MMD", "JMMD", "CORAL"],
91
- value="MK-MMD", # 修复默认值为有效选项
92
- )
93
- update_button = gr.Button("开始训练")
94
- with gr.Row():
95
- with gr.Column():
96
- # FIXME 需要弄好看一点
97
- args_all_params = gr.Textbox(label="更新结果", lines=8)
98
- with gr.Column():
99
- gr.Plot(create_plot)
100
- with gr.Tab("不使用预训练模型"):
101
- gr.Markdown("使用从零开始训练的方式,不依赖预训练模型。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  with gr.Tab("信号推理"):
104
  model_file = gr.File(label="模型文件", file_count="single", file_types=[".bin", ".pth", ".pt"])
105
- with gr.Tab("单次推理"):
106
- gr.Markdown("在此模块中,您可以上传信号数据进行推理。")
107
- signal_file_single = gr.File(label="上传信号数据", file_count="single", file_types=[".csv"])
108
- signal_inference_single_button = gr.Button("开始推理")
109
- signal_inference_single_output = gr.Textbox(label="推理结果", lines=8)
110
- with gr.Tab("批量推理"):
111
- gr.Markdown("在此模块中,您可以上传信号数据进行批量推理。")
112
- signal_file_multiple = gr.File(label="上传信号数据", file_count="multiple", file_types=[".csv"])
113
- signal_inference_multiple_button = gr.Button("开始批量推理")
114
- signal_inference_multiple_output = gr.Textbox(label="批量推理结果", lines=8)
115
 
116
  # 下面是所有函数绑定
117
- update_button.click(
118
  transfer_learning,
119
- inputs=[batch_size_slider, optimizer_radio, learning_rate_slider, scheduler_radio, transfer_method_radio, distance_loss_radio],
120
- outputs=args_all_params,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
- signal_inference_single_button.click(signal_inference, inputs=[model_file, signal_file_single], outputs=signal_inference_single_output)
123
- signal_inference_multiple_button.click(signal_inference, inputs=[model_file, signal_file_multiple], outputs=signal_inference_multiple_output)
 
 
 
 
 
 
 
124
  app.queue()
125
  app.launch()
 
1
+ import os
2
+ import requests
3
+ import zipfile
4
+
5
+ if __name__ == "__main__":
6
+ try:
7
+ # 这里尝试连接hugging face连接不上就换国内镜像源
8
+ response = requests.get("https://huggingface.co", timeout=5)
9
+ if response.status_code == 200:
10
+ print("成功连接到 Hugging Face")
11
+ else:
12
+ print(f"连接失败,状态码: {response.status_code}")
13
+ except requests.exceptions.RequestException:
14
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
15
+ print(f"无法连接到 Hugging Face:换源到{os.environ['HF_ENDPOINT']}")
16
+
17
+
18
  import gradio as gr
19
  import matplotlib
20
  import matplotlib.pyplot as plt
21
+ from BFDS_train import Argument
 
22
  import torch
23
  from utils.predict import predict
24
 
25
+ import logging
26
+ import warnings
27
+ from datetime import datetime
28
+
29
+
30
+ from utils.logger import setlogger
31
+ from utils.train import train_utils
32
+ from utils.fetch_conditions import fetch_all_conditions_from_huggingface
33
+
34
+ dataset_name = "BFDS-Project/Bearing-Fault-Diagnosis-System"
35
+ conditions = fetch_all_conditions_from_huggingface(dataset_name)
36
+
37
  # 设置 Matplotlib 的后端为非交互式后端
38
  matplotlib.use("Agg")
39
  plt.rcParams.update(
 
48
 
49
  # 初始化 Argument 实例
50
  args = Argument()
51
+ args.set_recommended_params()
52
 
53
 
54
  # 更新参数的函数
55
+ def transfer_learning(
56
+ source_config,
57
+ source_split,
58
+ target_path,
59
+ normalize_type,
60
+ model_name,
61
+ bottleneck,
62
+ bottleneck_num,
63
+ batch_size,
64
+ cuda_device,
65
+ max_epoch,
66
+ num_workers,
67
+ opt,
68
+ momentum,
69
+ weight_decay,
70
+ lr,
71
+ lr_scheduler,
72
+ gamma,
73
+ steps_start,
74
+ steps_end,
75
+ middle_epoch,
76
+ distance_option,
77
+ distance_loss,
78
+ distance_tradeoff,
79
+ distance_lambda,
80
+ adversarial_option,
81
+ adversarial_loss,
82
+ hidden_size,
83
+ grl_option,
84
+ grl_lambda,
85
+ adversarial_tradeoff,
86
+ adversarial_lambda,
87
+ wavelet,
88
+ ):
89
+ args_params_dict = {
90
+ "transfer_task": [[source_config, source_split], []],
91
+ "normalize_type": normalize_type,
92
+ "model_name": model_name,
93
+ "bottleneck": bottleneck,
94
+ "bottleneck_num": bottleneck_num,
95
+ "batch_size": batch_size,
96
+ "cuda_device": cuda_device,
97
+ "max_epoch": max_epoch,
98
+ "num_workers": num_workers,
99
+ "opt": opt,
100
+ "momentum": momentum,
101
+ "weight_decay": weight_decay,
102
+ "lr": lr,
103
+ "lr_scheduler": lr_scheduler,
104
+ "gamma": gamma,
105
+ "steps": [steps_start, steps_end],
106
+ "middle_epoch": middle_epoch,
107
+ "distance_option": distance_option,
108
+ "distance_loss": distance_loss,
109
+ "distance_tradeoff": distance_tradeoff,
110
+ "distance_lambda": distance_lambda,
111
+ "adversarial_option": adversarial_option,
112
+ "adversarial_loss": adversarial_loss,
113
+ "hidden_size": hidden_size,
114
+ "grl_option": grl_option,
115
+ "grl_lambda": grl_lambda,
116
+ "adversarial_tradeoff": adversarial_tradeoff,
117
+ "adversarial_lambda": adversarial_lambda,
118
+ "wavelet": wavelet,
119
+ }
120
  # 这里更新参数
121
+ if target_path is None:
122
+ raise ValueError("请上传目标域数据!")
123
+ args.update_params(**args_params_dict)
124
  # 这里进行训练
125
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_device.strip()
126
+ warnings.filterwarnings("ignore")
127
+ save_dir = os.path.join(args.checkpoint_dir, args.model_name + "_" + datetime.strftime(datetime.now(), "%m%d-%H%M%S"))
128
+ setattr(args, "save_dir", save_dir)
129
+ if not os.path.exists(args.save_dir):
130
+ os.makedirs(args.save_dir)
131
+ # 设定日志
132
+ setlogger(os.path.join(args.save_dir, "train.log"))
133
+ # 保存超参数
134
+ for k, v in args.__dict__.items():
135
+ if k[-3:] != "dir":
136
+ logging.info(f"{k}: {v}")
137
+ # 训练
138
+ trainer = train_utils(args, owned=True, data_path=target_path)
139
+ trainer.setup()
140
+ trainer.train()
141
+ fig = trainer.generate_fig()
142
 
143
+ # 压缩 save_dir 文件夹
144
+ zip_filename = f"{trainer.save_dir}.zip"
145
+ with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
146
+ for root, dirs, files in os.walk(trainer.save_dir):
147
+ for file in files:
148
+ zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(trainer.save_dir, "..")))
149
+
150
+ return fig, zip_filename
151
 
152
 
153
  # 下面是信号推理的函数
154
  def signal_inference(model_file, signal_file):
155
+ result = []
156
  if model_file is None or signal_file is None:
157
+ raise ValueError("请上传模型文件和信号数据!")
158
  model_state_dict = torch.load(model_file)
159
+ for signal_file_single in signal_file:
160
+ result.append(predict(model_state_dict, signal_file_single, args))
 
 
 
 
 
161
  return result
162
 
163
 
164
+ def change_source_split(source_config_radio):
165
+ source_splits = conditions[source_config_radio]
166
+ return gr.update(choices=source_splits, value=source_splits[0])
167
+
168
+
169
+ def change_bottleneck(bottleneck):
170
+ return gr.update(visible=bottleneck)
171
+
172
+
173
+ def change_opt(opt):
174
+ if opt == "sgd":
175
+ return gr.update(visible=True), gr.update(visible=True)
176
+ elif opt == "adam":
177
+ return gr.update(visible=False), gr.update(visible=False)
178
+
179
+
180
+ def change_lr_scheduler(lr_scheduler):
181
+ if lr_scheduler == "step":
182
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
183
+ elif lr_scheduler == "exp":
184
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
185
+ elif lr_scheduler == "stepLR":
186
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
187
+ elif lr_scheduler == "fix":
188
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
189
+
190
+
191
+ def change_steps_start(steps_start, steps_end):
192
+ if steps_start >= steps_end:
193
+ steps_start = steps_end - 1
194
+ return gr.update(value=steps_start, maximum=steps_end - 1)
195
+
196
+
197
+ def change_steps_end(steps_start, steps_end):
198
+ if steps_end <= steps_start:
199
+ steps_end = steps_start + 1
200
+ return gr.update(value=steps_end, minimum=steps_start + 1)
201
 
202
 
203
+ def change_max_epoch(max_epoch, middle_epoch):
204
+ if middle_epoch >= max_epoch:
205
+ middle_epoch = max_epoch - 1
206
+ return gr.update(value=max_epoch, maximum=max_epoch - 1)
207
+
208
+
209
+ def change_middle_epoch(max_epoch, middle_epoch):
210
+ if middle_epoch >= max_epoch:
211
+ middle_epoch = max_epoch - 1
212
+ return gr.update(value=middle_epoch, maximum=max_epoch - 1)
213
+
214
+
215
+ def change_distance_option(distance_option, distance_tradeoff):
216
+ if distance_option:
217
+ return gr.update(value=False), gr.update(visible=distance_option), gr.update(visible=distance_option), gr.update(visible=(distance_option and distance_tradeoff == "Cons"))
218
+ else:
219
+ return gr.update(value=False), gr.update(visible=distance_option), gr.update(visible=distance_option), gr.update(visible=(distance_option and distance_tradeoff == "Cons"))
220
+
221
+
222
+ def change_adversarial_option(adversarial_option, adversarial_tradeoff):
223
+ return (
224
+ gr.update(value=not adversarial_option),
225
+ gr.update(visible=adversarial_option),
226
+ gr.update(visible=adversarial_option),
227
+ gr.update(visible=adversarial_option),
228
+ gr.update(visible=adversarial_option),
229
+ gr.update(visible=adversarial_option),
230
+ gr.update(visible=(adversarial_option and adversarial_tradeoff == "Cons")),
231
+ )
232
+
233
+
234
+ def change_distance_tradeoff(distance_option, distance_tradeoff):
235
+ return gr.update(visible=(distance_option and distance_tradeoff == "Cons"))
236
+
237
+
238
+ def change_adversarial_tradeoff(adversarial_option, adversarial_tradeoff):
239
+ return (gr.update(visible=(adversarial_option and adversarial_tradeoff == "Cons")),)
240
+
241
+
242
+ with open("docs/BFDS_font.html", "r", encoding="utf-8") as f:
243
+ BFDS_font_html = f.read()
244
+
245
+ # gradio BFDS_web.py --demo-name app
246
  with gr.Blocks(title="BFDS WebUI") as app:
247
+ gr.HTML(BFDS_font_html)
248
+ gr.Markdown("""
249
+ # 轴承故障诊断系统
250
+ 基于深度迁移学习的智能轴承故障诊断系统。支持多种迁移学习算法、信号处理方法和故障诊断模型。
251
+ """)
252
  with gr.Tab("模型训练"):
253
  gr.Markdown("在此模块中,您可以选择不同的迁移学习方法进行模型训练。")
254
+ with gr.Row():
255
+ with gr.Column():
256
+ source_config_radio = gr.Radio(
257
+ label="选择源域数据集名称",
258
+ choices=list(conditions.keys()),
259
+ value=args.transfer_task[0][0],
260
+ )
261
+ source_split_radio = gr.Radio(
262
+ label="选择源域数据集工况",
263
+ choices=conditions[args.transfer_task[0][0]],
264
+ value=args.transfer_task[0][1],
265
+ )
266
+ target_file = gr.File(label="目标域数据集", file_count="single", file_types=[".csv"])
267
+ normalize_type_radio = gr.Radio(
268
+ label="选择归一化方式",
269
+ choices=["mean-std", "min-max", None],
270
+ value=args.normalize_type,
271
+ )
272
+ model_name_radio = gr.Radio(
273
+ label="选择模型名称",
274
+ choices=["CNN"],
275
+ value=args.model_name,
276
+ )
277
+ bottleneck_checkbox = gr.Checkbox(
278
+ label="是否使用瓶颈层",
279
+ value=args.bottleneck,
280
+ )
281
+ bottleneck_num_slider = gr.Slider(1, 1024, label="瓶颈层神经元个数", step=1, value=args.bottleneck_num, visible=args.bottleneck)
282
+ batch_size_slider = gr.Slider(1, 258, label="batch_size", step=1, value=args.batch_size)
283
+ cuda_device_radio = gr.Radio(
284
+ label="选择GPU设备",
285
+ choices=["0"],
286
+ value=args.cuda_device,
287
+ )
288
+ max_epoch_slider = gr.Slider(args.middle_epoch + 1, 100, label="max_epoch", step=1, value=args.max_epoch)
289
+ num_workers_slider = gr.Slider(1, 16, label="num_workers", step=1, value=args.num_workers)
290
+ opt_radio = gr.Radio(
291
+ label="选择优化器",
292
+ choices=["sgd", "adam"],
293
+ value=args.opt,
294
+ )
295
+ momentum_slider = gr.Slider(0, 1, label="momentum", step=0.01, value=args.momentum)
296
+ weight_decay_slider = gr.Slider(1e-5, 1e-1, label="weight_decay", step=1e-5, value=args.weight_decay)
297
+ lr_slider = gr.Slider(1e-5, 1e-2, label="学习率", step=1e-5, value=args.lr)
298
+ lr_scheduler_radio = gr.Radio(
299
+ label="学习率调度器",
300
+ choices=["step", "exp", "stepLR", "fix"],
301
+ value=args.lr_scheduler,
302
+ )
303
+ gamma_slider = gr.Slider(1e-5, 1e-2, label="gamma", step=1e-5, value=args.gamma, visible=args.lr_scheduler != "fix")
304
+ steps_start_slider = gr.Slider(1, args.steps[1] - 1, label="steps 第一个值", step=1, value=args.steps[0], visible=(args.lr_scheduler == "step" or args.lr_scheduler == "stepLR"))
305
+ steps_end_slider = gr.Slider(args.steps[0] + 1, 1000, label="steps 第二个值", step=1, value=args.steps[1], visible=(args.lr_scheduler == "step" or args.lr_scheduler == "stepLR"))
306
+ middle_epoch_slider = gr.Slider(0, args.max_epoch - 1, label="middle_epoch", step=1, value=args.middle_epoch)
307
+ wavelet_radio = gr.Radio(
308
+ label="选择波形变换",
309
+ choices=["cmor1.5-1.0"],
310
+ value=args.wavelet,
311
+ )
312
+ with gr.Column():
313
+ # 这两个true和false不能一起出现
314
+ distance_option_checkbox = gr.Checkbox(
315
+ label="是否使用距离损失",
316
+ value=args.distance_option,
317
+ )
318
+ distance_loss_radio = gr.Radio(label="距离损失函数", choices=["MK-MMD", "JMMD", "CORAL"], value=args.distance_loss, visible=args.distance_option)
319
+ distance_tradeoff_radio = gr.Radio(label="距离损失权重", choices=["Cons", "Step"], value=args.distance_tradeoff, visible=args.distance_option)
320
+ distance_lambda_slider = gr.Slider(1, 2, label="距离损失权重", step=1e-5, value=args.distance_lambda, visible=(args.distance_option and args.distance_tradeoff == "Cons"))
321
+ adversarial_option_checkbox = gr.Checkbox(
322
+ label="是否使用对抗损失",
323
+ value=args.adversarial_option,
324
+ )
325
+ adversarial_loss_radio = gr.Radio(label="对抗损失函数", choices=["DA", "CDA", "CDA+E"], value=args.adversarial_loss, visible=args.adversarial_option)
326
+ hidden_size_slider = gr.Slider(1, 1024, label="对抗层神经元个数", step=1, value=args.hidden_size, visible=args.adversarial_option)
327
+ grl_option_radio = gr.Radio(label="是否使用梯度反转层", choices=["Step"], value=args.grl_option, visible=args.adversarial_option)
328
+ grl_lambda_slider = gr.Slider(1, 2, label="梯度反转层系数", step=1e-5, value=args.grl_lambda, visible=args.adversarial_option)
329
+ adversarial_tradeoff_radio = gr.Radio(label="对抗损失权重", choices=["Cons", "Step"], value=args.adversarial_tradeoff, visible=args.adversarial_option)
330
+ adversarial_lambda_slider = gr.Slider(1, 2, label="对抗损失权重", step=1e-5, value=args.adversarial_lambda, visible=(args.adversarial_option and args.adversarial_tradeoff == "Cons"))
331
+
332
+ transfer_learning_button = gr.Button("开始训练")
333
+ with gr.Row():
334
+ with gr.Column():
335
+ download_output = gr.File(label="下载训练结果压缩包", interactive=False)
336
+ with gr.Column():
337
+ plot_component = gr.Plot(label="训练结果图表")
338
 
339
  with gr.Tab("信号推理"):
340
  model_file = gr.File(label="模型文件", file_count="single", file_types=[".bin", ".pth", ".pt"])
341
+ gr.Markdown("在此模块中,您可以上传信号数据进行批量推理")
342
+ signal_file_multiple = gr.File(label="上传信号数据", file_count="multiple", file_types=[".csv"])
343
+ signal_inference_button = gr.Button("开始批量推理")
344
+ signal_inference_output = gr.Textbox(label="批量推理结果", lines=8)
 
 
 
 
 
 
345
 
346
  # 下面是所有函数绑定
347
+ transfer_learning_button.click(
348
  transfer_learning,
349
+ inputs=[
350
+ source_config_radio,
351
+ source_split_radio,
352
+ target_file,
353
+ normalize_type_radio,
354
+ model_name_radio,
355
+ bottleneck_checkbox,
356
+ bottleneck_num_slider,
357
+ batch_size_slider,
358
+ cuda_device_radio,
359
+ max_epoch_slider,
360
+ num_workers_slider,
361
+ opt_radio,
362
+ momentum_slider,
363
+ weight_decay_slider,
364
+ lr_slider,
365
+ lr_scheduler_radio,
366
+ gamma_slider,
367
+ steps_start_slider,
368
+ steps_end_slider,
369
+ middle_epoch_slider,
370
+ distance_option_checkbox,
371
+ distance_loss_radio,
372
+ distance_tradeoff_radio,
373
+ distance_lambda_slider,
374
+ adversarial_option_checkbox,
375
+ adversarial_loss_radio,
376
+ hidden_size_slider,
377
+ grl_option_radio,
378
+ grl_lambda_slider,
379
+ adversarial_tradeoff_radio,
380
+ adversarial_lambda_slider,
381
+ wavelet_radio,
382
+ ],
383
+ outputs=[plot_component, download_output],
384
+ )
385
+ source_config_radio.change(change_source_split, inputs=[source_config_radio], outputs=[source_split_radio])
386
+ opt_radio.change(change_opt, inputs=[opt_radio], outputs=[momentum_slider, weight_decay_slider])
387
+ bottleneck_checkbox.change(change_bottleneck, inputs=[bottleneck_checkbox], outputs=[bottleneck_num_slider])
388
+ lr_scheduler_radio.change(change_lr_scheduler, inputs=[lr_scheduler_radio], outputs=[steps_start_slider, steps_end_slider, gamma_slider])
389
+ steps_start_slider.change(change_steps_start, inputs=[steps_start_slider, steps_end_slider], outputs=[steps_start_slider])
390
+ steps_end_slider.change(change_steps_end, inputs=[steps_start_slider, steps_end_slider], outputs=[steps_end_slider])
391
+ max_epoch_slider.change(change_middle_epoch, inputs=[max_epoch_slider, middle_epoch_slider], outputs=[middle_epoch_slider])
392
+ middle_epoch_slider.change(change_middle_epoch, inputs=[max_epoch_slider, middle_epoch_slider], outputs=[middle_epoch_slider])
393
+ distance_option_checkbox.change(
394
+ change_distance_option, inputs=[distance_option_checkbox, distance_tradeoff_radio], outputs=[adversarial_option_checkbox, distance_loss_radio, distance_tradeoff_radio, distance_lambda_slider]
395
  )
396
+ adversarial_option_checkbox.change(
397
+ change_adversarial_option,
398
+ inputs=[adversarial_option_checkbox, adversarial_tradeoff_radio],
399
+ outputs=[distance_option_checkbox, adversarial_loss_radio, hidden_size_slider, grl_option_radio, grl_lambda_slider, adversarial_tradeoff_radio, adversarial_lambda_slider],
400
+ )
401
+ distance_tradeoff_radio.change(change_distance_tradeoff, inputs=[distance_option_checkbox, distance_tradeoff_radio], outputs=[distance_lambda_slider])
402
+ adversarial_tradeoff_radio.change(change_adversarial_tradeoff, inputs=[adversarial_option_checkbox, adversarial_tradeoff_radio], outputs=[adversarial_lambda_slider])
403
+ signal_inference_button.click(signal_inference, inputs=[model_file, signal_file_multiple], outputs=signal_inference_output)
404
+
405
  app.queue()
406
  app.launch()
dataset/dataset.py CHANGED
@@ -1,4 +1,5 @@
1
  import pandas as pd
 
2
  from datasets import load_dataset
3
  import torch
4
  from torch.utils.data import Dataset, DataLoader, random_split
@@ -6,11 +7,27 @@ from typing import Optional, Literal
6
 
7
 
8
  def get_dataset(data_set, subset, split):
9
- # TODO 换源
10
  ds = load_dataset(data_set, subset)
11
  return ds[split].to_pandas()
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class SignalDataset(Dataset):
15
  def __init__(self, data_frame: pd.DataFrame, normalize_type: Optional[Literal["mean-std", "min-max"]] = None):
16
  if normalize_type == "mean-std":
@@ -36,34 +53,38 @@ class SignalDatasetCreator:
36
  self.source = transfer_task[0]
37
  self.target = transfer_task[1]
38
 
39
- def data_split(self, batch_size, num_workers, device, transfer_learning=True):
40
- if transfer_learning:
41
- # get source train and val
42
- data_frame = get_dataset(self.data_set, self.source[0], self.source[1])
43
- data_set = SignalDataset(data_frame)
44
- lengths = [round(0.8 * len(data_set)), len(data_set) - round(0.8 * len(data_set))]
45
- train_data, eval_data = random_split(data_set, lengths)
46
- source_train = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"),drop_last=True)
47
- source_val = DataLoader(dataset=eval_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"),drop_last=True)
48
- # get target train and val
49
- data_frame=get_dataset(self.data_set, self.target[0], self.target[1])
50
- data_set = SignalDataset(data_frame)
51
- lengths = [round(0.8 * len(data_set)), len(data_set) - round(0.8 * len(data_set))]
52
- train_data, eval_data = random_split(data_set, lengths)
53
- target_train = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"),drop_last=True)
54
- target_val = DataLoader(dataset=eval_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"),drop_last=True)
55
- return source_train, source_val, target_train, target_val
56
- else:
57
- # get source train and val
58
- data_frame = get_dataset(self.data_set, self.source[0], self.source[1])
59
- data_set = SignalDataset(data_frame)
60
- lengths = [round(0.8 * len(data_set)), len(data_set) - round(0.8 * len(data_set))]
61
- train_data, eval_data = random_split(data_set, lengths)
62
- source_train = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"))
63
- source_val = DataLoader(dataset=eval_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"))
64
- # get target val
65
- get_dataset(self.data_set, self.target[0], self.target[1])
66
- data_frame = get_dataset(self.data_set, self.target[0, 0], self.target[0, 1])
67
- data_set = SignalDataset(data_frame)
68
- target_val = DataLoader(dataset=data_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"))
69
- return source_train, source_val, target_val
 
 
 
 
 
1
  import pandas as pd
2
+ import numpy as np
3
  from datasets import load_dataset
4
  import torch
5
  from torch.utils.data import Dataset, DataLoader, random_split
 
7
 
8
 
9
  def get_dataset(data_set, subset, split):
 
10
  ds = load_dataset(data_set, subset)
11
  return ds[split].to_pandas()
12
 
13
 
14
+ def get_owned_dataset(data_path):
15
+ # 提供更多读取方式,和预测一起整理一下
16
+ df = pd.read_csv(data_path).dropna()
17
+ data = df.values
18
+ if data.size % 224 != 0:
19
+ raise ValueError(f"数据大小 {data.size} 不能被 224 整除,无法重塑为 (-1, 224)")
20
+ # 重塑数据为 (-1, 224)
21
+ reshaped_data = data.reshape(-1, 224)
22
+ # 创建一个全是 0 的列
23
+ zero_column = np.zeros((reshaped_data.shape[0], 1))
24
+ # 将 reshaped_data 和 zero_column 拼接成新的数组
25
+ new_data = np.hstack((reshaped_data, zero_column))
26
+ # 将新的数组转换为 DataFrame
27
+ owned_df = pd.DataFrame(new_data, columns=[f"col_{i}" for i in range(225)])
28
+ return owned_df
29
+
30
+
31
  class SignalDataset(Dataset):
32
  def __init__(self, data_frame: pd.DataFrame, normalize_type: Optional[Literal["mean-std", "min-max"]] = None):
33
  if normalize_type == "mean-std":
 
53
  self.source = transfer_task[0]
54
  self.target = transfer_task[1]
55
 
56
+ def data_split(self, batch_size, num_workers, device):
57
+ # 这里源域和目标域都是我们提供用来验证迁移学习的正确性
58
+ # get source train and val
59
+ data_frame_source = get_dataset(self.data_set, self.source[0], self.source[1])
60
+ data_set_source = SignalDataset(data_frame_source)
61
+ lengths_source = [round(0.8 * len(data_set_source)), len(data_set_source) - round(0.8 * len(data_set_source))]
62
+ train_data_source, eval_data_source = random_split(data_set_source, lengths_source)
63
+ source_train = DataLoader(dataset=train_data_source, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
64
+ source_val = DataLoader(dataset=eval_data_source, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
65
+ # get target train and val
66
+ data_frame_target = get_dataset(self.data_set, self.target[0], self.target[1])
67
+ data_set_target = SignalDataset(data_frame_target)
68
+ lengths_target = [round(0.8 * len(data_set_target)), len(data_set_target) - round(0.8 * len(data_set_target))]
69
+ train_data_target, eval_data_target = random_split(data_set_target, lengths_target)
70
+ target_train = DataLoader(dataset=train_data_target, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
71
+ target_val = DataLoader(dataset=eval_data_target, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
72
+ return source_train, source_val, target_train, target_val
73
+
74
+ def owned_data_split(self, data_path, batch_size, num_workers, device):
75
+ # 这里目标域是用户自己提供的数据集
76
+ # get source train and val
77
+ data_frame_source = get_dataset(self.data_set, self.source[0], self.source[1])
78
+ data_set_source = SignalDataset(data_frame_source)
79
+ lengths_source = [round(0.8 * len(data_set_source)), len(data_set_source) - round(0.8 * len(data_set_source))]
80
+ train_data_source, eval_data_source = random_split(data_set_source, lengths_source)
81
+ source_train = DataLoader(dataset=train_data_source, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
82
+ source_val = DataLoader(dataset=eval_data_source, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
83
+ # get target train and val
84
+ data_frame_target = get_owned_dataset(data_path)
85
+ data_set_target = SignalDataset(data_frame_target)
86
+ lengths_target = [round(0.8 * len(data_set_target)), len(data_set_target) - round(0.8 * len(data_set_target))]
87
+ train_data_target, eval_data_target = random_split(data_set_target, lengths_target)
88
+ target_train = DataLoader(dataset=train_data_target, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
89
+ target_val = DataLoader(dataset=eval_data_target, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
90
+ return source_train, source_val, target_train, target_val
docs/BFDS_font.html ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Bearing Fault Diagnosis ASCII Art</title>
7
+ <style>
8
+ #taag_output_text {
9
+ font-family: "Courier New", ui-monospace, monospace;
10
+ font-size: 10pt;
11
+ white-space: pre;
12
+ overflow-wrap: break-word;
13
+ margin-top: 15px;
14
+ margin-bottom: 15px;
15
+ float: left;
16
+ }
17
+ </style>
18
+ </head>
19
+ <body>
20
+ <pre id="taag_output_text">
21
+ ________ ________ ________ ________ ________ ________ ________ ___ _______ ________ _________
22
+ |\ __ \|\ _____\\ ___ \|\ ____\ |\ __ \|\ __ \|\ __ \ |\ \|\ ___ \ |\ ____\\___ ___\
23
+ \ \ \|\ /\ \ \__/\ \ \_|\ \ \ \___|_ ____________\ \ \|\ \ \ \|\ \ \ \|\ \ \ \ \ \ __/|\ \ \___\|___ \ \_|
24
+ \ \ __ \ \ __\\ \ \ \\ \ \_____ \|\____________\ \ ____\ \ _ _\ \ \\\ \ __ \ \ \ \ \_|/_\ \ \ \ \ \
25
+ \ \ \|\ \ \ \_| \ \ \_\\ \|____|\ \|____________|\ \ \___|\ \ \\ \\ \ \\\ \|\ \\_\ \ \ \_|\ \ \ \____ \ \ \
26
+ \ \_______\ \__\ \ \_______\____\_\ \ \ \__\ \ \__\\ _\\ \_______\ \________\ \_______\ \_______\ \ \__\
27
+ \|_______|\|__| \|_______|\_________\ \|__| \|__|\|__|\|_______|\|________|\|_______|\|_______| \|__|
28
+ \|_________|
29
+ </pre>
30
+ </body>
31
+ </html>
docs/demo.png DELETED

Git LFS Details

  • SHA256: 17b2460182578ad9cd6db2e80543dd40f917da399cd2382aa07d4dd1b94512f9
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
utils/fetch_conditions.py CHANGED
@@ -1,9 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from datasets import get_dataset_config_names, get_dataset_split_names
2
  import json
3
 
4
 
5
  def fetch_all_conditions_from_huggingface(dataset_name):
6
- # TODO 换源
7
  """所有数据集的subset和split
8
  具体见网页https://huggingface.co/datasets/BFDS-Project/Bearing-Fault-Diagnosis-System
9
  Args:
@@ -26,4 +40,6 @@ if __name__ == "__main__":
26
  dataset_name = "BFDS-Project/Bearing-Fault-Diagnosis-System"
27
  conditions = fetch_all_conditions_from_huggingface(dataset_name)
28
  print("huggingface上的数据集配置和分割信息:")
29
- print(json.dumps(conditions, indent=2))
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ if __name__ == "__main__":
5
+ try:
6
+ # 这里尝试连接hugging face连接不上就换国内镜像源
7
+ response = requests.get("https://huggingface.co", timeout=5)
8
+ if response.status_code == 200:
9
+ print("成功连接到 Hugging Face")
10
+ else:
11
+ print(f"连接失败,状态码: {response.status_code}")
12
+ except requests.exceptions.RequestException:
13
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
14
+ print(f"无法连接到 Hugging Face:换源到{os.environ['HF_ENDPOINT']}")
15
+
16
  from datasets import get_dataset_config_names, get_dataset_split_names
17
  import json
18
 
19
 
20
  def fetch_all_conditions_from_huggingface(dataset_name):
 
21
  """所有数据集的subset和split
22
  具体见网页https://huggingface.co/datasets/BFDS-Project/Bearing-Fault-Diagnosis-System
23
  Args:
 
40
  dataset_name = "BFDS-Project/Bearing-Fault-Diagnosis-System"
41
  conditions = fetch_all_conditions_from_huggingface(dataset_name)
42
  print("huggingface上的数据集配置和分割信息:")
43
+ # print(json.dumps(conditions, indent=2))
44
+ # 返回conditions的key用数组存储
45
+ print(conditions[0][0])
utils/logger.py CHANGED
@@ -8,6 +8,9 @@ def setlogger(path):
8
  path(_str_): log文件保存路径
9
  """
10
  logger = logging.getLogger()
 
 
 
11
  logger.setLevel(logging.INFO)
12
  logFormatter = logging.Formatter("%(asctime)s %(message)s", "%m-%d %H:%M:%S") # 格式为 月-日 时:分:秒
13
 
 
8
  path(_str_): log文件保存路径
9
  """
10
  logger = logging.getLogger()
11
+ if logger.hasHandlers(): # 检查是否已经有处理器
12
+ logger.handlers.clear() # 清除已有的处理器,避免重复输出
13
+
14
  logger.setLevel(logging.INFO)
15
  logFormatter = logging.Formatter("%(asctime)s %(message)s", "%m-%d %H:%M:%S") # 格式为 月-日 时:分:秒
16
 
utils/predict.py CHANGED
@@ -1,28 +1,45 @@
1
- from models.CNN import cnn_features
2
  import torch
3
  import torch.nn as nn
 
 
 
 
4
 
5
- from main import Argument
6
 
7
- # FIXME 这里的 Argument 类看是直接导入还是后期需要main传进来
8
- args = Argument()
 
9
 
10
 
11
- def predict(model_state_dict, signal):
12
- model = cnn_features()
 
 
 
 
 
 
 
13
  bottleneck_layer = nn.Sequential(
14
  nn.Linear(model.output_num(), args.bottleneck_num),
15
  nn.ReLU(inplace=True),
16
  nn.Dropout(),
17
- )
18
- classifier_layer = nn.Linear(args.bottleneck_num, 10)
19
- model_all = nn.Sequential(model, bottleneck_layer, classifier_layer)
20
  model_all.load_state_dict(model_state_dict)
21
-
22
- # 设置为评估模式
23
  model_all.eval()
24
- # 进行预测
25
  with torch.no_grad():
26
- # FIXME signal的处理还没有写
27
- output = model_all(torch.randn(10, 1, 1024))
28
- return torch.argmax(output, dim=1)
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import pandas as pd
4
+ import librosa
5
+ from pathlib import Path
6
+ import models
7
 
 
8
 
9
+ def audio_to_signal(audio_file, sr=None):
10
+ signal, _ = librosa.load(audio_file, sr=sr)
11
+ return signal
12
 
13
 
14
+ def csv_to_signal(signal_file):
15
+ signal = pd.read_csv(signal_file).to_numpy().flatten()
16
+ return signal
17
+
18
+
19
+ # 修改backbone
20
+ def predict(model_state_dict, signal_file, args):
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model = getattr(models, args.model_name)().to(device)
23
  bottleneck_layer = nn.Sequential(
24
  nn.Linear(model.output_num(), args.bottleneck_num),
25
  nn.ReLU(inplace=True),
26
  nn.Dropout(),
27
+ ).to(device)
28
+ classifier_layer = nn.Linear(args.bottleneck_num, len(args.labels)).to(device)
29
+ model_all = nn.Sequential(model, bottleneck_layer, classifier_layer).to(device)
30
  model_all.load_state_dict(model_state_dict)
31
+ # 模型预测
 
32
  model_all.eval()
 
33
  with torch.no_grad():
34
+ # 根据文件后缀选择处理方式
35
+ file_extension = Path(signal_file).suffix
36
+ if file_extension == ".csv":
37
+ signal = csv_to_signal(signal_file).reshape(-1, 1, 224)
38
+ elif file_extension in [".wav", ".mp3"]:
39
+ signal = audio_to_signal(signal_file).reshape(-1, 1, 224)
40
+ else:
41
+ raise ValueError(f"Unsupported file type: {file_extension}")
42
+ signal = torch.tensor(signal, dtype=torch.float32).to(device)
43
+ output = model_all(signal)
44
+ predictions = output.mean(dim=0)
45
+ return predictions
utils/train.py CHANGED
@@ -14,12 +14,14 @@ import matplotlib.pyplot as plt
14
  import models
15
  from models.AdversarialNet import AdversarialNet, calc_coeff, grl_hook, Entropy
16
  from dataset.dataset import SignalDatasetCreator
17
- from .loss import DAN, JAN, CORAL
18
 
19
 
20
  class train_utils:
21
- def __init__(self, args):
22
  self.args = args
 
 
23
 
24
  def setup(self):
25
  args = self.args
@@ -38,17 +40,18 @@ class train_utils:
38
  logging.info(f"using {self.device_count} cpu")
39
 
40
  # 加载数据集
41
- signal_dataset_creator = SignalDatasetCreator(args.data_set, args.labels, args.transfer_task)
42
- self.datasets = {}
43
- self.datasets["source_train"], self.datasets["source_val"], self.datasets["target_train"], self.datasets["target_val"] = signal_dataset_creator.data_split(
44
- args.batch_size, args.num_workers, self.device, transfer_learning=True
45
- )
46
- self.dataloaders = {
47
- "source_train": self.datasets["source_train"],
48
- "source_val": self.datasets["source_val"],
49
- "target_train": self.datasets["target_train"],
50
- "target_val": self.datasets["target_val"],
51
- }
 
52
  # 定义模型
53
  self.model = getattr(models, args.model_name)()
54
  if args.bottleneck:
@@ -91,9 +94,7 @@ class train_utils:
91
  )
92
  else:
93
  if args.bottleneck_num:
94
- self.AdversarialNet = AdversarialNet(
95
- in_feature=args.bottleneck_num, hidden_size=args.hidden_size, max_iter=self.max_iter, grl_option=args.grl_option, grl_lambda=args.grl_lambda
96
- )
97
  else:
98
  self.AdversarialNet = AdversarialNet(
99
  in_feature=self.model.output_num(), hidden_size=args.hidden_size, max_iter=self.max_iter, grl_option=args.grl_option, grl_lambda=args.grl_lambda
@@ -336,9 +337,7 @@ class train_utils:
336
  domain_label_source = torch.zeros(labels.size(0)).float()
337
  domain_label_target = torch.ones(inputs.size(0) - labels.size(0)).float()
338
  adversarial_label = torch.cat((domain_label_source, domain_label_target), dim=0).to(self.device)
339
- weight = torch.cat(
340
- (entropy_source / torch.sum(entropy_source).detach().item(), entropy_target / torch.sum(entropy_target).detach().item()), dim=0
341
- )
342
 
343
  # 展开权重,对损失重新加权
344
  adversarial_loss = torch.sum(weight.view(-1, 1) * self.adversarial_loss(adversarial_out.squeeze(), adversarial_label))
@@ -451,3 +450,27 @@ class train_utils:
451
 
452
  plt.tight_layout()
453
  plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import models
15
  from models.AdversarialNet import AdversarialNet, calc_coeff, grl_hook, Entropy
16
  from dataset.dataset import SignalDatasetCreator
17
+ from utils.loss import DAN, JAN, CORAL
18
 
19
 
20
  class train_utils:
21
+ def __init__(self, args, owned=False, data_path=None):
22
  self.args = args
23
+ self.owned = owned
24
+ self.data_path = data_path
25
 
26
  def setup(self):
27
  args = self.args
 
40
  logging.info(f"using {self.device_count} cpu")
41
 
42
  # 加载数据集
43
+ if self.owned:
44
+ signal_dataset_creator = SignalDatasetCreator(args.data_set, args.labels, args.transfer_task)
45
+ self.dataloaders = {}
46
+ self.dataloaders["source_train"], self.dataloaders["source_val"], self.dataloaders["target_train"], self.dataloaders["target_val"] = signal_dataset_creator.owned_data_split(
47
+ self.data_path, args.batch_size, args.num_workers, self.device
48
+ )
49
+ else:
50
+ signal_dataset_creator = SignalDatasetCreator(args.data_set, args.labels, args.transfer_task)
51
+ self.dataloaders = {}
52
+ self.dataloaders["source_train"], self.dataloaders["source_val"], self.dataloaders["target_train"], self.dataloaders["target_val"] = signal_dataset_creator.data_split(
53
+ args.batch_size, args.num_workers, self.device
54
+ )
55
  # 定义模型
56
  self.model = getattr(models, args.model_name)()
57
  if args.bottleneck:
 
94
  )
95
  else:
96
  if args.bottleneck_num:
97
+ self.AdversarialNet = AdversarialNet(in_feature=args.bottleneck_num, hidden_size=args.hidden_size, max_iter=self.max_iter, grl_option=args.grl_option, grl_lambda=args.grl_lambda)
 
 
98
  else:
99
  self.AdversarialNet = AdversarialNet(
100
  in_feature=self.model.output_num(), hidden_size=args.hidden_size, max_iter=self.max_iter, grl_option=args.grl_option, grl_lambda=args.grl_lambda
 
337
  domain_label_source = torch.zeros(labels.size(0)).float()
338
  domain_label_target = torch.ones(inputs.size(0) - labels.size(0)).float()
339
  adversarial_label = torch.cat((domain_label_source, domain_label_target), dim=0).to(self.device)
340
+ weight = torch.cat((entropy_source / torch.sum(entropy_source).detach().item(), entropy_target / torch.sum(entropy_target).detach().item()), dim=0)
 
 
341
 
342
  # 展开权重,对损失重新加权
343
  adversarial_loss = torch.sum(weight.view(-1, 1) * self.adversarial_loss(adversarial_out.squeeze(), adversarial_label))
 
450
 
451
  plt.tight_layout()
452
  plt.show()
453
+
454
+ def generate_fig(self):
455
+ args = self.args
456
+
457
+ fig, axs = plt.subplots(1, 2, figsize=(14, 6))
458
+
459
+ axs[0].set_title("Accuracy")
460
+ axs[0].set_xlabel("epoches")
461
+ axs[0].set_ylabel("accuracy")
462
+ axs[0].plot(range(args.max_epoch), self.acc["source_train"], label="source_train")
463
+ axs[0].plot(range(args.max_epoch), self.acc["source_val"], label="source_val")
464
+ axs[0].plot(range(args.max_epoch), self.acc["target_val"], label="target_val")
465
+ axs[0].legend()
466
+
467
+ axs[1].set_title(f"Loss Function: {args.distance_loss}")
468
+ axs[1].set_xlabel("epoches")
469
+ axs[1].set_ylabel("loss")
470
+ axs[1].plot(range(args.max_epoch), self.loss["source_train"], label="source_train")
471
+ axs[1].plot(range(args.max_epoch), self.loss["source_val"], label="source_val")
472
+ axs[1].plot(range(args.max_epoch), self.loss["target_val"], label="target_val")
473
+ axs[1].legend()
474
+
475
+ plt.tight_layout()
476
+ return fig