Spaces:
Build error
Build error
Commit ·
bb6b32a
1
Parent(s): 5d42100
fix:修改艺术字体美化和距离度量的BUG
Browse files- BFDS_train.py +3 -5
- BFDS_web.py +8 -14
- dataset/dataset.py +24 -25
- docs/BFDS_font.html +31 -0
- docs/demo.png +0 -3
- utils/predict.py +0 -3
- utils/train.py +2 -2
BFDS_train.py
CHANGED
|
@@ -67,14 +67,13 @@ class Argument:
|
|
| 67 |
self.middle_epoch = 0 # 引入目标域数据的起始轮次
|
| 68 |
|
| 69 |
# 基于映射
|
| 70 |
-
|
| 71 |
-
self.
|
| 72 |
-
self.distance_loss = "JMMD" # 损失模型 MK-MMD/JMMD/CORAL
|
| 73 |
self.distance_tradeoff = "Step" # 损失的trade_off参数 Cons/Step
|
| 74 |
self.distance_lambda = 1 # 若调整模式为Cons,指定其具体值
|
| 75 |
|
| 76 |
# 基于领域对抗
|
| 77 |
-
self.adversarial_option =
|
| 78 |
self.adversarial_loss = "CDA" # 领域对抗损失
|
| 79 |
self.hidden_size = 1024 # 对抗网络的隐藏层维数
|
| 80 |
self.grl_option = "Step" # 梯度反转层权重选择静态or动态更新
|
|
@@ -96,7 +95,6 @@ class Argument:
|
|
| 96 |
print(f"警告: Parameter '{param_name}' does not exist.")
|
| 97 |
|
| 98 |
def set_recommended_params(self):
|
| 99 |
-
# TODO cyq来写一个
|
| 100 |
# 给用户设定的推荐参数
|
| 101 |
recommended_params = {
|
| 102 |
"data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
|
|
|
|
| 67 |
self.middle_epoch = 0 # 引入目标域数据的起始轮次
|
| 68 |
|
| 69 |
# 基于映射
|
| 70 |
+
self.distance_option = True # 是否采用基于映射的损失
|
| 71 |
+
self.distance_loss = "MK-MMD" # 损失模型 MK-MMD/JMMD/CORAL
|
|
|
|
| 72 |
self.distance_tradeoff = "Step" # 损失的trade_off参数 Cons/Step
|
| 73 |
self.distance_lambda = 1 # 若调整模式为Cons,指定其具体值
|
| 74 |
|
| 75 |
# 基于领域对抗
|
| 76 |
+
self.adversarial_option = False # 是否采用领域对抗
|
| 77 |
self.adversarial_loss = "CDA" # 领域对抗损失
|
| 78 |
self.hidden_size = 1024 # 对抗网络的隐藏层维数
|
| 79 |
self.grl_option = "Step" # 梯度反转层权重选择静态or动态更新
|
|
|
|
| 95 |
print(f"警告: Parameter '{param_name}' does not exist.")
|
| 96 |
|
| 97 |
def set_recommended_params(self):
|
|
|
|
| 98 |
# 给用户设定的推荐参数
|
| 99 |
recommended_params = {
|
| 100 |
"data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
|
BFDS_web.py
CHANGED
|
@@ -239,22 +239,16 @@ def change_adversarial_tradeoff(adversarial_option, adversarial_tradeoff):
|
|
| 239 |
return (gr.update(visible=(adversarial_option and adversarial_tradeoff == "Cons")),)
|
| 240 |
|
| 241 |
|
|
|
|
|
|
|
|
|
|
| 242 |
# gradio BFDS_web.py --demo-name app
|
| 243 |
-
# FIXME 做一个网页logo
|
| 244 |
with gr.Blocks(title="BFDS WebUI") as app:
|
| 245 |
-
|
| 246 |
-
gr.
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
\ \ \|\ /\ \ \__/\ \ \_|\ \ \ \___|_ ____________\ \ \|\ \ \ \|\ \ \ \|\ \ \ \ \ \ __/|\ \ \___\|___ \ \_|
|
| 251 |
-
\ \ __ \ \ __\\ \ \ \\ \ \_____ \|\____________\ \ ____\ \ _ _\ \ \\\ \ __ \ \ \ \ \_|/_\ \ \ \ \ \
|
| 252 |
-
\ \ \|\ \ \ \_| \ \ \_\\ \|____|\ \|____________|\ \ \___|\ \ \\ \\ \ \\\ \|\ \\_\ \ \ \_|\ \ \ \____ \ \ \
|
| 253 |
-
\ \_______\ \__\ \ \_______\____\_\ \ \ \__\ \ \__\\ _\\ \_______\ \________\ \_______\ \_______\ \ \__\
|
| 254 |
-
\|_______|\|__| \|_______|\_________\ \|__| \|__|\|__|\|_______|\|________|\|_______|\|_______| \|__|
|
| 255 |
-
\|_________|
|
| 256 |
-
</pre>
|
| 257 |
-
""")
|
| 258 |
with gr.Tab("模型训练"):
|
| 259 |
gr.Markdown("在此模块中,您可以选择不同的迁移学习方法进行模型训练。")
|
| 260 |
with gr.Row():
|
|
|
|
| 239 |
return (gr.update(visible=(adversarial_option and adversarial_tradeoff == "Cons")),)
|
| 240 |
|
| 241 |
|
| 242 |
+
with open("docs/BFDS_font.html", "r", encoding="utf-8") as f:
|
| 243 |
+
BFDS_font_html = f.read()
|
| 244 |
+
|
| 245 |
# gradio BFDS_web.py --demo-name app
|
|
|
|
| 246 |
with gr.Blocks(title="BFDS WebUI") as app:
|
| 247 |
+
gr.HTML(BFDS_font_html)
|
| 248 |
+
gr.Markdown("""
|
| 249 |
+
# 轴承故障诊断系统
|
| 250 |
+
基于深度迁移学习的智能轴承故障诊断系统。支持多种迁移学习算法、信号处理方法和故障诊断模型。
|
| 251 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
with gr.Tab("模型训练"):
|
| 253 |
gr.Markdown("在此模块中,您可以选择不同的迁移学习方法进行模型训练。")
|
| 254 |
with gr.Row():
|
dataset/dataset.py
CHANGED
|
@@ -56,36 +56,35 @@ class SignalDatasetCreator:
|
|
| 56 |
def data_split(self, batch_size, num_workers, device):
|
| 57 |
# 这里源域和目标域都是我们提供用来验证迁移学习的正确性
|
| 58 |
# get source train and val
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
source_train = DataLoader(dataset=
|
| 64 |
-
source_val = DataLoader(dataset=
|
| 65 |
# get target train and val
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
target_train = DataLoader(dataset=
|
| 71 |
-
target_val = DataLoader(dataset=
|
| 72 |
return source_train, source_val, target_train, target_val
|
| 73 |
|
| 74 |
def owned_data_split(self, data_path, batch_size, num_workers, device):
|
| 75 |
-
# FIXME 这里不给标签影响acc吗?
|
| 76 |
# 这里目标域是用户自己提供的数据集
|
| 77 |
# get source train and val
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
source_train = DataLoader(dataset=
|
| 83 |
-
source_val = DataLoader(dataset=
|
| 84 |
# get target train and val
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
target_train = DataLoader(dataset=
|
| 90 |
-
target_val = DataLoader(dataset=
|
| 91 |
return source_train, source_val, target_train, target_val
|
|
|
|
| 56 |
def data_split(self, batch_size, num_workers, device):
|
| 57 |
# 这里源域和目标域都是我们提供用来验证迁移学习的正确性
|
| 58 |
# get source train and val
|
| 59 |
+
data_frame_source = get_dataset(self.data_set, self.source[0], self.source[1])
|
| 60 |
+
data_set_source = SignalDataset(data_frame_source)
|
| 61 |
+
lengths_source = [round(0.8 * len(data_set_source)), len(data_set_source) - round(0.8 * len(data_set_source))]
|
| 62 |
+
train_data_source, eval_data_source = random_split(data_set_source, lengths_source)
|
| 63 |
+
source_train = DataLoader(dataset=train_data_source, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 64 |
+
source_val = DataLoader(dataset=eval_data_source, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 65 |
# get target train and val
|
| 66 |
+
data_frame_target = get_dataset(self.data_set, self.target[0], self.target[1])
|
| 67 |
+
data_set_target = SignalDataset(data_frame_target)
|
| 68 |
+
lengths_target = [round(0.8 * len(data_set_target)), len(data_set_target) - round(0.8 * len(data_set_target))]
|
| 69 |
+
train_data_target, eval_data_target = random_split(data_set_target, lengths_target)
|
| 70 |
+
target_train = DataLoader(dataset=train_data_target, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 71 |
+
target_val = DataLoader(dataset=eval_data_target, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 72 |
return source_train, source_val, target_train, target_val
|
| 73 |
|
| 74 |
def owned_data_split(self, data_path, batch_size, num_workers, device):
|
|
|
|
| 75 |
# 这里目标域是用户自己提供的数据集
|
| 76 |
# get source train and val
|
| 77 |
+
data_frame_source = get_dataset(self.data_set, self.source[0], self.source[1])
|
| 78 |
+
data_set_source = SignalDataset(data_frame_source)
|
| 79 |
+
lengths_source = [round(0.8 * len(data_set_source)), len(data_set_source) - round(0.8 * len(data_set_source))]
|
| 80 |
+
train_data_source, eval_data_source = random_split(data_set_source, lengths_source)
|
| 81 |
+
source_train = DataLoader(dataset=train_data_source, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 82 |
+
source_val = DataLoader(dataset=eval_data_source, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 83 |
# get target train and val
|
| 84 |
+
data_frame_target = get_owned_dataset(data_path)
|
| 85 |
+
data_set_target = SignalDataset(data_frame_target)
|
| 86 |
+
lengths_target = [round(0.8 * len(data_set_target)), len(data_set_target) - round(0.8 * len(data_set_target))]
|
| 87 |
+
train_data_target, eval_data_target = random_split(data_set_target, lengths_target)
|
| 88 |
+
target_train = DataLoader(dataset=train_data_target, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 89 |
+
target_val = DataLoader(dataset=eval_data_target, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device == "cuda"), drop_last=True)
|
| 90 |
return source_train, source_val, target_train, target_val
|
docs/BFDS_font.html
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Bearing Fault Diagnosis ASCII Art</title>
|
| 7 |
+
<style>
|
| 8 |
+
#taag_output_text {
|
| 9 |
+
font-family: "Courier New", ui-monospace, monospace;
|
| 10 |
+
font-size: 10pt;
|
| 11 |
+
white-space: pre;
|
| 12 |
+
overflow-wrap: break-word;
|
| 13 |
+
margin-top: 15px;
|
| 14 |
+
margin-bottom: 15px;
|
| 15 |
+
float: left;
|
| 16 |
+
}
|
| 17 |
+
</style>
|
| 18 |
+
</head>
|
| 19 |
+
<body>
|
| 20 |
+
<pre id="taag_output_text">
|
| 21 |
+
________ ________ ________ ________ ________ ________ ________ ___ _______ ________ _________
|
| 22 |
+
|\ __ \|\ _____\\ ___ \|\ ____\ |\ __ \|\ __ \|\ __ \ |\ \|\ ___ \ |\ ____\\___ ___\
|
| 23 |
+
\ \ \|\ /\ \ \__/\ \ \_|\ \ \ \___|_ ____________\ \ \|\ \ \ \|\ \ \ \|\ \ \ \ \ \ __/|\ \ \___\|___ \ \_|
|
| 24 |
+
\ \ __ \ \ __\\ \ \ \\ \ \_____ \|\____________\ \ ____\ \ _ _\ \ \\\ \ __ \ \ \ \ \_|/_\ \ \ \ \ \
|
| 25 |
+
\ \ \|\ \ \ \_| \ \ \_\\ \|____|\ \|____________|\ \ \___|\ \ \\ \\ \ \\\ \|\ \\_\ \ \ \_|\ \ \ \____ \ \ \
|
| 26 |
+
\ \_______\ \__\ \ \_______\____\_\ \ \ \__\ \ \__\\ _\\ \_______\ \________\ \_______\ \_______\ \ \__\
|
| 27 |
+
\|_______|\|__| \|_______|\_________\ \|__| \|__|\|__|\|_______|\|________|\|_______|\|_______| \|__|
|
| 28 |
+
\|_________|
|
| 29 |
+
</pre>
|
| 30 |
+
</body>
|
| 31 |
+
</html>
|
docs/demo.png
DELETED
Git LFS Details
|
utils/predict.py
CHANGED
|
@@ -6,7 +6,6 @@ from pathlib import Path
|
|
| 6 |
import models
|
| 7 |
|
| 8 |
|
| 9 |
-
# TODO 处理更多文件
|
| 10 |
def audio_to_signal(audio_file, sr=None):
|
| 11 |
signal, _ = librosa.load(audio_file, sr=sr)
|
| 12 |
return signal
|
|
@@ -36,7 +35,6 @@ def predict(model_state_dict, signal_file, args):
|
|
| 36 |
file_extension = Path(signal_file).suffix
|
| 37 |
if file_extension == ".csv":
|
| 38 |
signal = csv_to_signal(signal_file).reshape(-1, 1, 224)
|
| 39 |
-
# FIXME 这里能搞多少后缀就多少后缀
|
| 40 |
elif file_extension in [".wav", ".mp3"]:
|
| 41 |
signal = audio_to_signal(signal_file).reshape(-1, 1, 224)
|
| 42 |
else:
|
|
@@ -44,5 +42,4 @@ def predict(model_state_dict, signal_file, args):
|
|
| 44 |
signal = torch.tensor(signal, dtype=torch.float32).to(device)
|
| 45 |
output = model_all(signal)
|
| 46 |
predictions = output.mean(dim=0)
|
| 47 |
-
# FIXME 这里没有过softmax
|
| 48 |
return predictions
|
|
|
|
| 6 |
import models
|
| 7 |
|
| 8 |
|
|
|
|
| 9 |
def audio_to_signal(audio_file, sr=None):
|
| 10 |
signal, _ = librosa.load(audio_file, sr=sr)
|
| 11 |
return signal
|
|
|
|
| 35 |
file_extension = Path(signal_file).suffix
|
| 36 |
if file_extension == ".csv":
|
| 37 |
signal = csv_to_signal(signal_file).reshape(-1, 1, 224)
|
|
|
|
| 38 |
elif file_extension in [".wav", ".mp3"]:
|
| 39 |
signal = audio_to_signal(signal_file).reshape(-1, 1, 224)
|
| 40 |
else:
|
|
|
|
| 42 |
signal = torch.tensor(signal, dtype=torch.float32).to(device)
|
| 43 |
output = model_all(signal)
|
| 44 |
predictions = output.mean(dim=0)
|
|
|
|
| 45 |
return predictions
|
utils/train.py
CHANGED
|
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
|
|
| 14 |
import models
|
| 15 |
from models.AdversarialNet import AdversarialNet, calc_coeff, grl_hook, Entropy
|
| 16 |
from dataset.dataset import SignalDatasetCreator
|
| 17 |
-
from .loss import DAN, JAN, CORAL
|
| 18 |
|
| 19 |
|
| 20 |
class train_utils:
|
|
@@ -26,7 +26,7 @@ class train_utils:
|
|
| 26 |
def setup(self):
|
| 27 |
args = self.args
|
| 28 |
self.save_dir = os.path.join(args.checkpoint_dir, args.model_name + "_" + datetime.strftime(datetime.now(), "%m%d-%H%M%S"))
|
| 29 |
-
|
| 30 |
# 判断训练设备
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
self.device = torch.device("cuda")
|
|
|
|
| 14 |
import models
|
| 15 |
from models.AdversarialNet import AdversarialNet, calc_coeff, grl_hook, Entropy
|
| 16 |
from dataset.dataset import SignalDatasetCreator
|
| 17 |
+
from utils.loss import DAN, JAN, CORAL
|
| 18 |
|
| 19 |
|
| 20 |
class train_utils:
|
|
|
|
| 26 |
def setup(self):
|
| 27 |
args = self.args
|
| 28 |
self.save_dir = os.path.join(args.checkpoint_dir, args.model_name + "_" + datetime.strftime(datetime.now(), "%m%d-%H%M%S"))
|
| 29 |
+
|
| 30 |
# 判断训练设备
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
self.device = torch.device("cuda")
|