Str0keOOOO commited on
Commit
bb6b32a
·
1 Parent(s): 5d42100

fix:修改艺术字体美化和距离度量的BUG

Browse files
BFDS_train.py CHANGED
@@ -67,14 +67,13 @@ class Argument:
67
  self.middle_epoch = 0 # 引入目标域数据的起始轮次
68
 
69
  # 基于映射
70
- # FIXME 基于映射有bug
71
- self.distance_option = False # 是否采用基于映射的损失
72
- self.distance_loss = "JMMD" # 损失模型 MK-MMD/JMMD/CORAL
73
  self.distance_tradeoff = "Step" # 损失的trade_off参数 Cons/Step
74
  self.distance_lambda = 1 # 若调整模式为Cons,指定其具体值
75
 
76
  # 基于领域对抗
77
- self.adversarial_option = True # 是否采用领域对抗
78
  self.adversarial_loss = "CDA" # 领域对抗损失
79
  self.hidden_size = 1024 # 对抗网络的隐藏层维数
80
  self.grl_option = "Step" # 梯度反转层权重选择静态or动态更新
@@ -96,7 +95,6 @@ class Argument:
96
  print(f"警告: Parameter '{param_name}' does not exist.")
97
 
98
  def set_recommended_params(self):
99
- # TODO cyq来写一个
100
  # 给用户设定的推荐参数
101
  recommended_params = {
102
  "data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
 
67
  self.middle_epoch = 0 # 引入目标域数据的起始轮次
68
 
69
  # 基于映射
70
+ self.distance_option = True # 是否采用基于映射的损失
71
+ self.distance_loss = "MK-MMD" # 损失模型 MK-MMD/JMMD/CORAL
 
72
  self.distance_tradeoff = "Step" # 损失的trade_off参数 Cons/Step
73
  self.distance_lambda = 1 # 若调整模式为Cons,指定其具体值
74
 
75
  # 基于领域对抗
76
+ self.adversarial_option = False # 是否采用领域对抗
77
  self.adversarial_loss = "CDA" # 领域对抗损失
78
  self.hidden_size = 1024 # 对抗网络的隐藏层维数
79
  self.grl_option = "Step" # 梯度反转层权重选择静态or动态更新
 
95
  print(f"警告: Parameter '{param_name}' does not exist.")
96
 
97
  def set_recommended_params(self):
 
98
  # 给用户设定的推荐参数
99
  recommended_params = {
100
  "data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
BFDS_web.py CHANGED
@@ -239,22 +239,16 @@ def change_adversarial_tradeoff(adversarial_option, adversarial_tradeoff):
239
  return (gr.update(visible=(adversarial_option and adversarial_tradeoff == "Cons")),)
240
 
241
 
 
 
 
242
  # gradio BFDS_web.py --demo-name app
243
- # FIXME 做一个网页logo
244
  with gr.Blocks(title="BFDS WebUI") as app:
245
- # FIXME 怎么显示有问题呢????
246
- gr.HTML(r"""
247
- <pre style="text-align: center;">
248
- ________ ________ ________ ________ ________ ________ ________ ___ _______ ________ _________
249
- |\ __ \|\ _____\\ ___ \|\ ____\ |\ __ \|\ __ \|\ __ \ |\ \|\ ___ \ |\ ____\\___ ___\
250
- \ \ \|\ /\ \ \__/\ \ \_|\ \ \ \___|_ ____________\ \ \|\ \ \ \|\ \ \ \|\ \ \ \ \ \ __/|\ \ \___\|___ \ \_|
251
- \ \ __ \ \ __\\ \ \ \\ \ \_____ \|\____________\ \ ____\ \ _ _\ \ \\\ \ __ \ \ \ \ \_|/_\ \ \ \ \ \
252
- \ \ \|\ \ \ \_| \ \ \_\\ \|____|\ \|____________|\ \ \___|\ \ \\ \\ \ \\\ \|\ \\_\ \ \ \_|\ \ \ \____ \ \ \
253
- \ \_______\ \__\ \ \_______\____\_\ \ \ \__\ \ \__\\ _\\ \_______\ \________\ \_______\ \_______\ \ \__\
254
- \|_______|\|__| \|_______|\_________\ \|__| \|__|\|__|\|_______|\|________|\|_______|\|_______| \|__|
255
- \|_________|
256
- </pre>
257
- """)
258
  with gr.Tab("模型训练"):
259
  gr.Markdown("在此模块中,您可以选择不同的迁移学习方法进行模型训练。")
260
  with gr.Row():
 
239
  return (gr.update(visible=(adversarial_option and adversarial_tradeoff == "Cons")),)
240
 
241
 
242
+ with open("docs/BFDS_font.html", "r", encoding="utf-8") as f:
243
+ BFDS_font_html = f.read()
244
+
245
  # gradio BFDS_web.py --demo-name app
 
246
  with gr.Blocks(title="BFDS WebUI") as app:
247
+ gr.HTML(BFDS_font_html)
248
+ gr.Markdown("""
249
+ # 轴承故障诊断系统
250
+ 基于深度迁移学习的智能轴承故障诊断系统。支持多种迁移学习算法、信号处理方法和故障诊断模型。
251
+ """)
 
 
 
 
 
 
 
 
252
  with gr.Tab("模型训练"):
253
  gr.Markdown("在此模块中,您可以选择不同的迁移学习方法进行模型训练。")
254
  with gr.Row():
dataset/dataset.py CHANGED
@@ -56,36 +56,35 @@ class SignalDatasetCreator:
56
  def data_split(self, batch_size, num_workers, device):
57
  # 这里源域和目标域都是我们提供用来验证迁移学习的正确性
58
  # get source train and val
59
- data_frame = get_dataset(self.data_set, self.source[0], self.source[1])
60
- data_set = SignalDataset(data_frame)
61
- lengths = [round(0.8 * len(data_set)), len(data_set) - round(0.8 * len(data_set))]
62
- train_data, eval_data = random_split(data_set, lengths)
63
- source_train = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"))
64
- source_val = DataLoader(dataset=eval_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"))
65
  # get target train and val
66
- data_frame = get_dataset(self.data_set, self.target[0], self.target[1])
67
- data_set = SignalDataset(data_frame)
68
- lengths = [round(0.8 * len(data_set)), len(data_set) - round(0.8 * len(data_set))]
69
- train_data, eval_data = random_split(data_set, lengths)
70
- target_train = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"))
71
- target_val = DataLoader(dataset=eval_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"))
72
  return source_train, source_val, target_train, target_val
73
 
74
  def owned_data_split(self, data_path, batch_size, num_workers, device):
75
- # FIXME 这里不给标签影响acc吗?
76
  # 这里目标域是用户自己提供的数据集
77
  # get source train and val
78
- data_frame = get_dataset(self.data_set, self.source[0], self.source[1])
79
- data_set = SignalDataset(data_frame)
80
- lengths = [round(0.8 * len(data_set)), len(data_set) - round(0.8 * len(data_set))]
81
- train_data, eval_data = random_split(data_set, lengths)
82
- source_train = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"))
83
- source_val = DataLoader(dataset=eval_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"))
84
  # get target train and val
85
- data_frame = get_owned_dataset(data_path)
86
- data_set = SignalDataset(data_frame)
87
- lengths = [round(0.8 * len(data_set)), len(data_set) - round(0.8 * len(data_set))]
88
- train_data, eval_data = random_split(data_set, lengths)
89
- target_train = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"))
90
- target_val = DataLoader(dataset=eval_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"))
91
  return source_train, source_val, target_train, target_val
 
56
  def data_split(self, batch_size, num_workers, device):
57
  # 这里源域和目标域都是我们提供用来验证迁移学习的正确性
58
  # get source train and val
59
+ data_frame_source = get_dataset(self.data_set, self.source[0], self.source[1])
60
+ data_set_source = SignalDataset(data_frame_source)
61
+ lengths_source = [round(0.8 * len(data_set_source)), len(data_set_source) - round(0.8 * len(data_set_source))]
62
+ train_data_source, eval_data_source = random_split(data_set_source, lengths_source)
63
+ source_train = DataLoader(dataset=train_data_source, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
64
+ source_val = DataLoader(dataset=eval_data_source, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
65
  # get target train and val
66
+ data_frame_target = get_dataset(self.data_set, self.target[0], self.target[1])
67
+ data_set_target = SignalDataset(data_frame_target)
68
+ lengths_target = [round(0.8 * len(data_set_target)), len(data_set_target) - round(0.8 * len(data_set_target))]
69
+ train_data_target, eval_data_target = random_split(data_set_target, lengths_target)
70
+ target_train = DataLoader(dataset=train_data_target, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
71
+ target_val = DataLoader(dataset=eval_data_target, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
72
  return source_train, source_val, target_train, target_val
73
 
74
  def owned_data_split(self, data_path, batch_size, num_workers, device):
 
75
  # 这里目标域是用户自己提供的数据集
76
  # get source train and val
77
+ data_frame_source = get_dataset(self.data_set, self.source[0], self.source[1])
78
+ data_set_source = SignalDataset(data_frame_source)
79
+ lengths_source = [round(0.8 * len(data_set_source)), len(data_set_source) - round(0.8 * len(data_set_source))]
80
+ train_data_source, eval_data_source = random_split(data_set_source, lengths_source)
81
+ source_train = DataLoader(dataset=train_data_source, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
82
+ source_val = DataLoader(dataset=eval_data_source, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
83
  # get target train and val
84
+ data_frame_target = get_owned_dataset(data_path)
85
+ data_set_target = SignalDataset(data_frame_target)
86
+ lengths_target = [round(0.8 * len(data_set_target)), len(data_set_target) - round(0.8 * len(data_set_target))]
87
+ train_data_target, eval_data_target = random_split(data_set_target, lengths_target)
88
+ target_train = DataLoader(dataset=train_data_target, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
89
+ target_val = DataLoader(dataset=eval_data_target, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
90
  return source_train, source_val, target_train, target_val
docs/BFDS_font.html ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Bearing Fault Diagnosis ASCII Art</title>
7
+ <style>
8
+ #taag_output_text {
9
+ font-family: "Courier New", ui-monospace, monospace;
10
+ font-size: 10pt;
11
+ white-space: pre;
12
+ overflow-wrap: break-word;
13
+ margin-top: 15px;
14
+ margin-bottom: 15px;
15
+ float: left;
16
+ }
17
+ </style>
18
+ </head>
19
+ <body>
20
+ <pre id="taag_output_text">
21
+ ________ ________ ________ ________ ________ ________ ________ ___ _______ ________ _________
22
+ |\ __ \|\ _____\\ ___ \|\ ____\ |\ __ \|\ __ \|\ __ \ |\ \|\ ___ \ |\ ____\\___ ___\
23
+ \ \ \|\ /\ \ \__/\ \ \_|\ \ \ \___|_ ____________\ \ \|\ \ \ \|\ \ \ \|\ \ \ \ \ \ __/|\ \ \___\|___ \ \_|
24
+ \ \ __ \ \ __\\ \ \ \\ \ \_____ \|\____________\ \ ____\ \ _ _\ \ \\\ \ __ \ \ \ \ \_|/_\ \ \ \ \ \
25
+ \ \ \|\ \ \ \_| \ \ \_\\ \|____|\ \|____________|\ \ \___|\ \ \\ \\ \ \\\ \|\ \\_\ \ \ \_|\ \ \ \____ \ \ \
26
+ \ \_______\ \__\ \ \_______\____\_\ \ \ \__\ \ \__\\ _\\ \_______\ \________\ \_______\ \_______\ \ \__\
27
+ \|_______|\|__| \|_______|\_________\ \|__| \|__|\|__|\|_______|\|________|\|_______|\|_______| \|__|
28
+ \|_________|
29
+ </pre>
30
+ </body>
31
+ </html>
docs/demo.png DELETED

Git LFS Details

  • SHA256: 17b2460182578ad9cd6db2e80543dd40f917da399cd2382aa07d4dd1b94512f9
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
utils/predict.py CHANGED
@@ -6,7 +6,6 @@ from pathlib import Path
6
  import models
7
 
8
 
9
- # TODO 处理更多文件
10
  def audio_to_signal(audio_file, sr=None):
11
  signal, _ = librosa.load(audio_file, sr=sr)
12
  return signal
@@ -36,7 +35,6 @@ def predict(model_state_dict, signal_file, args):
36
  file_extension = Path(signal_file).suffix
37
  if file_extension == ".csv":
38
  signal = csv_to_signal(signal_file).reshape(-1, 1, 224)
39
- # FIXME 这里能搞多少后缀就多少后缀
40
  elif file_extension in [".wav", ".mp3"]:
41
  signal = audio_to_signal(signal_file).reshape(-1, 1, 224)
42
  else:
@@ -44,5 +42,4 @@ def predict(model_state_dict, signal_file, args):
44
  signal = torch.tensor(signal, dtype=torch.float32).to(device)
45
  output = model_all(signal)
46
  predictions = output.mean(dim=0)
47
- # FIXME 这里没有过softmax
48
  return predictions
 
6
  import models
7
 
8
 
 
9
  def audio_to_signal(audio_file, sr=None):
10
  signal, _ = librosa.load(audio_file, sr=sr)
11
  return signal
 
35
  file_extension = Path(signal_file).suffix
36
  if file_extension == ".csv":
37
  signal = csv_to_signal(signal_file).reshape(-1, 1, 224)
 
38
  elif file_extension in [".wav", ".mp3"]:
39
  signal = audio_to_signal(signal_file).reshape(-1, 1, 224)
40
  else:
 
42
  signal = torch.tensor(signal, dtype=torch.float32).to(device)
43
  output = model_all(signal)
44
  predictions = output.mean(dim=0)
 
45
  return predictions
utils/train.py CHANGED
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
14
  import models
15
  from models.AdversarialNet import AdversarialNet, calc_coeff, grl_hook, Entropy
16
  from dataset.dataset import SignalDatasetCreator
17
- from .loss import DAN, JAN, CORAL
18
 
19
 
20
  class train_utils:
@@ -26,7 +26,7 @@ class train_utils:
26
  def setup(self):
27
  args = self.args
28
  self.save_dir = os.path.join(args.checkpoint_dir, args.model_name + "_" + datetime.strftime(datetime.now(), "%m%d-%H%M%S"))
29
-
30
  # 判断训练设备
31
  if torch.cuda.is_available():
32
  self.device = torch.device("cuda")
 
14
  import models
15
  from models.AdversarialNet import AdversarialNet, calc_coeff, grl_hook, Entropy
16
  from dataset.dataset import SignalDatasetCreator
17
+ from utils.loss import DAN, JAN, CORAL
18
 
19
 
20
  class train_utils:
 
26
  def setup(self):
27
  args = self.args
28
  self.save_dir = os.path.join(args.checkpoint_dir, args.model_name + "_" + datetime.strftime(datetime.now(), "%m%d-%H%M%S"))
29
+
30
  # 判断训练设备
31
  if torch.cuda.is_available():
32
  self.device = torch.device("cuda")