Spaces:
Build error
Build error
Commit ·
ec09c5e
1
Parent(s): 9e14f38
v1:第一版修整代码
Browse files- BFDS_train.py +10 -9
- BFDS_web.py +7 -9
- Dockerfile.cpu +31 -0
- README.md +1 -7
- dataset/get_data.py +3 -3
- dataset/get_dataset.py +4 -2
- models/CNN.py +1 -1
- models/ResNet18_1d.py +7 -9
- requirements-cpu.txt +96 -0
- utils/fetch_conditions.py +0 -1
- utils/future_use.py +0 -208
- utils/predict.py +2 -1
BFDS_train.py
CHANGED
|
@@ -2,8 +2,8 @@ import os
|
|
| 2 |
import logging
|
| 3 |
import warnings
|
| 4 |
import json
|
| 5 |
-
from datetime import datetime
|
| 6 |
import requests
|
|
|
|
| 7 |
|
| 8 |
if __name__ == "__main__":
|
| 9 |
try:
|
|
@@ -17,7 +17,7 @@ if __name__ == "__main__":
|
|
| 17 |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 18 |
print(f"无法连接到 Hugging Face:换源到{os.environ['HF_ENDPOINT']}")
|
| 19 |
if not os.path.exists("./cache"):
|
| 20 |
-
os.makedirs("./cache")
|
| 21 |
os.environ["HF_DATASETS_CACHE"] = "./cache"
|
| 22 |
|
| 23 |
from utils.logger import setlogger
|
|
@@ -35,14 +35,14 @@ class Argument:
|
|
| 35 |
self.data_set = "BFDS-Project/Bearing-Fault-Diagnosis-System" # 数据集huggingface地址
|
| 36 |
self.conditions = fetch_all_conditions_from_huggingface(self.data_set) # 数据集的配置和分割信息如果想要知道明确的信息来确定迁移方向请自行运行fetch_conditions.py
|
| 37 |
self.labels = {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5} # 标签
|
| 38 |
-
self.transfer_task = [["
|
| 39 |
-
self.target_domain_labeled =
|
| 40 |
|
| 41 |
# 预处理
|
| 42 |
self.normalize_type = None # 归一化方式, mean-std/min-max/None
|
| 43 |
self.stratified_sampling = True # 是否分层采样, True/False
|
| 44 |
# 模型
|
| 45 |
-
self.model_name = "
|
| 46 |
self.bottleneck = True # 是否使用bottleneck层
|
| 47 |
self.bottleneck_num = 256 # bottleneck层的输出维数
|
| 48 |
|
|
@@ -118,7 +118,8 @@ class Argument:
|
|
| 118 |
"data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
|
| 119 |
"conditions": fetch_all_conditions_from_huggingface("BFDS-Project/Bearing-Fault-Diagnosis-System"),
|
| 120 |
"labels": {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5},
|
| 121 |
-
"transfer_task": [["
|
|
|
|
| 122 |
"normalize_type": None,
|
| 123 |
"stratified_sampling": True,
|
| 124 |
"model_name": "CNN",
|
|
@@ -126,7 +127,7 @@ class Argument:
|
|
| 126 |
"bottleneck_num": 256,
|
| 127 |
"batch_size": 64,
|
| 128 |
"cuda_device": "0",
|
| 129 |
-
"max_epoch":
|
| 130 |
"num_workers": 0,
|
| 131 |
"checkpoint_dir": "./checkpoint",
|
| 132 |
"print_step": 50,
|
|
@@ -136,8 +137,8 @@ class Argument:
|
|
| 136 |
"lr": 1e-3,
|
| 137 |
"lr_scheduler": "step",
|
| 138 |
"gamma": 0.1,
|
| 139 |
-
"steps": [
|
| 140 |
-
"middle_epoch":
|
| 141 |
"distance_option": True,
|
| 142 |
"distance_loss": "JMMD",
|
| 143 |
"distance_tradeoff": "Step",
|
|
|
|
| 2 |
import logging
|
| 3 |
import warnings
|
| 4 |
import json
|
|
|
|
| 5 |
import requests
|
| 6 |
+
from datetime import datetime
|
| 7 |
|
| 8 |
if __name__ == "__main__":
|
| 9 |
try:
|
|
|
|
| 17 |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 18 |
print(f"无法连接到 Hugging Face:换源到{os.environ['HF_ENDPOINT']}")
|
| 19 |
if not os.path.exists("./cache"):
|
| 20 |
+
os.makedirs("./cache")
|
| 21 |
os.environ["HF_DATASETS_CACHE"] = "./cache"
|
| 22 |
|
| 23 |
from utils.logger import setlogger
|
|
|
|
| 35 |
self.data_set = "BFDS-Project/Bearing-Fault-Diagnosis-System" # 数据集huggingface地址
|
| 36 |
self.conditions = fetch_all_conditions_from_huggingface(self.data_set) # 数据集的配置和分割信息如果想要知道明确的信息来确定迁移方向请自行运行fetch_conditions.py
|
| 37 |
self.labels = {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5} # 标签
|
| 38 |
+
self.transfer_task = [["CWRURPM", "12kDriveEndrpm1730"], ["CWRURPM", "12kDriveEndrpm1750"]] # 迁移方向
|
| 39 |
+
self.target_domain_labeled = True # 表示目标域在训练中是否带有标签
|
| 40 |
|
| 41 |
# 预处理
|
| 42 |
self.normalize_type = None # 归一化方式, mean-std/min-max/None
|
| 43 |
self.stratified_sampling = True # 是否分层采样, True/False
|
| 44 |
# 模型
|
| 45 |
+
self.model_name = "CNN" # 模型名
|
| 46 |
self.bottleneck = True # 是否使用bottleneck层
|
| 47 |
self.bottleneck_num = 256 # bottleneck层的输出维数
|
| 48 |
|
|
|
|
| 118 |
"data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
|
| 119 |
"conditions": fetch_all_conditions_from_huggingface("BFDS-Project/Bearing-Fault-Diagnosis-System"),
|
| 120 |
"labels": {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5},
|
| 121 |
+
"transfer_task": [["CWRURPM", "12kDriveEndrpm1730"], ["CWRURPM", "12kDriveEndrpm1750"]],
|
| 122 |
+
"target_domain_labeled": False,
|
| 123 |
"normalize_type": None,
|
| 124 |
"stratified_sampling": True,
|
| 125 |
"model_name": "CNN",
|
|
|
|
| 127 |
"bottleneck_num": 256,
|
| 128 |
"batch_size": 64,
|
| 129 |
"cuda_device": "0",
|
| 130 |
+
"max_epoch": 100,
|
| 131 |
"num_workers": 0,
|
| 132 |
"checkpoint_dir": "./checkpoint",
|
| 133 |
"print_step": 50,
|
|
|
|
| 137 |
"lr": 1e-3,
|
| 138 |
"lr_scheduler": "step",
|
| 139 |
"gamma": 0.1,
|
| 140 |
+
"steps": [25, 75],
|
| 141 |
+
"middle_epoch": 50,
|
| 142 |
"distance_option": True,
|
| 143 |
"distance_loss": "JMMD",
|
| 144 |
"distance_tradeoff": "Step",
|
BFDS_web.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import requests
|
|
|
|
|
|
|
| 3 |
import zipfile
|
|
|
|
| 4 |
|
| 5 |
if __name__ == "__main__":
|
| 6 |
try:
|
|
@@ -20,18 +25,11 @@ if __name__ == "__main__":
|
|
| 20 |
|
| 21 |
import gradio as gr
|
| 22 |
from BFDS_train import Argument
|
| 23 |
-
import torch
|
| 24 |
-
from utils.predict import predict
|
| 25 |
-
import pandas as pd
|
| 26 |
-
|
| 27 |
-
import logging
|
| 28 |
-
import warnings
|
| 29 |
-
from datetime import datetime
|
| 30 |
-
|
| 31 |
-
|
| 32 |
from utils.logger import setlogger
|
|
|
|
| 33 |
from utils.train import train_utils
|
| 34 |
|
|
|
|
| 35 |
# 初始化 Argument 实例
|
| 36 |
args = Argument()
|
| 37 |
args.set_recommended_params()
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import os
|
| 3 |
+
import pandas as pd
|
| 4 |
import requests
|
| 5 |
+
import torch
|
| 6 |
+
import warnings
|
| 7 |
import zipfile
|
| 8 |
+
from datetime import datetime
|
| 9 |
|
| 10 |
if __name__ == "__main__":
|
| 11 |
try:
|
|
|
|
| 25 |
|
| 26 |
import gradio as gr
|
| 27 |
from BFDS_train import Argument
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from utils.logger import setlogger
|
| 29 |
+
from utils.predict import predict
|
| 30 |
from utils.train import train_utils
|
| 31 |
|
| 32 |
+
|
| 33 |
# 初始化 Argument 实例
|
| 34 |
args = Argument()
|
| 35 |
args.set_recommended_params()
|
Dockerfile.cpu
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-alpine
|
| 2 |
+
|
| 3 |
+
# 设置环境变量以减少缓存和输出
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
+
PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
# 设置工作目录
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
# 安装必要的系统依赖
|
| 11 |
+
RUN apk add --no-cache --virtual .build-deps \
|
| 12 |
+
gcc \
|
| 13 |
+
musl-dev \
|
| 14 |
+
libffi-dev \
|
| 15 |
+
openssl-dev \
|
| 16 |
+
make \
|
| 17 |
+
&& apk add --no-cache \
|
| 18 |
+
bash \
|
| 19 |
+
git \
|
| 20 |
+
&& pip install --upgrade pip
|
| 21 |
+
|
| 22 |
+
# 安装 Python 依赖并清理构建依赖
|
| 23 |
+
COPY requirements-cpu.txt .
|
| 24 |
+
RUN pip install --no-cache-dir -r requirements-cpu.txt \
|
| 25 |
+
&& apk del .build-deps
|
| 26 |
+
|
| 27 |
+
# 复制项目文件
|
| 28 |
+
COPY . .
|
| 29 |
+
|
| 30 |
+
# 设置默认运行命令
|
| 31 |
+
CMD ["python", "BFDS_web.py"]
|
README.md
CHANGED
|
@@ -1,7 +1 @@
|
|
| 1 |
-
# Bearing-Fault-Diagnosis-System
|
| 2 |
-
|
| 3 |
-
## Requirements
|
| 4 |
-
- Python 3.10
|
| 5 |
-
- matplotlib 3.10.0
|
| 6 |
-
- numpy 2.2.2
|
| 7 |
-
- PyWavelets 1.8.0
|
|
|
|
| 1 |
+
# Bearing-Fault-Diagnosis-System
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/get_data.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
import numpy as np
|
| 3 |
-
from datasets import load_dataset
|
| 4 |
import librosa
|
| 5 |
import mimetypes
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# ===============================================================
|
| 8 |
# 加载有标签的数据集(n , m + 1)最后一列是标签
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import librosa
|
| 2 |
import mimetypes
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
|
| 7 |
# ===============================================================
|
| 8 |
# 加载有标签的数据集(n , m + 1)最后一列是标签
|
dataset/get_dataset.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import torch
|
| 3 |
-
from torch.utils.data import
|
| 4 |
-
|
| 5 |
from dataset.get_data import get_huggingface_dataset, get_local_dataset, get_user_dataset
|
| 6 |
|
| 7 |
|
|
|
|
| 1 |
+
from typing import Literal, Optional
|
| 2 |
+
|
| 3 |
import pandas as pd
|
| 4 |
import torch
|
| 5 |
+
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
| 6 |
+
|
| 7 |
from dataset.get_data import get_huggingface_dataset, get_local_dataset, get_user_dataset
|
| 8 |
|
| 9 |
|
models/CNN.py
CHANGED
|
@@ -33,7 +33,7 @@ class CNN(nn.Module):
|
|
| 33 |
nn.AdaptiveMaxPool1d(4),
|
| 34 |
) # 128, 4,4
|
| 35 |
|
| 36 |
-
self.layer5 = nn.Sequential(nn.Linear(128 * 4,self.__in_features), nn.ReLU(inplace=True), nn.Dropout())
|
| 37 |
|
| 38 |
def forward(self, x):
|
| 39 |
x = self.layer1(x)
|
|
|
|
| 33 |
nn.AdaptiveMaxPool1d(4),
|
| 34 |
) # 128, 4,4
|
| 35 |
|
| 36 |
+
self.layer5 = nn.Sequential(nn.Linear(128 * 4, self.__in_features), nn.ReLU(inplace=True), nn.Dropout())
|
| 37 |
|
| 38 |
def forward(self, x):
|
| 39 |
x = self.layer1(x)
|
models/ResNet18_1d.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
|
|
|
|
| 4 |
class BasicBlock1D(nn.Module):
|
| 5 |
expansion = 1 # 扩展倍数,用于调整输出通道数
|
| 6 |
|
|
@@ -32,6 +33,7 @@ class BasicBlock1D(nn.Module):
|
|
| 32 |
|
| 33 |
return out
|
| 34 |
|
|
|
|
| 35 |
class ResNet1D(nn.Module):
|
| 36 |
def __init__(self, block=BasicBlock1D, layers=[2, 2, 2, 2]):
|
| 37 |
super(ResNet1D, self).__init__()
|
|
@@ -58,8 +60,7 @@ class ResNet1D(nn.Module):
|
|
| 58 |
# 如果需要调整通道数或步幅不为1,则定义下采样层
|
| 59 |
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
| 60 |
downsample = nn.Sequential(
|
| 61 |
-
nn.Conv1d(self.in_channels, out_channels * block.expansion,
|
| 62 |
-
kernel_size=1, stride=stride, bias=False),
|
| 63 |
nn.BatchNorm1d(out_channels * block.expansion),
|
| 64 |
)
|
| 65 |
|
|
@@ -96,19 +97,16 @@ class ResNet1D(nn.Module):
|
|
| 96 |
def output_num(self):
|
| 97 |
# 返回输出特征维度
|
| 98 |
return self.__in_features
|
| 99 |
-
|
|
|
|
| 100 |
def resnet1d18():
|
| 101 |
# 构建 ResNet1D-18 模型
|
| 102 |
return ResNet1D(layers=[2, 2, 2, 2])
|
| 103 |
|
|
|
|
| 104 |
if __name__ == "__main__":
|
| 105 |
-
|
| 106 |
-
model = resnet1d18() # 输出固定为 256 特征
|
| 107 |
print(model)
|
| 108 |
-
|
| 109 |
-
# 创建一个随机输入张量,批量大小为 8,信号长度为 1024
|
| 110 |
input_tensor = torch.randn(8, 1, 1024)
|
| 111 |
output = model(input_tensor)
|
| 112 |
-
|
| 113 |
-
# 打印输出形状
|
| 114 |
print("Output shape:", output.shape)
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
|
| 4 |
+
|
| 5 |
class BasicBlock1D(nn.Module):
|
| 6 |
expansion = 1 # 扩展倍数,用于调整输出通道数
|
| 7 |
|
|
|
|
| 33 |
|
| 34 |
return out
|
| 35 |
|
| 36 |
+
|
| 37 |
class ResNet1D(nn.Module):
|
| 38 |
def __init__(self, block=BasicBlock1D, layers=[2, 2, 2, 2]):
|
| 39 |
super(ResNet1D, self).__init__()
|
|
|
|
| 60 |
# 如果需要调整通道数或步幅不为1,则定义下采样层
|
| 61 |
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
| 62 |
downsample = nn.Sequential(
|
| 63 |
+
nn.Conv1d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
|
|
|
|
| 64 |
nn.BatchNorm1d(out_channels * block.expansion),
|
| 65 |
)
|
| 66 |
|
|
|
|
| 97 |
def output_num(self):
|
| 98 |
# 返回输出特征维度
|
| 99 |
return self.__in_features
|
| 100 |
+
|
| 101 |
+
|
| 102 |
def resnet1d18():
|
| 103 |
# 构建 ResNet1D-18 模型
|
| 104 |
return ResNet1D(layers=[2, 2, 2, 2])
|
| 105 |
|
| 106 |
+
|
| 107 |
if __name__ == "__main__":
|
| 108 |
+
model = resnet1d18()
|
|
|
|
| 109 |
print(model)
|
|
|
|
|
|
|
| 110 |
input_tensor = torch.randn(8, 1, 1024)
|
| 111 |
output = model(input_tensor)
|
|
|
|
|
|
|
| 112 |
print("Output shape:", output.shape)
|
requirements-cpu.txt
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==24.1.0
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.11.16
|
| 4 |
+
aiosignal==1.3.2
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.9.0
|
| 7 |
+
attrs==25.3.0
|
| 8 |
+
audioread==3.0.1
|
| 9 |
+
certifi==2025.1.31
|
| 10 |
+
cffi==1.17.1
|
| 11 |
+
charset-normalizer==3.4.1
|
| 12 |
+
click==8.1.8
|
| 13 |
+
colorama==0.4.6
|
| 14 |
+
contourpy==1.3.1
|
| 15 |
+
cycler==0.12.1
|
| 16 |
+
datasets==3.5.0
|
| 17 |
+
decorator==5.2.1
|
| 18 |
+
dill==0.3.8
|
| 19 |
+
fastapi==0.115.12
|
| 20 |
+
ffmpy==0.5.0
|
| 21 |
+
filelock==3.18.0
|
| 22 |
+
fonttools==4.57.0
|
| 23 |
+
frozenlist==1.5.0
|
| 24 |
+
fsspec==2024.12.0
|
| 25 |
+
gradio==5.24.0
|
| 26 |
+
gradio_client==1.8.0
|
| 27 |
+
groovy==0.1.2
|
| 28 |
+
h11==0.14.0
|
| 29 |
+
httpcore==1.0.8
|
| 30 |
+
httpx==0.28.1
|
| 31 |
+
huggingface-hub==0.30.2
|
| 32 |
+
idna==3.10
|
| 33 |
+
Jinja2==3.1.6
|
| 34 |
+
joblib==1.4.2
|
| 35 |
+
kiwisolver==1.4.8
|
| 36 |
+
lazy_loader==0.4
|
| 37 |
+
librosa==0.11.0
|
| 38 |
+
llvmlite==0.44.0
|
| 39 |
+
markdown-it-py==3.0.0
|
| 40 |
+
MarkupSafe==3.0.2
|
| 41 |
+
matplotlib==3.10.1
|
| 42 |
+
mdurl==0.1.2
|
| 43 |
+
mpmath==1.3.0
|
| 44 |
+
msgpack==1.1.0
|
| 45 |
+
multidict==6.4.3
|
| 46 |
+
multiprocess==0.70.16
|
| 47 |
+
networkx==3.4.2
|
| 48 |
+
numba==0.61.2
|
| 49 |
+
numpy==2.2.4
|
| 50 |
+
orjson==3.10.16
|
| 51 |
+
packaging==24.2
|
| 52 |
+
pandas==2.2.3
|
| 53 |
+
pillow==11.1.0
|
| 54 |
+
platformdirs==4.3.7
|
| 55 |
+
pooch==1.8.2
|
| 56 |
+
propcache==0.3.1
|
| 57 |
+
pyarrow==19.0.1
|
| 58 |
+
pycparser==2.22
|
| 59 |
+
pydantic==2.11.3
|
| 60 |
+
pydantic_core==2.33.1
|
| 61 |
+
pydub==0.25.1
|
| 62 |
+
Pygments==2.19.1
|
| 63 |
+
pyparsing==3.2.3
|
| 64 |
+
python-dateutil==2.9.0.post0
|
| 65 |
+
python-multipart==0.0.20
|
| 66 |
+
pytz==2025.2
|
| 67 |
+
PyYAML==6.0.2
|
| 68 |
+
requests==2.32.3
|
| 69 |
+
rich==14.0.0
|
| 70 |
+
ruff==0.11.5
|
| 71 |
+
safehttpx==0.1.6
|
| 72 |
+
scikit-learn==1.6.1
|
| 73 |
+
scipy==1.15.2
|
| 74 |
+
semantic-version==2.10.0
|
| 75 |
+
shellingham==1.5.4
|
| 76 |
+
six==1.17.0
|
| 77 |
+
sniffio==1.3.1
|
| 78 |
+
soundfile==0.13.1
|
| 79 |
+
soxr==0.5.0.post1
|
| 80 |
+
starlette==0.46.1
|
| 81 |
+
sympy==1.13.1
|
| 82 |
+
threadpoolctl==3.6.0
|
| 83 |
+
tomlkit==0.13.2
|
| 84 |
+
torch==2.6.0
|
| 85 |
+
torchaudio==2.6.0
|
| 86 |
+
torchvision==0.21.0
|
| 87 |
+
tqdm==4.67.1
|
| 88 |
+
typer==0.15.2
|
| 89 |
+
typing-inspection==0.4.0
|
| 90 |
+
typing_extensions==4.13.2
|
| 91 |
+
tzdata==2025.2
|
| 92 |
+
urllib3==2.4.0
|
| 93 |
+
uvicorn==0.34.0
|
| 94 |
+
websockets==15.0.1
|
| 95 |
+
xxhash==3.5.0
|
| 96 |
+
yarl==1.19.0
|
utils/fetch_conditions.py
CHANGED
|
@@ -18,7 +18,6 @@ if __name__ == "__main__":
|
|
| 18 |
os.environ["HF_DATASETS_CACHE"] = "./cache"
|
| 19 |
|
| 20 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
| 21 |
-
import json
|
| 22 |
|
| 23 |
|
| 24 |
def fetch_all_conditions_from_huggingface(dataset_name):
|
|
|
|
| 18 |
os.environ["HF_DATASETS_CACHE"] = "./cache"
|
| 19 |
|
| 20 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def fetch_all_conditions_from_huggingface(dataset_name):
|
utils/future_use.py
DELETED
|
@@ -1,208 +0,0 @@
|
|
| 1 |
-
# 暂时不使用的代码
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pywt
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import os
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class ContinuousWaveletTransform:
|
| 9 |
-
def __init__(self, fs, signals, save_path=None, wavelet="cmor1.5-1.0", freqNum=224):
|
| 10 |
-
"""
|
| 11 |
-
连续小波变换 (CWT) 计算类,支持单个信号或批量信号输入,并保存为 .npy 文件。
|
| 12 |
-
|
| 13 |
-
Args:
|
| 14 |
-
fs (_int_): 采样频率
|
| 15 |
-
signals (_np.array_): 输入信号,形状可以是 (signal_length,) 或 (batch_size, signal_length)
|
| 16 |
-
save_path (_str_): 如果提供,将保存 CWT 变换后的数据到 .npy 文件,默认为None,为None时不保存
|
| 17 |
-
wavelet (_str_): 连续小波类型(默认 'cmor1.5-1.0')
|
| 18 |
-
freqNum (_int_): 频率点个数(默认 224)
|
| 19 |
-
"""
|
| 20 |
-
self.fs = fs
|
| 21 |
-
self.save_path = save_path
|
| 22 |
-
|
| 23 |
-
# 确保路径存在
|
| 24 |
-
if save_path:
|
| 25 |
-
os.makedirs(save_path, exist_ok=True)
|
| 26 |
-
|
| 27 |
-
# 确保输入是 NumPy 数组
|
| 28 |
-
signals = np.asarray(signals, dtype=np.float32) # 使用 float32 节省内存
|
| 29 |
-
|
| 30 |
-
# 处理 batch 维度
|
| 31 |
-
if signals.ndim == 1:
|
| 32 |
-
signals = signals[np.newaxis, :] # (signal_length,) -> (1, signal_length)
|
| 33 |
-
|
| 34 |
-
self.batch_size, self.signal_length = signals.shape
|
| 35 |
-
self.time = np.arange(0, self.signal_length) / fs # 时间轴
|
| 36 |
-
self.widths = np.geomspace(1, 512, num=freqNum).astype(np.float32) # 频率尺度
|
| 37 |
-
|
| 38 |
-
# 预分配 CWT 结果矩阵
|
| 39 |
-
self.cwt_results = np.empty((self.batch_size, freqNum, self.signal_length), dtype=np.float32)
|
| 40 |
-
|
| 41 |
-
for i in range(self.batch_size):
|
| 42 |
-
signal = signals[i] - np.mean(signals[i]) # 去均值(去直流分量)
|
| 43 |
-
cwtmatr, freqs = pywt.cwt(signal, self.widths, wavelet, sampling_period=1 / fs)
|
| 44 |
-
cwt_result = np.abs(cwtmatr).astype(np.float32) # 取模值,转换为 float32
|
| 45 |
-
self.cwt_results[i] = cwt_result
|
| 46 |
-
|
| 47 |
-
# 保存 CWT 结果到 .npy 文件
|
| 48 |
-
if save_path:
|
| 49 |
-
np.save(os.path.join(save_path, f"cwt_{i:04d}.npy"), cwt_result)
|
| 50 |
-
print(f"CWT 结果已保存到 {os.path.join(save_path, f'cwt_{i:04d}.npy')}")
|
| 51 |
-
|
| 52 |
-
self.freqs = freqs.astype(np.float32) # 存储频率信息,节省内存
|
| 53 |
-
|
| 54 |
-
def plot(self, index=0, logspace=True, save_path=None):
|
| 55 |
-
"""
|
| 56 |
-
绘制并可选保存 CWT 结果。
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
index (_int_): 选择绘制的信号索引
|
| 60 |
-
logspace (_bool_): 是否使用对数坐标绘制频率轴
|
| 61 |
-
save_path (_str_ 或 None): 如果提供路径,则保存 `.npy` 文件,否则不保存
|
| 62 |
-
"""
|
| 63 |
-
if index >= self.batch_size:
|
| 64 |
-
raise ValueError(f"Index 超出范围!batch_size = {self.batch_size}, 但 index = {index}")
|
| 65 |
-
|
| 66 |
-
# 获取 CWT 结果
|
| 67 |
-
cwt_matrix = self.cwt_results[index]
|
| 68 |
-
|
| 69 |
-
# 选择是否保存 .npy
|
| 70 |
-
if save_path:
|
| 71 |
-
np.save(save_path, cwt_matrix)
|
| 72 |
-
print(f"CWT 结果已保存到 {save_path}")
|
| 73 |
-
|
| 74 |
-
# 绘图
|
| 75 |
-
fig, ax = plt.subplots(figsize=(10, 5))
|
| 76 |
-
pcm = ax.pcolormesh(self.time, self.freqs, cwt_matrix, shading="auto")
|
| 77 |
-
|
| 78 |
-
ax.set_yscale("log" if logspace else "linear")
|
| 79 |
-
ax.set_xlabel("Time (s)")
|
| 80 |
-
ax.set_ylabel("Frequency (Hz)")
|
| 81 |
-
ax.set_title(f"CWT Scaleogram (Signal {index})")
|
| 82 |
-
fig.colorbar(pcm, ax=ax)
|
| 83 |
-
|
| 84 |
-
plt.show() # 显示图像
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
if __name__ == "__main__":
|
| 88 |
-
fs = 1e3
|
| 89 |
-
N = 1e4
|
| 90 |
-
noise_power = 1e-3 * fs
|
| 91 |
-
time = np.arange(N) / float(fs)
|
| 92 |
-
mod = 2 * np.pi * 20 * np.cos(time)
|
| 93 |
-
carrier = np.sin(2 * np.pi * 100 * time + mod) # 频率调制
|
| 94 |
-
|
| 95 |
-
rng = np.random.default_rng()
|
| 96 |
-
noise = rng.normal(scale=np.sqrt(noise_power), size=time.shape)
|
| 97 |
-
noise *= np.exp(-time / 5)
|
| 98 |
-
x = carrier + noise
|
| 99 |
-
|
| 100 |
-
CWT = ContinuousWaveletTransform(fs, x)
|
| 101 |
-
CWT.plot(0, logspace=False)
|
| 102 |
-
|
| 103 |
-
# plt.plot(time, x)
|
| 104 |
-
# plt.show()
|
| 105 |
-
|
| 106 |
-
# ================================================================
|
| 107 |
-
# %%
|
| 108 |
-
import torch
|
| 109 |
-
import torch.nn as nn
|
| 110 |
-
import pandas as pd
|
| 111 |
-
from dataset.get_dataset import SignalDatasetCreator
|
| 112 |
-
from pathlib import Path
|
| 113 |
-
|
| 114 |
-
data_set = "BFDS-Project/Bearing-Fault-Diagnosis-System" # 数据集huggingface地址
|
| 115 |
-
labels = {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5} # 标签
|
| 116 |
-
transfer_task = [["CWRU224", "12kDriveEnd"], ["CWRU224", "12kFanEnd"]] # 迁移方向
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
signal_dataset_creator = SignalDatasetCreator(data_set, labels, transfer_task, stratified_sampling=True)
|
| 120 |
-
dataloaders = {}
|
| 121 |
-
dataloaders["source_train"], dataloaders["source_val"], dataloaders["target_train"], dataloaders["target_val"] = signal_dataset_creator.data_split(
|
| 122 |
-
64, 0, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
# %%
|
| 126 |
-
import models
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 130 |
-
model = getattr(models, "ResNet")().to(device)
|
| 131 |
-
bottleneck_layer = nn.Sequential(
|
| 132 |
-
nn.Linear(model.output_num(), 256),
|
| 133 |
-
nn.ReLU(inplace=True),
|
| 134 |
-
nn.Dropout(),
|
| 135 |
-
).to(device)
|
| 136 |
-
classifier_layer = nn.Linear(256, len(labels)).to(device)
|
| 137 |
-
model_all = nn.Sequential(model, bottleneck_layer, classifier_layer).to(device)
|
| 138 |
-
model_all.load_state_dict(torch.load("checkpoint/150_0/149-0.3942-best_model.bin")) # 加载模型参数
|
| 139 |
-
model_without_head = nn.Sequential(*list(model_all.children())[:-1])
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
# %%
|
| 143 |
-
from sklearn.manifold import TSNE
|
| 144 |
-
import matplotlib.pyplot as plt
|
| 145 |
-
import numpy as np
|
| 146 |
-
from matplotlib.colors import ListedColormap
|
| 147 |
-
|
| 148 |
-
# 定义一个固定的颜色映射
|
| 149 |
-
num_classes = len(set(label for dataloader in dataloaders.values() for _, labels in dataloader for label in labels.numpy()))
|
| 150 |
-
colors = plt.cm.get_cmap("tab10", num_classes) # 使用 "tab10" 颜色映射
|
| 151 |
-
cmap = ListedColormap(colors.colors)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def plot_tsne(dataloader, title, ax):
|
| 155 |
-
model_all.eval()
|
| 156 |
-
with torch.no_grad():
|
| 157 |
-
for i, (inputs, labels) in enumerate(dataloader):
|
| 158 |
-
inputs = inputs.to(device)
|
| 159 |
-
labels = labels.to(device)
|
| 160 |
-
outputs = model_without_head(inputs)
|
| 161 |
-
# Collect all points across all batches
|
| 162 |
-
if i == 0:
|
| 163 |
-
all_points = outputs.cpu().numpy()
|
| 164 |
-
all_labels = labels.cpu().numpy()
|
| 165 |
-
else:
|
| 166 |
-
all_points = np.concatenate((all_points, outputs.cpu().numpy()), axis=0)
|
| 167 |
-
all_labels = np.concatenate((all_labels, labels.cpu().numpy()), axis=0)
|
| 168 |
-
|
| 169 |
-
# Apply t-SNE to reduce dimensions to 2D
|
| 170 |
-
tsne = TSNE(n_components=2, random_state=42)
|
| 171 |
-
reduced_points = tsne.fit_transform(all_points)
|
| 172 |
-
|
| 173 |
-
# Plot the reduced points
|
| 174 |
-
scatter = ax.scatter(reduced_points[:, 0], reduced_points[:, 1], c=all_labels, cmap=cmap, s=10)
|
| 175 |
-
ax.set_title(title)
|
| 176 |
-
ax.set_xlabel("Dimension 1")
|
| 177 |
-
ax.set_ylabel("Dimension 2")
|
| 178 |
-
return scatter, reduced_points, all_labels
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
# Create a 2x2 subplot
|
| 182 |
-
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 183 |
-
|
| 184 |
-
# Plot each dataloader
|
| 185 |
-
# sc1,_ = plot_tsne(dataloaders["source_train"], "Source Train", axes[0, 0])
|
| 186 |
-
sc2, reduced_points2, all_labels2 = plot_tsne(dataloaders["source_val"], "Source Val", axes[0, 1])
|
| 187 |
-
# sc3,_ = plot_tsne(dataloaders["target_train"], "Target Train", axes[1, 0])
|
| 188 |
-
# sc4,_ = plot_tsne(dataloaders["target_val"], "Target Val", axes[1, 1])
|
| 189 |
-
|
| 190 |
-
# Add a colorbar to the figure
|
| 191 |
-
# cbar = fig.colorbar(sc1, ax=axes, orientation="vertical", fraction=0.02, pad=0.04)
|
| 192 |
-
# cbar.set_label("Labels")
|
| 193 |
-
|
| 194 |
-
# Adjust layout and show the plot
|
| 195 |
-
plt.tight_layout()
|
| 196 |
-
plt.show()
|
| 197 |
-
|
| 198 |
-
# %%
|
| 199 |
-
reduced_points2, all_labels2
|
| 200 |
-
|
| 201 |
-
# %%
|
| 202 |
-
import pandas as pd
|
| 203 |
-
|
| 204 |
-
df = pd.DataFrame(reduced_points2)
|
| 205 |
-
df["label"] = all_labels2 # 将标签添加为新列
|
| 206 |
-
|
| 207 |
-
# 保存为 CSV 文件
|
| 208 |
-
df.to_csv("checkpoint/reduced_points_with_labels.csv", index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/predict.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
|
|
|
|
|
|
| 3 |
import models
|
| 4 |
from dataset.get_dataset import get_user_dataset
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
|
| 7 |
|
| 8 |
def predict(model_state_dict, signal_file, args):
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
import models
|
| 6 |
from dataset.get_dataset import get_user_dataset
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def predict(model_state_dict, signal_file, args):
|