Str0keOOOO commited on
Commit
ec09c5e
·
1 Parent(s): 9e14f38

v1:第一版修整代码

Browse files
BFDS_train.py CHANGED
@@ -2,8 +2,8 @@ import os
2
  import logging
3
  import warnings
4
  import json
5
- from datetime import datetime
6
  import requests
 
7
 
8
  if __name__ == "__main__":
9
  try:
@@ -17,7 +17,7 @@ if __name__ == "__main__":
17
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
18
  print(f"无法连接到 Hugging Face:换源到{os.environ['HF_ENDPOINT']}")
19
  if not os.path.exists("./cache"):
20
- os.makedirs("./cache") # 创建缓存目录
21
  os.environ["HF_DATASETS_CACHE"] = "./cache"
22
 
23
  from utils.logger import setlogger
@@ -35,14 +35,14 @@ class Argument:
35
  self.data_set = "BFDS-Project/Bearing-Fault-Diagnosis-System" # 数据集huggingface地址
36
  self.conditions = fetch_all_conditions_from_huggingface(self.data_set) # 数据集的配置和分割信息如果想要知道明确的信息来确定迁移方向请自行运行fetch_conditions.py
37
  self.labels = {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5} # 标签
38
- self.transfer_task = [["CWRU224", "12kDriveEnd"], ["CWRU224", "12kFanEnd"]] # 迁移方向
39
- self.target_domain_labeled = False # 表示目标域在训练中是否带有标签
40
 
41
  # 预处理
42
  self.normalize_type = None # 归一化方式, mean-std/min-max/None
43
  self.stratified_sampling = True # 是否分层采样, True/False
44
  # 模型
45
- self.model_name = "ResNet" # 模型名
46
  self.bottleneck = True # 是否使用bottleneck层
47
  self.bottleneck_num = 256 # bottleneck层的输出维数
48
 
@@ -118,7 +118,8 @@ class Argument:
118
  "data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
119
  "conditions": fetch_all_conditions_from_huggingface("BFDS-Project/Bearing-Fault-Diagnosis-System"),
120
  "labels": {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5},
121
- "transfer_task": [["CWRU224", "12kDriveEnd"], ["CWRU224", "12kFanEnd"]],
 
122
  "normalize_type": None,
123
  "stratified_sampling": True,
124
  "model_name": "CNN",
@@ -126,7 +127,7 @@ class Argument:
126
  "bottleneck_num": 256,
127
  "batch_size": 64,
128
  "cuda_device": "0",
129
- "max_epoch": 2,
130
  "num_workers": 0,
131
  "checkpoint_dir": "./checkpoint",
132
  "print_step": 50,
@@ -136,8 +137,8 @@ class Argument:
136
  "lr": 1e-3,
137
  "lr_scheduler": "step",
138
  "gamma": 0.1,
139
- "steps": [150, 250],
140
- "middle_epoch": 0,
141
  "distance_option": True,
142
  "distance_loss": "JMMD",
143
  "distance_tradeoff": "Step",
 
2
  import logging
3
  import warnings
4
  import json
 
5
  import requests
6
+ from datetime import datetime
7
 
8
  if __name__ == "__main__":
9
  try:
 
17
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
18
  print(f"无法连接到 Hugging Face:换源到{os.environ['HF_ENDPOINT']}")
19
  if not os.path.exists("./cache"):
20
+ os.makedirs("./cache")
21
  os.environ["HF_DATASETS_CACHE"] = "./cache"
22
 
23
  from utils.logger import setlogger
 
35
  self.data_set = "BFDS-Project/Bearing-Fault-Diagnosis-System" # 数据集huggingface地址
36
  self.conditions = fetch_all_conditions_from_huggingface(self.data_set) # 数据集的配置和分割信息如果想要知道明确的信息来确定迁移方向请自行运行fetch_conditions.py
37
  self.labels = {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5} # 标签
38
+ self.transfer_task = [["CWRURPM", "12kDriveEndrpm1730"], ["CWRURPM", "12kDriveEndrpm1750"]] # 迁移方向
39
+ self.target_domain_labeled = True # 表示目标域在训练中是否带有标签
40
 
41
  # 预处理
42
  self.normalize_type = None # 归一化方式, mean-std/min-max/None
43
  self.stratified_sampling = True # 是否分层采样, True/False
44
  # 模型
45
+ self.model_name = "CNN" # 模型名
46
  self.bottleneck = True # 是否使用bottleneck层
47
  self.bottleneck_num = 256 # bottleneck层的输出维数
48
 
 
118
  "data_set": "BFDS-Project/Bearing-Fault-Diagnosis-System",
119
  "conditions": fetch_all_conditions_from_huggingface("BFDS-Project/Bearing-Fault-Diagnosis-System"),
120
  "labels": {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5},
121
+ "transfer_task": [["CWRURPM", "12kDriveEndrpm1730"], ["CWRURPM", "12kDriveEndrpm1750"]],
122
+ "target_domain_labeled": False,
123
  "normalize_type": None,
124
  "stratified_sampling": True,
125
  "model_name": "CNN",
 
127
  "bottleneck_num": 256,
128
  "batch_size": 64,
129
  "cuda_device": "0",
130
+ "max_epoch": 100,
131
  "num_workers": 0,
132
  "checkpoint_dir": "./checkpoint",
133
  "print_step": 50,
 
137
  "lr": 1e-3,
138
  "lr_scheduler": "step",
139
  "gamma": 0.1,
140
+ "steps": [25, 75],
141
+ "middle_epoch": 50,
142
  "distance_option": True,
143
  "distance_loss": "JMMD",
144
  "distance_tradeoff": "Step",
BFDS_web.py CHANGED
@@ -1,6 +1,11 @@
 
1
  import os
 
2
  import requests
 
 
3
  import zipfile
 
4
 
5
  if __name__ == "__main__":
6
  try:
@@ -20,18 +25,11 @@ if __name__ == "__main__":
20
 
21
  import gradio as gr
22
  from BFDS_train import Argument
23
- import torch
24
- from utils.predict import predict
25
- import pandas as pd
26
-
27
- import logging
28
- import warnings
29
- from datetime import datetime
30
-
31
-
32
  from utils.logger import setlogger
 
33
  from utils.train import train_utils
34
 
 
35
  # 初始化 Argument 实例
36
  args = Argument()
37
  args.set_recommended_params()
 
1
+ import logging
2
  import os
3
+ import pandas as pd
4
  import requests
5
+ import torch
6
+ import warnings
7
  import zipfile
8
+ from datetime import datetime
9
 
10
  if __name__ == "__main__":
11
  try:
 
25
 
26
  import gradio as gr
27
  from BFDS_train import Argument
 
 
 
 
 
 
 
 
 
28
  from utils.logger import setlogger
29
+ from utils.predict import predict
30
  from utils.train import train_utils
31
 
32
+
33
  # 初始化 Argument 实例
34
  args = Argument()
35
  args.set_recommended_params()
Dockerfile.cpu ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-alpine
2
+
3
+ # 设置环境变量以减少缓存和输出
4
+ ENV PYTHONDONTWRITEBYTECODE=1 \
5
+ PYTHONUNBUFFERED=1
6
+
7
+ # 设置工作目录
8
+ WORKDIR /app
9
+
10
+ # 安装必要的系统依赖
11
+ RUN apk add --no-cache --virtual .build-deps \
12
+ gcc \
13
+ musl-dev \
14
+ libffi-dev \
15
+ openssl-dev \
16
+ make \
17
+ && apk add --no-cache \
18
+ bash \
19
+ git \
20
+ && pip install --upgrade pip
21
+
22
+ # 安装 Python 依赖并清理构建依赖
23
+ COPY requirements-cpu.txt .
24
+ RUN pip install --no-cache-dir -r requirements-cpu.txt \
25
+ && apk del .build-deps
26
+
27
+ # 复制项目文件
28
+ COPY . .
29
+
30
+ # 设置默认运行命令
31
+ CMD ["python", "BFDS_web.py"]
README.md CHANGED
@@ -1,7 +1 @@
1
- # Bearing-Fault-Diagnosis-System
2
-
3
- ## Requirements
4
- - Python 3.10
5
- - matplotlib 3.10.0
6
- - numpy 2.2.2
7
- - PyWavelets 1.8.0
 
1
+ # Bearing-Fault-Diagnosis-System
 
 
 
 
 
 
dataset/get_data.py CHANGED
@@ -1,8 +1,8 @@
1
- import pandas as pd
2
- import numpy as np
3
- from datasets import load_dataset
4
  import librosa
5
  import mimetypes
 
 
 
6
 
7
  # ===============================================================
8
  # 加载有标签的数据集(n , m + 1)最后一列是标签
 
 
 
 
1
  import librosa
2
  import mimetypes
3
+ import numpy as np
4
+ import pandas as pd
5
+ from datasets import load_dataset
6
 
7
  # ===============================================================
8
  # 加载有标签的数据集(n , m + 1)最后一列是标签
dataset/get_dataset.py CHANGED
@@ -1,7 +1,9 @@
 
 
1
  import pandas as pd
2
  import torch
3
- from torch.utils.data import Dataset, DataLoader, Subset, random_split
4
- from typing import Optional, Literal
5
  from dataset.get_data import get_huggingface_dataset, get_local_dataset, get_user_dataset
6
 
7
 
 
1
+ from typing import Literal, Optional
2
+
3
  import pandas as pd
4
  import torch
5
+ from torch.utils.data import DataLoader, Dataset, Subset, random_split
6
+
7
  from dataset.get_data import get_huggingface_dataset, get_local_dataset, get_user_dataset
8
 
9
 
models/CNN.py CHANGED
@@ -33,7 +33,7 @@ class CNN(nn.Module):
33
  nn.AdaptiveMaxPool1d(4),
34
  ) # 128, 4,4
35
 
36
- self.layer5 = nn.Sequential(nn.Linear(128 * 4,self.__in_features), nn.ReLU(inplace=True), nn.Dropout())
37
 
38
  def forward(self, x):
39
  x = self.layer1(x)
 
33
  nn.AdaptiveMaxPool1d(4),
34
  ) # 128, 4,4
35
 
36
+ self.layer5 = nn.Sequential(nn.Linear(128 * 4, self.__in_features), nn.ReLU(inplace=True), nn.Dropout())
37
 
38
  def forward(self, x):
39
  x = self.layer1(x)
models/ResNet18_1d.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
 
 
4
  class BasicBlock1D(nn.Module):
5
  expansion = 1 # 扩展倍数,用于调整输出通道数
6
 
@@ -32,6 +33,7 @@ class BasicBlock1D(nn.Module):
32
 
33
  return out
34
 
 
35
  class ResNet1D(nn.Module):
36
  def __init__(self, block=BasicBlock1D, layers=[2, 2, 2, 2]):
37
  super(ResNet1D, self).__init__()
@@ -58,8 +60,7 @@ class ResNet1D(nn.Module):
58
  # 如果需要调整通道数或步幅不为1,则定义下采样层
59
  if stride != 1 or self.in_channels != out_channels * block.expansion:
60
  downsample = nn.Sequential(
61
- nn.Conv1d(self.in_channels, out_channels * block.expansion,
62
- kernel_size=1, stride=stride, bias=False),
63
  nn.BatchNorm1d(out_channels * block.expansion),
64
  )
65
 
@@ -96,19 +97,16 @@ class ResNet1D(nn.Module):
96
  def output_num(self):
97
  # 返回输出特征维度
98
  return self.__in_features
99
-
 
100
  def resnet1d18():
101
  # 构建 ResNet1D-18 模型
102
  return ResNet1D(layers=[2, 2, 2, 2])
103
 
 
104
  if __name__ == "__main__":
105
- # 调试和测试模型
106
- model = resnet1d18() # 输出固定为 256 特征
107
  print(model)
108
-
109
- # 创建一个随机输入张量,批量大小为 8,信号长度为 1024
110
  input_tensor = torch.randn(8, 1, 1024)
111
  output = model(input_tensor)
112
-
113
- # 打印输出形状
114
  print("Output shape:", output.shape)
 
1
  import torch
2
  import torch.nn as nn
3
 
4
+
5
  class BasicBlock1D(nn.Module):
6
  expansion = 1 # 扩展倍数,用于调整输出通道数
7
 
 
33
 
34
  return out
35
 
36
+
37
  class ResNet1D(nn.Module):
38
  def __init__(self, block=BasicBlock1D, layers=[2, 2, 2, 2]):
39
  super(ResNet1D, self).__init__()
 
60
  # 如果需要调整通道数或步幅不为1,则定义下采样层
61
  if stride != 1 or self.in_channels != out_channels * block.expansion:
62
  downsample = nn.Sequential(
63
+ nn.Conv1d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
 
64
  nn.BatchNorm1d(out_channels * block.expansion),
65
  )
66
 
 
97
  def output_num(self):
98
  # 返回输出特征维度
99
  return self.__in_features
100
+
101
+
102
  def resnet1d18():
103
  # 构建 ResNet1D-18 模型
104
  return ResNet1D(layers=[2, 2, 2, 2])
105
 
106
+
107
  if __name__ == "__main__":
108
+ model = resnet1d18()
 
109
  print(model)
 
 
110
  input_tensor = torch.randn(8, 1, 1024)
111
  output = model(input_tensor)
 
 
112
  print("Output shape:", output.shape)
requirements-cpu.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.16
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ attrs==25.3.0
8
+ audioread==3.0.1
9
+ certifi==2025.1.31
10
+ cffi==1.17.1
11
+ charset-normalizer==3.4.1
12
+ click==8.1.8
13
+ colorama==0.4.6
14
+ contourpy==1.3.1
15
+ cycler==0.12.1
16
+ datasets==3.5.0
17
+ decorator==5.2.1
18
+ dill==0.3.8
19
+ fastapi==0.115.12
20
+ ffmpy==0.5.0
21
+ filelock==3.18.0
22
+ fonttools==4.57.0
23
+ frozenlist==1.5.0
24
+ fsspec==2024.12.0
25
+ gradio==5.24.0
26
+ gradio_client==1.8.0
27
+ groovy==0.1.2
28
+ h11==0.14.0
29
+ httpcore==1.0.8
30
+ httpx==0.28.1
31
+ huggingface-hub==0.30.2
32
+ idna==3.10
33
+ Jinja2==3.1.6
34
+ joblib==1.4.2
35
+ kiwisolver==1.4.8
36
+ lazy_loader==0.4
37
+ librosa==0.11.0
38
+ llvmlite==0.44.0
39
+ markdown-it-py==3.0.0
40
+ MarkupSafe==3.0.2
41
+ matplotlib==3.10.1
42
+ mdurl==0.1.2
43
+ mpmath==1.3.0
44
+ msgpack==1.1.0
45
+ multidict==6.4.3
46
+ multiprocess==0.70.16
47
+ networkx==3.4.2
48
+ numba==0.61.2
49
+ numpy==2.2.4
50
+ orjson==3.10.16
51
+ packaging==24.2
52
+ pandas==2.2.3
53
+ pillow==11.1.0
54
+ platformdirs==4.3.7
55
+ pooch==1.8.2
56
+ propcache==0.3.1
57
+ pyarrow==19.0.1
58
+ pycparser==2.22
59
+ pydantic==2.11.3
60
+ pydantic_core==2.33.1
61
+ pydub==0.25.1
62
+ Pygments==2.19.1
63
+ pyparsing==3.2.3
64
+ python-dateutil==2.9.0.post0
65
+ python-multipart==0.0.20
66
+ pytz==2025.2
67
+ PyYAML==6.0.2
68
+ requests==2.32.3
69
+ rich==14.0.0
70
+ ruff==0.11.5
71
+ safehttpx==0.1.6
72
+ scikit-learn==1.6.1
73
+ scipy==1.15.2
74
+ semantic-version==2.10.0
75
+ shellingham==1.5.4
76
+ six==1.17.0
77
+ sniffio==1.3.1
78
+ soundfile==0.13.1
79
+ soxr==0.5.0.post1
80
+ starlette==0.46.1
81
+ sympy==1.13.1
82
+ threadpoolctl==3.6.0
83
+ tomlkit==0.13.2
84
+ torch==2.6.0
85
+ torchaudio==2.6.0
86
+ torchvision==0.21.0
87
+ tqdm==4.67.1
88
+ typer==0.15.2
89
+ typing-inspection==0.4.0
90
+ typing_extensions==4.13.2
91
+ tzdata==2025.2
92
+ urllib3==2.4.0
93
+ uvicorn==0.34.0
94
+ websockets==15.0.1
95
+ xxhash==3.5.0
96
+ yarl==1.19.0
utils/fetch_conditions.py CHANGED
@@ -18,7 +18,6 @@ if __name__ == "__main__":
18
  os.environ["HF_DATASETS_CACHE"] = "./cache"
19
 
20
  from datasets import get_dataset_config_names, get_dataset_split_names
21
- import json
22
 
23
 
24
  def fetch_all_conditions_from_huggingface(dataset_name):
 
18
  os.environ["HF_DATASETS_CACHE"] = "./cache"
19
 
20
  from datasets import get_dataset_config_names, get_dataset_split_names
 
21
 
22
 
23
  def fetch_all_conditions_from_huggingface(dataset_name):
utils/future_use.py DELETED
@@ -1,208 +0,0 @@
1
- # 暂时不使用的代码
2
- import numpy as np
3
- import pywt
4
- import matplotlib.pyplot as plt
5
- import os
6
-
7
-
8
- class ContinuousWaveletTransform:
9
- def __init__(self, fs, signals, save_path=None, wavelet="cmor1.5-1.0", freqNum=224):
10
- """
11
- 连续小波变换 (CWT) 计算类,支持单个信号或批量信号输入,并保存为 .npy 文件。
12
-
13
- Args:
14
- fs (_int_): 采样频率
15
- signals (_np.array_): 输入信号,形状可以是 (signal_length,) 或 (batch_size, signal_length)
16
- save_path (_str_): 如果提供,将保存 CWT 变换后的数据到 .npy 文件,默认为None,为None时不保存
17
- wavelet (_str_): 连续小波类型(默认 'cmor1.5-1.0')
18
- freqNum (_int_): 频率点个数(默认 224)
19
- """
20
- self.fs = fs
21
- self.save_path = save_path
22
-
23
- # 确保路径存在
24
- if save_path:
25
- os.makedirs(save_path, exist_ok=True)
26
-
27
- # 确保输入是 NumPy 数组
28
- signals = np.asarray(signals, dtype=np.float32) # 使用 float32 节省内存
29
-
30
- # 处理 batch 维度
31
- if signals.ndim == 1:
32
- signals = signals[np.newaxis, :] # (signal_length,) -> (1, signal_length)
33
-
34
- self.batch_size, self.signal_length = signals.shape
35
- self.time = np.arange(0, self.signal_length) / fs # 时间轴
36
- self.widths = np.geomspace(1, 512, num=freqNum).astype(np.float32) # 频率尺度
37
-
38
- # 预分配 CWT 结果矩阵
39
- self.cwt_results = np.empty((self.batch_size, freqNum, self.signal_length), dtype=np.float32)
40
-
41
- for i in range(self.batch_size):
42
- signal = signals[i] - np.mean(signals[i]) # 去均值(去直流分量)
43
- cwtmatr, freqs = pywt.cwt(signal, self.widths, wavelet, sampling_period=1 / fs)
44
- cwt_result = np.abs(cwtmatr).astype(np.float32) # 取模值,转换为 float32
45
- self.cwt_results[i] = cwt_result
46
-
47
- # 保存 CWT 结果到 .npy 文件
48
- if save_path:
49
- np.save(os.path.join(save_path, f"cwt_{i:04d}.npy"), cwt_result)
50
- print(f"CWT 结果已保存到 {os.path.join(save_path, f'cwt_{i:04d}.npy')}")
51
-
52
- self.freqs = freqs.astype(np.float32) # 存储频率信息,节省内存
53
-
54
- def plot(self, index=0, logspace=True, save_path=None):
55
- """
56
- 绘制并可选保存 CWT 结果。
57
-
58
- Args:
59
- index (_int_): 选择绘制的信号索引
60
- logspace (_bool_): 是否使用对数坐标绘制频率轴
61
- save_path (_str_ 或 None): 如果提供路径,则保存 `.npy` 文件,否则不保存
62
- """
63
- if index >= self.batch_size:
64
- raise ValueError(f"Index 超出范围!batch_size = {self.batch_size}, 但 index = {index}")
65
-
66
- # 获取 CWT 结果
67
- cwt_matrix = self.cwt_results[index]
68
-
69
- # 选择是否保存 .npy
70
- if save_path:
71
- np.save(save_path, cwt_matrix)
72
- print(f"CWT 结果已保存到 {save_path}")
73
-
74
- # 绘图
75
- fig, ax = plt.subplots(figsize=(10, 5))
76
- pcm = ax.pcolormesh(self.time, self.freqs, cwt_matrix, shading="auto")
77
-
78
- ax.set_yscale("log" if logspace else "linear")
79
- ax.set_xlabel("Time (s)")
80
- ax.set_ylabel("Frequency (Hz)")
81
- ax.set_title(f"CWT Scaleogram (Signal {index})")
82
- fig.colorbar(pcm, ax=ax)
83
-
84
- plt.show() # 显示图像
85
-
86
-
87
- if __name__ == "__main__":
88
- fs = 1e3
89
- N = 1e4
90
- noise_power = 1e-3 * fs
91
- time = np.arange(N) / float(fs)
92
- mod = 2 * np.pi * 20 * np.cos(time)
93
- carrier = np.sin(2 * np.pi * 100 * time + mod) # 频率调制
94
-
95
- rng = np.random.default_rng()
96
- noise = rng.normal(scale=np.sqrt(noise_power), size=time.shape)
97
- noise *= np.exp(-time / 5)
98
- x = carrier + noise
99
-
100
- CWT = ContinuousWaveletTransform(fs, x)
101
- CWT.plot(0, logspace=False)
102
-
103
- # plt.plot(time, x)
104
- # plt.show()
105
-
106
- # ================================================================
107
- # %%
108
- import torch
109
- import torch.nn as nn
110
- import pandas as pd
111
- from dataset.get_dataset import SignalDatasetCreator
112
- from pathlib import Path
113
-
114
- data_set = "BFDS-Project/Bearing-Fault-Diagnosis-System" # 数据集huggingface地址
115
- labels = {"Normal Baseline Data": 0, "Ball": 1, "Inner Race": 2, "Outer Race Centered": 3, "Outer Race Opposite": 4, "Outer Race Orthogonal": 5} # 标签
116
- transfer_task = [["CWRU224", "12kDriveEnd"], ["CWRU224", "12kFanEnd"]] # 迁移方向
117
-
118
-
119
- signal_dataset_creator = SignalDatasetCreator(data_set, labels, transfer_task, stratified_sampling=True)
120
- dataloaders = {}
121
- dataloaders["source_train"], dataloaders["source_val"], dataloaders["target_train"], dataloaders["target_val"] = signal_dataset_creator.data_split(
122
- 64, 0, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
- )
124
-
125
- # %%
126
- import models
127
-
128
-
129
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
- model = getattr(models, "ResNet")().to(device)
131
- bottleneck_layer = nn.Sequential(
132
- nn.Linear(model.output_num(), 256),
133
- nn.ReLU(inplace=True),
134
- nn.Dropout(),
135
- ).to(device)
136
- classifier_layer = nn.Linear(256, len(labels)).to(device)
137
- model_all = nn.Sequential(model, bottleneck_layer, classifier_layer).to(device)
138
- model_all.load_state_dict(torch.load("checkpoint/150_0/149-0.3942-best_model.bin")) # 加载模型参数
139
- model_without_head = nn.Sequential(*list(model_all.children())[:-1])
140
-
141
-
142
- # %%
143
- from sklearn.manifold import TSNE
144
- import matplotlib.pyplot as plt
145
- import numpy as np
146
- from matplotlib.colors import ListedColormap
147
-
148
- # 定义一个固定的颜色映射
149
- num_classes = len(set(label for dataloader in dataloaders.values() for _, labels in dataloader for label in labels.numpy()))
150
- colors = plt.cm.get_cmap("tab10", num_classes) # 使用 "tab10" 颜色映射
151
- cmap = ListedColormap(colors.colors)
152
-
153
-
154
- def plot_tsne(dataloader, title, ax):
155
- model_all.eval()
156
- with torch.no_grad():
157
- for i, (inputs, labels) in enumerate(dataloader):
158
- inputs = inputs.to(device)
159
- labels = labels.to(device)
160
- outputs = model_without_head(inputs)
161
- # Collect all points across all batches
162
- if i == 0:
163
- all_points = outputs.cpu().numpy()
164
- all_labels = labels.cpu().numpy()
165
- else:
166
- all_points = np.concatenate((all_points, outputs.cpu().numpy()), axis=0)
167
- all_labels = np.concatenate((all_labels, labels.cpu().numpy()), axis=0)
168
-
169
- # Apply t-SNE to reduce dimensions to 2D
170
- tsne = TSNE(n_components=2, random_state=42)
171
- reduced_points = tsne.fit_transform(all_points)
172
-
173
- # Plot the reduced points
174
- scatter = ax.scatter(reduced_points[:, 0], reduced_points[:, 1], c=all_labels, cmap=cmap, s=10)
175
- ax.set_title(title)
176
- ax.set_xlabel("Dimension 1")
177
- ax.set_ylabel("Dimension 2")
178
- return scatter, reduced_points, all_labels
179
-
180
-
181
- # Create a 2x2 subplot
182
- fig, axes = plt.subplots(2, 2, figsize=(12, 10))
183
-
184
- # Plot each dataloader
185
- # sc1,_ = plot_tsne(dataloaders["source_train"], "Source Train", axes[0, 0])
186
- sc2, reduced_points2, all_labels2 = plot_tsne(dataloaders["source_val"], "Source Val", axes[0, 1])
187
- # sc3,_ = plot_tsne(dataloaders["target_train"], "Target Train", axes[1, 0])
188
- # sc4,_ = plot_tsne(dataloaders["target_val"], "Target Val", axes[1, 1])
189
-
190
- # Add a colorbar to the figure
191
- # cbar = fig.colorbar(sc1, ax=axes, orientation="vertical", fraction=0.02, pad=0.04)
192
- # cbar.set_label("Labels")
193
-
194
- # Adjust layout and show the plot
195
- plt.tight_layout()
196
- plt.show()
197
-
198
- # %%
199
- reduced_points2, all_labels2
200
-
201
- # %%
202
- import pandas as pd
203
-
204
- df = pd.DataFrame(reduced_points2)
205
- df["label"] = all_labels2 # 将标签添加为新列
206
-
207
- # 保存为 CSV 文件
208
- df.to_csv("checkpoint/reduced_points_with_labels.csv", index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/predict.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import torch.nn as nn
 
 
3
  import models
4
  from dataset.get_dataset import get_user_dataset
5
- import torch.nn.functional as F
6
 
7
 
8
  def predict(model_state_dict, signal_file, args):
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
  import models
6
  from dataset.get_dataset import get_user_dataset
 
7
 
8
 
9
  def predict(model_state_dict, signal_file, args):