caobin commited on
Commit
38f7d61
·
verified ·
1 Parent(s): d93aac3

Upload 24 files

Browse files
src/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Instructions for replication
2
+
3
+ This directory contains all the source code needed to reproduce this work.
4
+
5
+ ### Data preparation
6
+
7
+ To directly run the train and validation script in this directory, data preparation needs to be done. The [OneDrive link](https://hkustgz-my.sharepoint.com/:f:/g/personal/bcao686_connect_hkust-gz_edu_cn/EhdJLtou8I1MoUJCu-KCoboBfi-wOp00WAlQCrONxjoYgg?e=rltgFE) contains all the training and synthetic testing data used in this work, stored in data.zip. This link also contains the pretrained model for single-phase and di-phase identification.
8
+
9
+ File single-phase_checkpoint_0200.pth and file bi-phase_checkpoint_2000.pth from the link above is the pretrained model, place them under directory "pretrained".
10
+
11
+ File data.zip contains the data and the annotaion file. Place directory "train" and "val" from data.zip under directory "data", place the annotation files anno_train.csv and anno_val.csv under directory "annotation".
12
+
13
+ ### Model Trianing
14
+
15
+ #### Single-phase
16
+
17
+ Run ```python train_single-phase.py``` to train the single-phase identification model from scratch. To train the model on your data, addtional parameters need to be set: ```python train_single-phase.py --data_dir_train=[your training data] --data_dir_val=[your validation data] --anno_train=[your anno file for training data] --anno_val=[your anno file for validation data]```.
18
+
19
+ #### Bi-phase
20
+
21
+ Run ```python train_bi-phase.py``` to train the bi-phase identification model. The bi-phase identification model is trained based on single-phase model, you can change the default setting by set the parameter ```load_path=[your pretrained single-phase model]```.
22
+
23
+ ### Model validation
24
+
25
+ Run ```python train_single-phase.py``` and ```python val_bi-phase.py``` to run the validation code at default setting.
26
+
27
+ If you wish to validate the model on your data, plase format your data using data_format.py
src/annotation/# place anno_train.csv and anno_val.csv here ADDED
File without changes
src/annotation/anno_struc.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/data/# place directory train and directory val here ADDED
File without changes
src/data_format.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from scipy import interpolate
4
+
5
+ global dataWriter
6
+
7
+ def convert_file(file_path):
8
+ suffix = file_path.split('.')[-1]
9
+ if suffix not in ['txt', 'csv', 'xy']:
10
+ Warning(f'File {file_path} not supported, skiping...')
11
+ return None
12
+
13
+ if suffix == 'txt':
14
+ return txt_to_csv(file_path)
15
+ elif suffix == 'csv':
16
+ return csv_to_csv(file_path)
17
+ elif suffix == 'xy':
18
+ return xy_to_csv(file_path)
19
+
20
+ def txt_to_csv(file_path):
21
+ f = open(file_path, 'r')
22
+ rows = []
23
+ for line in f.readlines():
24
+ line = line.strip('\n')
25
+ line = line.replace('\t', ' ')
26
+ line = [x for x in line.split(' ') if x != '']
27
+ if len(line) == 3:
28
+ try:
29
+ line = [line[0], float(line[1])-float(line[2])]
30
+ except ValueError:
31
+ continue
32
+ elif len(line) < 2 or len(line) > 3:
33
+ continue
34
+ rows.append(line)
35
+ f.close()
36
+
37
+ outData = upsample(rows)
38
+ return outData
39
+
40
+ def csv_to_csv(file_path):
41
+ fromData = pd.read_csv(file_path).values
42
+ outData = upsample(list(fromData))
43
+ return outData
44
+
45
+ def xy_to_csv(file_path):
46
+ return txt_to_csv(file_path)
47
+
48
+ def upsample(rows):
49
+ if len(rows) == 0:
50
+ Warning('Empty data!')
51
+ return None
52
+
53
+ rows.insert(0, ['10', rows[0][1]]) if float(rows[0][0]) > 10 else None
54
+ rows.append(['80', rows[-1][1]]) if float(rows[-1][0]) < 80 else None
55
+ rowsData = np.array(rows, dtype=np.float32)
56
+ f = interpolate.interp1d(rowsData[:, 0], rowsData[:, 1], kind='slinear')
57
+ xnew = np.linspace(10, 80, 4500)
58
+ ynew = f(xnew)
59
+ # outData = np.array([xnew, ynew]).T
60
+ return ynew
src/inference&case/.DS_Store ADDED
Binary file (8.2 kB). View file
 
src/inference&case/.ipynb_checkpoints/CPICANNcode-checkpoint.ipynb ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f2299629",
6
+ "metadata": {},
7
+ "source": [
8
+ "# It is a template for applying CPICANN to X-ray powder diffraction phase identification"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "cad44131",
14
+ "metadata": {},
15
+ "source": [
16
+ "### 1: install WPEMPhase package "
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "id": "af59fa9d",
22
+ "metadata": {},
23
+ "source": [
24
+ "pip install WPEMPhase"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "7d9d6d02",
30
+ "metadata": {},
31
+ "source": [
32
+ "### 2: The first time you execute CPICANN on your computer, you should initialize the system documents. After that, you do not need to do any additional execution to run CPICANN."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "raw",
37
+ "id": "4787c1b4",
38
+ "metadata": {},
39
+ "source": [
40
+ "Signature:\n",
41
+ "CPICANN.PhaseIdentifier(\n",
42
+ " FilePath,\n",
43
+ " Task='single-phase',\n",
44
+ " Model='default',\n",
45
+ " ElementsSystem='',\n",
46
+ " ElementsContained='',\n",
47
+ " Device='cuda:0',\n",
48
+ ")\n",
49
+ "Docstring:\n",
50
+ "CPICANN : Crystallographic Phase Identifier of Convolutional self-Attention Neural Network\n",
51
+ "\n",
52
+ "Contributors : Shouyang Zhang & Bin Cao\n",
53
+ "================================================================\n",
54
+ " Please feel free to open issues in the Github :\n",
55
+ " https://github.com/WPEM/CPICANN\n",
56
+ " or\n",
57
+ " contact Mr.Bin Cao (bcao686@connect.hkust-gz.edu.cn)\n",
58
+ " in case of any problems/comments/suggestions in using the code.\n",
59
+ "==================================================================\n",
60
+ "\n",
61
+ ":param FilePath\n",
62
+ "\n",
63
+ ":param Task, type=str, default='single-phase'\n",
64
+ " if Task = 'single-phase', CPICANN executes a single phase identification task\n",
65
+ " if Task = 'di-phase', CPICANN executes a dual phase identification task\n",
66
+ "\n",
67
+ ":param Model, type=str, default='default'\n",
68
+ " if Model = 'noise_model', CPICANN executes a single phase identification by noise-contained model\n",
69
+ " if Model = 'bca_model', CPICANN executes a single phase identification by background-contained model\n",
70
+ "\n",
71
+ ":param ElementsSystem, type=str, default=''\n",
72
+ " Specifies the elements to be included at least in the prediction, example: 'Fe'.\n",
73
+ "\n",
74
+ ":param ElementsContained, type=str, default=''\n",
75
+ " Specifies the elements to be included, with at least one of them in the prediction, example: 'O_C_S'.\n",
76
+ "\n",
77
+ ":param Device, type=str, default='cuda:0',\n",
78
+ " Which device to run the CPICANN, example: 'cuda:0', 'cpu'.\n",
79
+ "\n",
80
+ "examples:\n",
81
+ "from WPEMPhase import CPICANN\n",
82
+ "CPICANN.PhaseIdentifier(FilePath='./single-phase',Device='cpu')\n",
83
+ "File: ~/miniconda3/lib/python3.9/site-packages/WPEMPhase/CPICANN.py\n",
84
+ "Type: function"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 2,
90
+ "id": "0922c99d",
91
+ "metadata": {},
92
+ "outputs": [
93
+ {
94
+ "name": "stdout",
95
+ "output_type": "stream",
96
+ "text": [
97
+ "Collecting WPEMPhase\n",
98
+ " Downloading WPEMPhase-0.1.0-py3-none-any.whl.metadata (1.0 kB)\n",
99
+ "Requirement already satisfied: torch in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2.0.0)\n",
100
+ "Requirement already satisfied: plot in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (0.6.5)\n",
101
+ "Requirement already satisfied: scipy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.9.3)\n",
102
+ "Requirement already satisfied: pandas in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.5.1)\n",
103
+ "Requirement already satisfied: numpy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.23.3)\n",
104
+ "Requirement already satisfied: art in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (6.1)\n",
105
+ "Requirement already satisfied: pymatgen in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2023.3.23)\n",
106
+ "Requirement already satisfied: wget in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (3.2)\n",
107
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2.8.2)\n",
108
+ "Requirement already satisfied: pytz>=2020.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2022.5)\n",
109
+ "Requirement already satisfied: matplotlib in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.1)\n",
110
+ "Requirement already satisfied: typing in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.4.3)\n",
111
+ "Requirement already satisfied: pyyaml in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (6.0)\n",
112
+ "Requirement already satisfied: monty>=3.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2023.4.10)\n",
113
+ "Requirement already satisfied: mp-api>=0.27.3 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.31.2)\n",
114
+ "Requirement already satisfied: networkx>=2.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.8.8)\n",
115
+ "Requirement already satisfied: palettable>=3.1.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.3.3)\n",
116
+ "Requirement already satisfied: plotly>=4.5.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (5.14.1)\n",
117
+ "Requirement already satisfied: pybtex in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.24.0)\n",
118
+ "Requirement already satisfied: requests in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.28.2)\n",
119
+ "Requirement already satisfied: ruamel.yaml>=0.17.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.17.21)\n",
120
+ "Requirement already satisfied: spglib>=2.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.0.2)\n",
121
+ "Requirement already satisfied: sympy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (1.11.1)\n",
122
+ "Requirement already satisfied: tabulate in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.9.0)\n",
123
+ "Requirement already satisfied: tqdm in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (4.66.1)\n",
124
+ "Requirement already satisfied: uncertainties>=3.1.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.1.7)\n",
125
+ "Requirement already satisfied: filelock in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.10.7)\n",
126
+ "Requirement already satisfied: typing-extensions in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (4.11.0)\n",
127
+ "Requirement already satisfied: jinja2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.1.2)\n",
128
+ "Requirement already satisfied: contourpy>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.0.5)\n",
129
+ "Requirement already satisfied: cycler>=0.10 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (0.11.0)\n",
130
+ "Requirement already satisfied: fonttools>=4.22.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (4.38.0)\n",
131
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.4.4)\n",
132
+ "Requirement already satisfied: packaging>=20.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (23.0)\n",
133
+ "Requirement already satisfied: pillow>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (9.5.0)\n",
134
+ "Requirement already satisfied: pyparsing>=2.3.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (3.0.9)\n",
135
+ "Requirement already satisfied: importlib-resources>=3.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (5.12.0)\n",
136
+ "Requirement already satisfied: setuptools in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (67.6.1)\n",
137
+ "Requirement already satisfied: msgpack in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (1.0.5)\n",
138
+ "Requirement already satisfied: emmet-core<=0.50.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (0.50.0)\n",
139
+ "Requirement already satisfied: tenacity>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plotly>=4.5.0->pymatgen->WPEMPhase) (8.2.2)\n",
140
+ "Requirement already satisfied: six>=1.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas->WPEMPhase) (1.16.0)\n",
141
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2.0.4)\n",
142
+ "Requirement already satisfied: idna<4,>=2.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (3.3)\n",
143
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (1.26.9)\n",
144
+ "Requirement already satisfied: certifi>=2017.4.17 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2022.12.7)\n",
145
+ "Requirement already satisfied: ruamel.yaml.clib>=0.2.6 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from ruamel.yaml>=0.17.0->pymatgen->WPEMPhase) (0.2.6)\n",
146
+ "Requirement already satisfied: future in /Users/jacob/miniconda3/lib/python3.9/site-packages (from uncertainties>=3.1.4->pymatgen->WPEMPhase) (0.18.3)\n",
147
+ "Requirement already satisfied: MarkupSafe>=2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from jinja2->torch->WPEMPhase) (2.1.1)\n",
148
+ "Requirement already satisfied: latexcodec>=1.0.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pybtex->pymatgen->WPEMPhase) (2.0.1)\n",
149
+ "Requirement already satisfied: mpmath>=0.19 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from sympy->pymatgen->WPEMPhase) (1.3.0)\n",
150
+ "Requirement already satisfied: pydantic>=1.10.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from emmet-core<=0.50.0->mp-api>=0.27.3->pymatgen->WPEMPhase) (1.10.7)\n",
151
+ "Requirement already satisfied: zipp>=3.1.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib->plot->WPEMPhase) (3.9.0)\n",
152
+ "Downloading WPEMPhase-0.1.0-py3-none-any.whl (710 kB)\n",
153
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m710.2/710.2 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
154
+ "\u001b[?25hInstalling collected packages: WPEMPhase\n",
155
+ "Successfully installed WPEMPhase-0.1.0\n",
156
+ "Note: you may need to restart the kernel to use updated packages.\n"
157
+ ]
158
+ }
159
+ ],
160
+ "source": [
161
+ "pip install WPEMPhase"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 3,
167
+ "id": "8e1680a6",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "from WPEMPhase import CPICANN"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 4,
177
+ "id": "7625eb66",
178
+ "metadata": {},
179
+ "outputs": [
180
+ {
181
+ "name": "stdout",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "This is the first time CPICANN is being executed on your computer, configuring...\n",
185
+ "Downloading: 3% [24690688 / 776454342] bytes"
186
+ ]
187
+ },
188
+ {
189
+ "name": "stderr",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "IOPub message rate exceeded.\n",
193
+ "The notebook server will temporarily stop sending output\n",
194
+ "to the client in order to avoid crashing it.\n",
195
+ "To change this limit, set the config variable\n",
196
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
197
+ "\n",
198
+ "Current values:\n",
199
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
200
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
201
+ "\n"
202
+ ]
203
+ },
204
+ {
205
+ "name": "stdout",
206
+ "output_type": "stream",
207
+ "text": [
208
+ "Downloading: 7% [61341696 / 776454342] bytes"
209
+ ]
210
+ },
211
+ {
212
+ "name": "stderr",
213
+ "output_type": "stream",
214
+ "text": [
215
+ "IOPub message rate exceeded.\n",
216
+ "The notebook server will temporarily stop sending output\n",
217
+ "to the client in order to avoid crashing it.\n",
218
+ "To change this limit, set the config variable\n",
219
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
220
+ "\n",
221
+ "Current values:\n",
222
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
223
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
224
+ "\n"
225
+ ]
226
+ },
227
+ {
228
+ "name": "stdout",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "Downloading: 13% [107954176 / 776454342] bytes"
232
+ ]
233
+ },
234
+ {
235
+ "name": "stderr",
236
+ "output_type": "stream",
237
+ "text": [
238
+ "IOPub message rate exceeded.\n",
239
+ "The notebook server will temporarily stop sending output\n",
240
+ "to the client in order to avoid crashing it.\n",
241
+ "To change this limit, set the config variable\n",
242
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
243
+ "\n",
244
+ "Current values:\n",
245
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
246
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
247
+ "\n"
248
+ ]
249
+ },
250
+ {
251
+ "name": "stdout",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "Downloading: 19% [148324352 / 776454342] bytes"
255
+ ]
256
+ },
257
+ {
258
+ "name": "stderr",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "IOPub message rate exceeded.\n",
262
+ "The notebook server will temporarily stop sending output\n",
263
+ "to the client in order to avoid crashing it.\n",
264
+ "To change this limit, set the config variable\n",
265
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
266
+ "\n",
267
+ "Current values:\n",
268
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
269
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
270
+ "\n"
271
+ ]
272
+ },
273
+ {
274
+ "name": "stdout",
275
+ "output_type": "stream",
276
+ "text": [
277
+ "Downloading: 24% [189382656 / 776454342] bytes"
278
+ ]
279
+ },
280
+ {
281
+ "name": "stderr",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "IOPub message rate exceeded.\n",
285
+ "The notebook server will temporarily stop sending output\n",
286
+ "to the client in order to avoid crashing it.\n",
287
+ "To change this limit, set the config variable\n",
288
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
289
+ "\n",
290
+ "Current values:\n",
291
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
292
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
293
+ "\n"
294
+ ]
295
+ },
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "Downloading: 28% [221265920 / 776454342] bytes"
301
+ ]
302
+ },
303
+ {
304
+ "name": "stderr",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "IOPub message rate exceeded.\n",
308
+ "The notebook server will temporarily stop sending output\n",
309
+ "to the client in order to avoid crashing it.\n",
310
+ "To change this limit, set the config variable\n",
311
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
312
+ "\n",
313
+ "Current values:\n",
314
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
315
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
316
+ "\n"
317
+ ]
318
+ },
319
+ {
320
+ "name": "stdout",
321
+ "output_type": "stream",
322
+ "text": [
323
+ "Downloading: 33% [262488064 / 776454342] bytes"
324
+ ]
325
+ },
326
+ {
327
+ "name": "stderr",
328
+ "output_type": "stream",
329
+ "text": [
330
+ "IOPub message rate exceeded.\n",
331
+ "The notebook server will temporarily stop sending output\n",
332
+ "to the client in order to avoid crashing it.\n",
333
+ "To change this limit, set the config variable\n",
334
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
335
+ "\n",
336
+ "Current values:\n",
337
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
338
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
339
+ "\n"
340
+ ]
341
+ },
342
+ {
343
+ "name": "stdout",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "Downloading: 39% [304799744 / 776454342] bytes"
347
+ ]
348
+ },
349
+ {
350
+ "name": "stderr",
351
+ "output_type": "stream",
352
+ "text": [
353
+ "IOPub message rate exceeded.\n",
354
+ "The notebook server will temporarily stop sending output\n",
355
+ "to the client in order to avoid crashing it.\n",
356
+ "To change this limit, set the config variable\n",
357
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
358
+ "\n",
359
+ "Current values:\n",
360
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
361
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
362
+ "\n"
363
+ ]
364
+ },
365
+ {
366
+ "name": "stdout",
367
+ "output_type": "stream",
368
+ "text": [
369
+ "Downloading: 44% [346030080 / 776454342] bytes"
370
+ ]
371
+ },
372
+ {
373
+ "name": "stderr",
374
+ "output_type": "stream",
375
+ "text": [
376
+ "IOPub message rate exceeded.\n",
377
+ "The notebook server will temporarily stop sending output\n",
378
+ "to the client in order to avoid crashing it.\n",
379
+ "To change this limit, set the config variable\n",
380
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
381
+ "\n",
382
+ "Current values:\n",
383
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
384
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
385
+ "\n"
386
+ ]
387
+ },
388
+ {
389
+ "name": "stdout",
390
+ "output_type": "stream",
391
+ "text": [
392
+ "Downloading: 50% [388333568 / 776454342] bytes"
393
+ ]
394
+ },
395
+ {
396
+ "name": "stderr",
397
+ "output_type": "stream",
398
+ "text": [
399
+ "IOPub message rate exceeded.\n",
400
+ "The notebook server will temporarily stop sending output\n",
401
+ "to the client in order to avoid crashing it.\n",
402
+ "To change this limit, set the config variable\n",
403
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
404
+ "\n",
405
+ "Current values:\n",
406
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
407
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
408
+ "\n"
409
+ ]
410
+ },
411
+ {
412
+ "name": "stdout",
413
+ "output_type": "stream",
414
+ "text": [
415
+ "Downloading: 55% [429015040 / 776454342] bytes"
416
+ ]
417
+ },
418
+ {
419
+ "name": "stderr",
420
+ "output_type": "stream",
421
+ "text": [
422
+ "IOPub message rate exceeded.\n",
423
+ "The notebook server will temporarily stop sending output\n",
424
+ "to the client in order to avoid crashing it.\n",
425
+ "To change this limit, set the config variable\n",
426
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
427
+ "\n",
428
+ "Current values:\n",
429
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
430
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
431
+ "\n"
432
+ ]
433
+ },
434
+ {
435
+ "name": "stdout",
436
+ "output_type": "stream",
437
+ "text": [
438
+ "Downloading: 60% [470278144 / 776454342] bytes"
439
+ ]
440
+ },
441
+ {
442
+ "name": "stderr",
443
+ "output_type": "stream",
444
+ "text": [
445
+ "IOPub message rate exceeded.\n",
446
+ "The notebook server will temporarily stop sending output\n",
447
+ "to the client in order to avoid crashing it.\n",
448
+ "To change this limit, set the config variable\n",
449
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
450
+ "\n",
451
+ "Current values:\n",
452
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
453
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
454
+ "\n"
455
+ ]
456
+ },
457
+ {
458
+ "name": "stdout",
459
+ "output_type": "stream",
460
+ "text": [
461
+ "Downloading: 65% [507609088 / 776454342] bytes"
462
+ ]
463
+ },
464
+ {
465
+ "name": "stderr",
466
+ "output_type": "stream",
467
+ "text": [
468
+ "IOPub message rate exceeded.\n",
469
+ "The notebook server will temporarily stop sending output\n",
470
+ "to the client in order to avoid crashing it.\n",
471
+ "To change this limit, set the config variable\n",
472
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
473
+ "\n",
474
+ "Current values:\n",
475
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
476
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
477
+ "\n"
478
+ ]
479
+ },
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "Downloading: 70% [549601280 / 776454342] bytes"
485
+ ]
486
+ },
487
+ {
488
+ "name": "stderr",
489
+ "output_type": "stream",
490
+ "text": [
491
+ "IOPub message rate exceeded.\n",
492
+ "The notebook server will temporarily stop sending output\n",
493
+ "to the client in order to avoid crashing it.\n",
494
+ "To change this limit, set the config variable\n",
495
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
496
+ "\n",
497
+ "Current values:\n",
498
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
499
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
500
+ "\n"
501
+ ]
502
+ },
503
+ {
504
+ "name": "stdout",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "Downloading: 75% [587497472 / 776454342] bytes"
508
+ ]
509
+ },
510
+ {
511
+ "name": "stderr",
512
+ "output_type": "stream",
513
+ "text": [
514
+ "IOPub message rate exceeded.\n",
515
+ "The notebook server will temporarily stop sending output\n",
516
+ "to the client in order to avoid crashing it.\n",
517
+ "To change this limit, set the config variable\n",
518
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
519
+ "\n",
520
+ "Current values:\n",
521
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
522
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
523
+ "\n"
524
+ ]
525
+ },
526
+ {
527
+ "name": "stdout",
528
+ "output_type": "stream",
529
+ "text": [
530
+ "Downloading: 80% [622919680 / 776454342] bytes"
531
+ ]
532
+ },
533
+ {
534
+ "name": "stderr",
535
+ "output_type": "stream",
536
+ "text": [
537
+ "IOPub message rate exceeded.\n",
538
+ "The notebook server will temporarily stop sending output\n",
539
+ "to the client in order to avoid crashing it.\n",
540
+ "To change this limit, set the config variable\n",
541
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
542
+ "\n",
543
+ "Current values:\n",
544
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
545
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
546
+ "\n"
547
+ ]
548
+ },
549
+ {
550
+ "name": "stdout",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "Downloading: 86% [668491776 / 776454342] bytes"
554
+ ]
555
+ },
556
+ {
557
+ "name": "stderr",
558
+ "output_type": "stream",
559
+ "text": [
560
+ "IOPub message rate exceeded.\n",
561
+ "The notebook server will temporarily stop sending output\n",
562
+ "to the client in order to avoid crashing it.\n",
563
+ "To change this limit, set the config variable\n",
564
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
565
+ "\n",
566
+ "Current values:\n",
567
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
568
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
569
+ "\n"
570
+ ]
571
+ },
572
+ {
573
+ "name": "stdout",
574
+ "output_type": "stream",
575
+ "text": [
576
+ "Downloading: 89% [698474496 / 776454342] bytes"
577
+ ]
578
+ },
579
+ {
580
+ "name": "stderr",
581
+ "output_type": "stream",
582
+ "text": [
583
+ "IOPub message rate exceeded.\n",
584
+ "The notebook server will temporarily stop sending output\n",
585
+ "to the client in order to avoid crashing it.\n",
586
+ "To change this limit, set the config variable\n",
587
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
588
+ "\n",
589
+ "Current values:\n",
590
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
591
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
592
+ "\n"
593
+ ]
594
+ },
595
+ {
596
+ "name": "stdout",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "Downloading: 91% [713318400 / 776454342] bytes"
600
+ ]
601
+ },
602
+ {
603
+ "name": "stderr",
604
+ "output_type": "stream",
605
+ "text": [
606
+ "IOPub message rate exceeded.\n",
607
+ "The notebook server will temporarily stop sending output\n",
608
+ "to the client in order to avoid crashing it.\n",
609
+ "To change this limit, set the config variable\n",
610
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
611
+ "\n",
612
+ "Current values:\n",
613
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
614
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
615
+ "\n"
616
+ ]
617
+ },
618
+ {
619
+ "name": "stdout",
620
+ "output_type": "stream",
621
+ "text": [
622
+ "Downloading: 96% [746487808 / 776454342] bytes"
623
+ ]
624
+ },
625
+ {
626
+ "name": "stderr",
627
+ "output_type": "stream",
628
+ "text": [
629
+ "IOPub message rate exceeded.\n",
630
+ "The notebook server will temporarily stop sending output\n",
631
+ "to the client in order to avoid crashing it.\n",
632
+ "To change this limit, set the config variable\n",
633
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
634
+ "\n",
635
+ "Current values:\n",
636
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
637
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
638
+ "\n"
639
+ ]
640
+ },
641
+ {
642
+ "name": "stdout",
643
+ "output_type": "stream",
644
+ "text": [
645
+ "Downloading: 100% [776454342 / 776454342] bytes ____ ____ ___ ____ _ _ _ _ _ \n",
646
+ " / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
647
+ "| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
648
+ "| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
649
+ " \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
650
+ " \n",
651
+ "\n",
652
+ "The phase identification module of WPEM\n",
653
+ "URL : https://github.com/WPEM/CPICANN\n",
654
+ "Executed on : 2024-04-21 14:14:25 | Have a great day.\n",
655
+ "================================================================================\n",
656
+ "loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_back3.pth\n",
657
+ "\n",
658
+ ">>>>>> RUNNING: ./testdata/.DS_Store\n",
659
+ "\n",
660
+ ">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
661
+ "pred cls_id : 2475 confidence : 98.89%\n",
662
+ "pred cod_id : 9009622 formula : Pb2 S2 O6\n",
663
+ "pred space group No: 11 space group : P2_1/m\n",
664
+ "\n",
665
+ "inference result saved in infResults_testdata.csv\n",
666
+ "inference figures saved at figs/\n",
667
+ "THE END\n"
668
+ ]
669
+ },
670
+ {
671
+ "data": {
672
+ "text/plain": [
673
+ "True"
674
+ ]
675
+ },
676
+ "execution_count": 4,
677
+ "metadata": {},
678
+ "output_type": "execute_result"
679
+ }
680
+ ],
681
+ "source": [
682
+ "# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
683
+ "\n",
684
+ "CPICANN.PhaseIdentifier(FilePath='./testdata',Model='bca_model',Task='single-phase',Device='cpu',)"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 7,
690
+ "id": "304b62b5",
691
+ "metadata": {},
692
+ "outputs": [
693
+ {
694
+ "name": "stdout",
695
+ "output_type": "stream",
696
+ "text": [
697
+ " ____ ____ ___ ____ _ _ _ _ _ \n",
698
+ " / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
699
+ "| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
700
+ "| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
701
+ " \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
702
+ " \n",
703
+ "\n",
704
+ "The phase identification module of WPEM\n",
705
+ "URL : https://github.com/WPEM/CPICANN\n",
706
+ "Executed on : 2024-04-21 14:14:53 | Have a great day.\n",
707
+ "================================================================================\n",
708
+ "loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_noise3.pth\n",
709
+ "\n",
710
+ ">>>>>> RUNNING: ./testdata/.DS_Store\n",
711
+ "\n",
712
+ ">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
713
+ "pred cls_id : 3378 confidence : 100.00%\n",
714
+ "pred cod_id : 9004484 formula : Pb4 S4 O16\n",
715
+ "pred space group No: 62 space group : Pnma\n",
716
+ "\n",
717
+ "inference result saved in infResults_testdata.csv\n",
718
+ "inference figures saved at figs/\n",
719
+ "THE END\n"
720
+ ]
721
+ },
722
+ {
723
+ "data": {
724
+ "text/plain": [
725
+ "True"
726
+ ]
727
+ },
728
+ "execution_count": 7,
729
+ "metadata": {},
730
+ "output_type": "execute_result"
731
+ }
732
+ ],
733
+ "source": [
734
+ "from WPEMPhase import CPICANN\n",
735
+ "# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
736
+ "\n",
737
+ "CPICANN.PhaseIdentifier(FilePath='./testdata',Model='noise_model',Task='single-phase',ElementsContained='Pb_S_O',Device='cpu',)"
738
+ ]
739
+ },
740
+ {
741
+ "cell_type": "raw",
742
+ "id": "977ede12",
743
+ "metadata": {},
744
+ "source": [
745
+ "For inquiries or assistance, please don't hesitate to contact us at bcao686@connect.hkust-gz.edu.cn (Dr. CAO Bin)."
746
+ ]
747
+ },
748
+ {
749
+ "cell_type": "code",
750
+ "execution_count": null,
751
+ "id": "bb101480",
752
+ "metadata": {},
753
+ "outputs": [],
754
+ "source": []
755
+ }
756
+ ],
757
+ "metadata": {
758
+ "kernelspec": {
759
+ "display_name": "Python 3 (ipykernel)",
760
+ "language": "python",
761
+ "name": "python3"
762
+ },
763
+ "language_info": {
764
+ "codemirror_mode": {
765
+ "name": "ipython",
766
+ "version": 3
767
+ },
768
+ "file_extension": ".py",
769
+ "mimetype": "text/x-python",
770
+ "name": "python",
771
+ "nbconvert_exporter": "python",
772
+ "pygments_lexer": "ipython3",
773
+ "version": "3.9.12"
774
+ }
775
+ },
776
+ "nbformat": 4,
777
+ "nbformat_minor": 5
778
+ }
src/inference&case/CPICANNcode.ipynb ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f2299629",
6
+ "metadata": {},
7
+ "source": [
8
+ "# It is a template for applying CPICANN to X-ray powder diffraction phase identification"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "cad44131",
14
+ "metadata": {},
15
+ "source": [
16
+ "### 1: install WPEMPhase package "
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "id": "af59fa9d",
22
+ "metadata": {},
23
+ "source": [
24
+ "pip install WPEMPhase"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "7d9d6d02",
30
+ "metadata": {},
31
+ "source": [
32
+ "### 2: The first time you execute CPICANN on your computer, you should initialize the system documents. After that, you do not need to do any additional execution to run CPICANN."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "raw",
37
+ "id": "4787c1b4",
38
+ "metadata": {},
39
+ "source": [
40
+ "Signature:\n",
41
+ "CPICANN.PhaseIdentifier(\n",
42
+ " FilePath,\n",
43
+ " Task='single-phase',\n",
44
+ " Model='default',\n",
45
+ " ElementsSystem='',\n",
46
+ " ElementsContained='',\n",
47
+ " Device='cuda:0',\n",
48
+ ")\n",
49
+ "Docstring:\n",
50
+ "CPICANN : Crystallographic Phase Identifier of Convolutional self-Attention Neural Network\n",
51
+ "\n",
52
+ "Contributors : Shouyang Zhang & Bin Cao\n",
53
+ "================================================================\n",
54
+ " Please feel free to open issues in the Github :\n",
55
+ " https://github.com/WPEM/CPICANN\n",
56
+ " or\n",
57
+ " contact Mr.Bin Cao (bcao686@connect.hkust-gz.edu.cn)\n",
58
+ " in case of any problems/comments/suggestions in using the code.\n",
59
+ "==================================================================\n",
60
+ "\n",
61
+ ":param FilePath\n",
62
+ "\n",
63
+ ":param Task, type=str, default='single-phase'\n",
64
+ " if Task = 'single-phase', CPICANN executes a single phase identification task\n",
65
+ " if Task = 'di-phase', CPICANN executes a dual phase identification task\n",
66
+ "\n",
67
+ ":param Model, type=str, default='default'\n",
68
+ " if Model = 'noise_model', CPICANN executes a single phase identification by noise-contained model\n",
69
+ " if Model = 'bca_model', CPICANN executes a single phase identification by background-contained model\n",
70
+ "\n",
71
+ ":param ElementsSystem, type=str, default=''\n",
72
+ " Specifies the elements to be included at least in the prediction, example: 'Fe'.\n",
73
+ "\n",
74
+ ":param ElementsContained, type=str, default=''\n",
75
+ " Specifies the elements to be included, with at least one of them in the prediction, example: 'O_C_S'.\n",
76
+ "\n",
77
+ ":param Device, type=str, default='cuda:0',\n",
78
+ " Which device to run the CPICANN, example: 'cuda:0', 'cpu'.\n",
79
+ "\n",
80
+ "examples:\n",
81
+ "from WPEMPhase import CPICANN\n",
82
+ "CPICANN.PhaseIdentifier(FilePath='./single-phase',Device='cpu')\n",
83
+ "File: ~/miniconda3/lib/python3.9/site-packages/WPEMPhase/CPICANN.py\n",
84
+ "Type: function"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 2,
90
+ "id": "0922c99d",
91
+ "metadata": {},
92
+ "outputs": [
93
+ {
94
+ "name": "stdout",
95
+ "output_type": "stream",
96
+ "text": [
97
+ "Collecting WPEMPhase\n",
98
+ " Downloading WPEMPhase-0.1.0-py3-none-any.whl.metadata (1.0 kB)\n",
99
+ "Requirement already satisfied: torch in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2.0.0)\n",
100
+ "Requirement already satisfied: plot in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (0.6.5)\n",
101
+ "Requirement already satisfied: scipy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.9.3)\n",
102
+ "Requirement already satisfied: pandas in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.5.1)\n",
103
+ "Requirement already satisfied: numpy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.23.3)\n",
104
+ "Requirement already satisfied: art in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (6.1)\n",
105
+ "Requirement already satisfied: pymatgen in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2023.3.23)\n",
106
+ "Requirement already satisfied: wget in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (3.2)\n",
107
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2.8.2)\n",
108
+ "Requirement already satisfied: pytz>=2020.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2022.5)\n",
109
+ "Requirement already satisfied: matplotlib in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.1)\n",
110
+ "Requirement already satisfied: typing in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.4.3)\n",
111
+ "Requirement already satisfied: pyyaml in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (6.0)\n",
112
+ "Requirement already satisfied: monty>=3.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2023.4.10)\n",
113
+ "Requirement already satisfied: mp-api>=0.27.3 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.31.2)\n",
114
+ "Requirement already satisfied: networkx>=2.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.8.8)\n",
115
+ "Requirement already satisfied: palettable>=3.1.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.3.3)\n",
116
+ "Requirement already satisfied: plotly>=4.5.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (5.14.1)\n",
117
+ "Requirement already satisfied: pybtex in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.24.0)\n",
118
+ "Requirement already satisfied: requests in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.28.2)\n",
119
+ "Requirement already satisfied: ruamel.yaml>=0.17.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.17.21)\n",
120
+ "Requirement already satisfied: spglib>=2.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.0.2)\n",
121
+ "Requirement already satisfied: sympy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (1.11.1)\n",
122
+ "Requirement already satisfied: tabulate in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.9.0)\n",
123
+ "Requirement already satisfied: tqdm in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (4.66.1)\n",
124
+ "Requirement already satisfied: uncertainties>=3.1.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.1.7)\n",
125
+ "Requirement already satisfied: filelock in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.10.7)\n",
126
+ "Requirement already satisfied: typing-extensions in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (4.11.0)\n",
127
+ "Requirement already satisfied: jinja2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.1.2)\n",
128
+ "Requirement already satisfied: contourpy>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.0.5)\n",
129
+ "Requirement already satisfied: cycler>=0.10 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (0.11.0)\n",
130
+ "Requirement already satisfied: fonttools>=4.22.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (4.38.0)\n",
131
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.4.4)\n",
132
+ "Requirement already satisfied: packaging>=20.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (23.0)\n",
133
+ "Requirement already satisfied: pillow>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (9.5.0)\n",
134
+ "Requirement already satisfied: pyparsing>=2.3.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (3.0.9)\n",
135
+ "Requirement already satisfied: importlib-resources>=3.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (5.12.0)\n",
136
+ "Requirement already satisfied: setuptools in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (67.6.1)\n",
137
+ "Requirement already satisfied: msgpack in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (1.0.5)\n",
138
+ "Requirement already satisfied: emmet-core<=0.50.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (0.50.0)\n",
139
+ "Requirement already satisfied: tenacity>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plotly>=4.5.0->pymatgen->WPEMPhase) (8.2.2)\n",
140
+ "Requirement already satisfied: six>=1.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas->WPEMPhase) (1.16.0)\n",
141
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2.0.4)\n",
142
+ "Requirement already satisfied: idna<4,>=2.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (3.3)\n",
143
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (1.26.9)\n",
144
+ "Requirement already satisfied: certifi>=2017.4.17 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2022.12.7)\n",
145
+ "Requirement already satisfied: ruamel.yaml.clib>=0.2.6 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from ruamel.yaml>=0.17.0->pymatgen->WPEMPhase) (0.2.6)\n",
146
+ "Requirement already satisfied: future in /Users/jacob/miniconda3/lib/python3.9/site-packages (from uncertainties>=3.1.4->pymatgen->WPEMPhase) (0.18.3)\n",
147
+ "Requirement already satisfied: MarkupSafe>=2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from jinja2->torch->WPEMPhase) (2.1.1)\n",
148
+ "Requirement already satisfied: latexcodec>=1.0.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pybtex->pymatgen->WPEMPhase) (2.0.1)\n",
149
+ "Requirement already satisfied: mpmath>=0.19 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from sympy->pymatgen->WPEMPhase) (1.3.0)\n",
150
+ "Requirement already satisfied: pydantic>=1.10.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from emmet-core<=0.50.0->mp-api>=0.27.3->pymatgen->WPEMPhase) (1.10.7)\n",
151
+ "Requirement already satisfied: zipp>=3.1.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib->plot->WPEMPhase) (3.9.0)\n",
152
+ "Downloading WPEMPhase-0.1.0-py3-none-any.whl (710 kB)\n",
153
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m710.2/710.2 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
154
+ "\u001b[?25hInstalling collected packages: WPEMPhase\n",
155
+ "Successfully installed WPEMPhase-0.1.0\n",
156
+ "Note: you may need to restart the kernel to use updated packages.\n"
157
+ ]
158
+ }
159
+ ],
160
+ "source": [
161
+ "pip install WPEMPhase"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 3,
167
+ "id": "8e1680a6",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "from WPEMPhase import CPICANN"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 4,
177
+ "id": "7625eb66",
178
+ "metadata": {},
179
+ "outputs": [
180
+ {
181
+ "name": "stdout",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "This is the first time CPICANN is being executed on your computer, configuring...\n",
185
+ "Downloading: 3% [24690688 / 776454342] bytes"
186
+ ]
187
+ },
188
+ {
189
+ "name": "stderr",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "IOPub message rate exceeded.\n",
193
+ "The notebook server will temporarily stop sending output\n",
194
+ "to the client in order to avoid crashing it.\n",
195
+ "To change this limit, set the config variable\n",
196
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
197
+ "\n",
198
+ "Current values:\n",
199
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
200
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
201
+ "\n"
202
+ ]
203
+ },
204
+ {
205
+ "name": "stdout",
206
+ "output_type": "stream",
207
+ "text": [
208
+ "Downloading: 7% [61341696 / 776454342] bytes"
209
+ ]
210
+ },
211
+ {
212
+ "name": "stderr",
213
+ "output_type": "stream",
214
+ "text": [
215
+ "IOPub message rate exceeded.\n",
216
+ "The notebook server will temporarily stop sending output\n",
217
+ "to the client in order to avoid crashing it.\n",
218
+ "To change this limit, set the config variable\n",
219
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
220
+ "\n",
221
+ "Current values:\n",
222
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
223
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
224
+ "\n"
225
+ ]
226
+ },
227
+ {
228
+ "name": "stdout",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "Downloading: 13% [107954176 / 776454342] bytes"
232
+ ]
233
+ },
234
+ {
235
+ "name": "stderr",
236
+ "output_type": "stream",
237
+ "text": [
238
+ "IOPub message rate exceeded.\n",
239
+ "The notebook server will temporarily stop sending output\n",
240
+ "to the client in order to avoid crashing it.\n",
241
+ "To change this limit, set the config variable\n",
242
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
243
+ "\n",
244
+ "Current values:\n",
245
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
246
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
247
+ "\n"
248
+ ]
249
+ },
250
+ {
251
+ "name": "stdout",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "Downloading: 19% [148324352 / 776454342] bytes"
255
+ ]
256
+ },
257
+ {
258
+ "name": "stderr",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "IOPub message rate exceeded.\n",
262
+ "The notebook server will temporarily stop sending output\n",
263
+ "to the client in order to avoid crashing it.\n",
264
+ "To change this limit, set the config variable\n",
265
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
266
+ "\n",
267
+ "Current values:\n",
268
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
269
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
270
+ "\n"
271
+ ]
272
+ },
273
+ {
274
+ "name": "stdout",
275
+ "output_type": "stream",
276
+ "text": [
277
+ "Downloading: 24% [189382656 / 776454342] bytes"
278
+ ]
279
+ },
280
+ {
281
+ "name": "stderr",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "IOPub message rate exceeded.\n",
285
+ "The notebook server will temporarily stop sending output\n",
286
+ "to the client in order to avoid crashing it.\n",
287
+ "To change this limit, set the config variable\n",
288
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
289
+ "\n",
290
+ "Current values:\n",
291
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
292
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
293
+ "\n"
294
+ ]
295
+ },
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "Downloading: 28% [221265920 / 776454342] bytes"
301
+ ]
302
+ },
303
+ {
304
+ "name": "stderr",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "IOPub message rate exceeded.\n",
308
+ "The notebook server will temporarily stop sending output\n",
309
+ "to the client in order to avoid crashing it.\n",
310
+ "To change this limit, set the config variable\n",
311
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
312
+ "\n",
313
+ "Current values:\n",
314
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
315
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
316
+ "\n"
317
+ ]
318
+ },
319
+ {
320
+ "name": "stdout",
321
+ "output_type": "stream",
322
+ "text": [
323
+ "Downloading: 33% [262488064 / 776454342] bytes"
324
+ ]
325
+ },
326
+ {
327
+ "name": "stderr",
328
+ "output_type": "stream",
329
+ "text": [
330
+ "IOPub message rate exceeded.\n",
331
+ "The notebook server will temporarily stop sending output\n",
332
+ "to the client in order to avoid crashing it.\n",
333
+ "To change this limit, set the config variable\n",
334
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
335
+ "\n",
336
+ "Current values:\n",
337
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
338
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
339
+ "\n"
340
+ ]
341
+ },
342
+ {
343
+ "name": "stdout",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "Downloading: 39% [304799744 / 776454342] bytes"
347
+ ]
348
+ },
349
+ {
350
+ "name": "stderr",
351
+ "output_type": "stream",
352
+ "text": [
353
+ "IOPub message rate exceeded.\n",
354
+ "The notebook server will temporarily stop sending output\n",
355
+ "to the client in order to avoid crashing it.\n",
356
+ "To change this limit, set the config variable\n",
357
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
358
+ "\n",
359
+ "Current values:\n",
360
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
361
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
362
+ "\n"
363
+ ]
364
+ },
365
+ {
366
+ "name": "stdout",
367
+ "output_type": "stream",
368
+ "text": [
369
+ "Downloading: 44% [346030080 / 776454342] bytes"
370
+ ]
371
+ },
372
+ {
373
+ "name": "stderr",
374
+ "output_type": "stream",
375
+ "text": [
376
+ "IOPub message rate exceeded.\n",
377
+ "The notebook server will temporarily stop sending output\n",
378
+ "to the client in order to avoid crashing it.\n",
379
+ "To change this limit, set the config variable\n",
380
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
381
+ "\n",
382
+ "Current values:\n",
383
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
384
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
385
+ "\n"
386
+ ]
387
+ },
388
+ {
389
+ "name": "stdout",
390
+ "output_type": "stream",
391
+ "text": [
392
+ "Downloading: 50% [388333568 / 776454342] bytes"
393
+ ]
394
+ },
395
+ {
396
+ "name": "stderr",
397
+ "output_type": "stream",
398
+ "text": [
399
+ "IOPub message rate exceeded.\n",
400
+ "The notebook server will temporarily stop sending output\n",
401
+ "to the client in order to avoid crashing it.\n",
402
+ "To change this limit, set the config variable\n",
403
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
404
+ "\n",
405
+ "Current values:\n",
406
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
407
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
408
+ "\n"
409
+ ]
410
+ },
411
+ {
412
+ "name": "stdout",
413
+ "output_type": "stream",
414
+ "text": [
415
+ "Downloading: 55% [429015040 / 776454342] bytes"
416
+ ]
417
+ },
418
+ {
419
+ "name": "stderr",
420
+ "output_type": "stream",
421
+ "text": [
422
+ "IOPub message rate exceeded.\n",
423
+ "The notebook server will temporarily stop sending output\n",
424
+ "to the client in order to avoid crashing it.\n",
425
+ "To change this limit, set the config variable\n",
426
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
427
+ "\n",
428
+ "Current values:\n",
429
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
430
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
431
+ "\n"
432
+ ]
433
+ },
434
+ {
435
+ "name": "stdout",
436
+ "output_type": "stream",
437
+ "text": [
438
+ "Downloading: 60% [470278144 / 776454342] bytes"
439
+ ]
440
+ },
441
+ {
442
+ "name": "stderr",
443
+ "output_type": "stream",
444
+ "text": [
445
+ "IOPub message rate exceeded.\n",
446
+ "The notebook server will temporarily stop sending output\n",
447
+ "to the client in order to avoid crashing it.\n",
448
+ "To change this limit, set the config variable\n",
449
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
450
+ "\n",
451
+ "Current values:\n",
452
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
453
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
454
+ "\n"
455
+ ]
456
+ },
457
+ {
458
+ "name": "stdout",
459
+ "output_type": "stream",
460
+ "text": [
461
+ "Downloading: 65% [507609088 / 776454342] bytes"
462
+ ]
463
+ },
464
+ {
465
+ "name": "stderr",
466
+ "output_type": "stream",
467
+ "text": [
468
+ "IOPub message rate exceeded.\n",
469
+ "The notebook server will temporarily stop sending output\n",
470
+ "to the client in order to avoid crashing it.\n",
471
+ "To change this limit, set the config variable\n",
472
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
473
+ "\n",
474
+ "Current values:\n",
475
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
476
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
477
+ "\n"
478
+ ]
479
+ },
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "Downloading: 70% [549601280 / 776454342] bytes"
485
+ ]
486
+ },
487
+ {
488
+ "name": "stderr",
489
+ "output_type": "stream",
490
+ "text": [
491
+ "IOPub message rate exceeded.\n",
492
+ "The notebook server will temporarily stop sending output\n",
493
+ "to the client in order to avoid crashing it.\n",
494
+ "To change this limit, set the config variable\n",
495
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
496
+ "\n",
497
+ "Current values:\n",
498
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
499
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
500
+ "\n"
501
+ ]
502
+ },
503
+ {
504
+ "name": "stdout",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "Downloading: 75% [587497472 / 776454342] bytes"
508
+ ]
509
+ },
510
+ {
511
+ "name": "stderr",
512
+ "output_type": "stream",
513
+ "text": [
514
+ "IOPub message rate exceeded.\n",
515
+ "The notebook server will temporarily stop sending output\n",
516
+ "to the client in order to avoid crashing it.\n",
517
+ "To change this limit, set the config variable\n",
518
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
519
+ "\n",
520
+ "Current values:\n",
521
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
522
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
523
+ "\n"
524
+ ]
525
+ },
526
+ {
527
+ "name": "stdout",
528
+ "output_type": "stream",
529
+ "text": [
530
+ "Downloading: 80% [622919680 / 776454342] bytes"
531
+ ]
532
+ },
533
+ {
534
+ "name": "stderr",
535
+ "output_type": "stream",
536
+ "text": [
537
+ "IOPub message rate exceeded.\n",
538
+ "The notebook server will temporarily stop sending output\n",
539
+ "to the client in order to avoid crashing it.\n",
540
+ "To change this limit, set the config variable\n",
541
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
542
+ "\n",
543
+ "Current values:\n",
544
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
545
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
546
+ "\n"
547
+ ]
548
+ },
549
+ {
550
+ "name": "stdout",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "Downloading: 86% [668491776 / 776454342] bytes"
554
+ ]
555
+ },
556
+ {
557
+ "name": "stderr",
558
+ "output_type": "stream",
559
+ "text": [
560
+ "IOPub message rate exceeded.\n",
561
+ "The notebook server will temporarily stop sending output\n",
562
+ "to the client in order to avoid crashing it.\n",
563
+ "To change this limit, set the config variable\n",
564
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
565
+ "\n",
566
+ "Current values:\n",
567
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
568
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
569
+ "\n"
570
+ ]
571
+ },
572
+ {
573
+ "name": "stdout",
574
+ "output_type": "stream",
575
+ "text": [
576
+ "Downloading: 89% [698474496 / 776454342] bytes"
577
+ ]
578
+ },
579
+ {
580
+ "name": "stderr",
581
+ "output_type": "stream",
582
+ "text": [
583
+ "IOPub message rate exceeded.\n",
584
+ "The notebook server will temporarily stop sending output\n",
585
+ "to the client in order to avoid crashing it.\n",
586
+ "To change this limit, set the config variable\n",
587
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
588
+ "\n",
589
+ "Current values:\n",
590
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
591
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
592
+ "\n"
593
+ ]
594
+ },
595
+ {
596
+ "name": "stdout",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "Downloading: 91% [713318400 / 776454342] bytes"
600
+ ]
601
+ },
602
+ {
603
+ "name": "stderr",
604
+ "output_type": "stream",
605
+ "text": [
606
+ "IOPub message rate exceeded.\n",
607
+ "The notebook server will temporarily stop sending output\n",
608
+ "to the client in order to avoid crashing it.\n",
609
+ "To change this limit, set the config variable\n",
610
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
611
+ "\n",
612
+ "Current values:\n",
613
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
614
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
615
+ "\n"
616
+ ]
617
+ },
618
+ {
619
+ "name": "stdout",
620
+ "output_type": "stream",
621
+ "text": [
622
+ "Downloading: 96% [746487808 / 776454342] bytes"
623
+ ]
624
+ },
625
+ {
626
+ "name": "stderr",
627
+ "output_type": "stream",
628
+ "text": [
629
+ "IOPub message rate exceeded.\n",
630
+ "The notebook server will temporarily stop sending output\n",
631
+ "to the client in order to avoid crashing it.\n",
632
+ "To change this limit, set the config variable\n",
633
+ "`--NotebookApp.iopub_msg_rate_limit`.\n",
634
+ "\n",
635
+ "Current values:\n",
636
+ "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
637
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
638
+ "\n"
639
+ ]
640
+ },
641
+ {
642
+ "name": "stdout",
643
+ "output_type": "stream",
644
+ "text": [
645
+ "Downloading: 100% [776454342 / 776454342] bytes ____ ____ ___ ____ _ _ _ _ _ \n",
646
+ " / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
647
+ "| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
648
+ "| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
649
+ " \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
650
+ " \n",
651
+ "\n",
652
+ "The phase identification module of WPEM\n",
653
+ "URL : https://github.com/WPEM/CPICANN\n",
654
+ "Executed on : 2024-04-21 14:14:25 | Have a great day.\n",
655
+ "================================================================================\n",
656
+ "loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_back3.pth\n",
657
+ "\n",
658
+ ">>>>>> RUNNING: ./testdata/.DS_Store\n",
659
+ "\n",
660
+ ">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
661
+ "pred cls_id : 2475 confidence : 98.89%\n",
662
+ "pred cod_id : 9009622 formula : Pb2 S2 O6\n",
663
+ "pred space group No: 11 space group : P2_1/m\n",
664
+ "\n",
665
+ "inference result saved in infResults_testdata.csv\n",
666
+ "inference figures saved at figs/\n",
667
+ "THE END\n"
668
+ ]
669
+ },
670
+ {
671
+ "data": {
672
+ "text/plain": [
673
+ "True"
674
+ ]
675
+ },
676
+ "execution_count": 4,
677
+ "metadata": {},
678
+ "output_type": "execute_result"
679
+ }
680
+ ],
681
+ "source": [
682
+ "# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
683
+ "\n",
684
+ "CPICANN.PhaseIdentifier(FilePath='./testdata',Model='bca_model',Task='single-phase',Device='cpu',)"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 7,
690
+ "id": "304b62b5",
691
+ "metadata": {},
692
+ "outputs": [
693
+ {
694
+ "name": "stdout",
695
+ "output_type": "stream",
696
+ "text": [
697
+ " ____ ____ ___ ____ _ _ _ _ _ \n",
698
+ " / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
699
+ "| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
700
+ "| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
701
+ " \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
702
+ " \n",
703
+ "\n",
704
+ "The phase identification module of WPEM\n",
705
+ "URL : https://github.com/WPEM/CPICANN\n",
706
+ "Executed on : 2024-04-21 14:14:53 | Have a great day.\n",
707
+ "================================================================================\n",
708
+ "loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_noise3.pth\n",
709
+ "\n",
710
+ ">>>>>> RUNNING: ./testdata/.DS_Store\n",
711
+ "\n",
712
+ ">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
713
+ "pred cls_id : 3378 confidence : 100.00%\n",
714
+ "pred cod_id : 9004484 formula : Pb4 S4 O16\n",
715
+ "pred space group No: 62 space group : Pnma\n",
716
+ "\n",
717
+ "inference result saved in infResults_testdata.csv\n",
718
+ "inference figures saved at figs/\n",
719
+ "THE END\n"
720
+ ]
721
+ },
722
+ {
723
+ "data": {
724
+ "text/plain": [
725
+ "True"
726
+ ]
727
+ },
728
+ "execution_count": 7,
729
+ "metadata": {},
730
+ "output_type": "execute_result"
731
+ }
732
+ ],
733
+ "source": [
734
+ "from WPEMPhase import CPICANN\n",
735
+ "# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
736
+ "\n",
737
+ "CPICANN.PhaseIdentifier(FilePath='./testdata',Model='noise_model',Task='single-phase',ElementsContained='Pb_S_O',Device='cpu',)"
738
+ ]
739
+ },
740
+ {
741
+ "cell_type": "raw",
742
+ "id": "977ede12",
743
+ "metadata": {},
744
+ "source": [
745
+ "For inquiries or assistance, please don't hesitate to contact us at bcao686@connect.hkust-gz.edu.cn (Dr. CAO Bin)."
746
+ ]
747
+ },
748
+ {
749
+ "cell_type": "code",
750
+ "execution_count": null,
751
+ "id": "bb101480",
752
+ "metadata": {},
753
+ "outputs": [],
754
+ "source": []
755
+ }
756
+ ],
757
+ "metadata": {
758
+ "kernelspec": {
759
+ "display_name": "Python 3 (ipykernel)",
760
+ "language": "python",
761
+ "name": "python3"
762
+ },
763
+ "language_info": {
764
+ "codemirror_mode": {
765
+ "name": "ipython",
766
+ "version": 3
767
+ },
768
+ "file_extension": ".py",
769
+ "mimetype": "text/x-python",
770
+ "name": "python",
771
+ "nbconvert_exporter": "python",
772
+ "pygments_lexer": "ipython3",
773
+ "version": "3.9.12"
774
+ }
775
+ },
776
+ "nbformat": 4,
777
+ "nbformat_minor": 5
778
+ }
src/inference&case/config/elem_setting.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FileName,must include elements,include at least one of,exclude elements
2
+ CdS.csv,Cd_S,,
3
+ MnS.csv,Mn_S,,
4
+ NiO2H2.csv,Ni_O_H,,
5
+ PbSO4.csv,Pb_O,S,
src/inference&case/figs/PbSO4.csv.png ADDED
src/inference&case/infResults_testdata.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ path,fileName,predRank,pred,codId,formula,spaceGroupNo,spaceGroup
2
+ ./testdata,PbSO4.csv,1,3378,9004484.0,Pb4 S4 O16,62,Pnma
src/inference&case/testdata/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/inference&case/testdata/PbSO4.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/model/CPICANN.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from torch.nn.init import trunc_normal_
8
+
9
+
10
+ class CPICANN(nn.Module):
11
+
12
+ def __init__(self, embed_dim=64, nhead=8, num_encoder_layers=6, dim_feedforward=1024,
13
+ dropout=0.1, activation="relu", num_classes=23073):
14
+ super().__init__()
15
+
16
+ self.embed_dim = embed_dim
17
+ self.num_classes = num_classes
18
+
19
+ self.conv = ConvModule(drop_rate=dropout)
20
+
21
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
22
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 142))
23
+
24
+ # -------------encoder----------------
25
+ sa_layer = SelfAttnLayer(embed_dim, nhead, dim_feedforward, dropout, activation)
26
+ self.encoder = SelfAttnModule(sa_layer, num_encoder_layers)
27
+ # ------------------------------------
28
+
29
+ self.norm_after = nn.LayerNorm(embed_dim)
30
+
31
+ self.cls_head = nn.Sequential(
32
+ nn.Linear(embed_dim, int(embed_dim * 4)),
33
+ nn.BatchNorm1d(int(embed_dim * 4)),
34
+ nn.ReLU(inplace=True),
35
+ nn.Dropout(0.5),
36
+ nn.Linear(int(embed_dim * 4), int(embed_dim * 4)),
37
+ nn.BatchNorm1d(int(embed_dim * 4)),
38
+ nn.ReLU(inplace=True),
39
+ nn.Dropout(0.5),
40
+ nn.Linear(int(embed_dim * 4), num_classes)
41
+ )
42
+
43
+ self._reset_parameters()
44
+ self.init_weights()
45
+
46
+ def _reset_parameters(self):
47
+ for p in self.parameters():
48
+ if p.dim() > 1:
49
+ nn.init.xavier_uniform_(p)
50
+
51
+ def init_weights(self):
52
+ trunc_normal_(self.cls_token, std=.02)
53
+
54
+ self.pos_embed.requires_grad = False
55
+
56
+ pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, np.array(range(self.pos_embed.shape[2])))
57
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).T.unsqueeze(0))
58
+
59
+ def bce_fineTune_init_weights(self):
60
+ for p in self.conv.parameters():
61
+ p.requires_grad = False
62
+
63
+ for p in self.encoder.parameters():
64
+ if p.dim() > 1:
65
+ nn.init.xavier_uniform_(p)
66
+ for p in self.cls_head.parameters():
67
+ if p.dim() > 1:
68
+ nn.init.xavier_uniform_(p)
69
+
70
+ def forward(self, x):
71
+ N = x.shape[0]
72
+ if x.shape[1] == 2:
73
+ x = x[:, 1:, :]
74
+
75
+ x = x / 100
76
+ x = self.conv(x)
77
+
78
+ # flatten NxCxL to LxNxC
79
+ x = x.permute(2, 0, 1).contiguous()
80
+
81
+ cls_token = self.cls_token.expand(-1, N, -1)
82
+ x = torch.cat((cls_token, x), dim=0)
83
+
84
+ pos_embed = self.pos_embed.permute(2, 0, 1).contiguous().repeat(1, N, 1)
85
+ feats = self.encoder(x, pos_embed)
86
+ feats = self.norm_after(feats)
87
+ logits = self.cls_head(feats[0])
88
+ return logits
89
+
90
+
91
+ class ConvModule(nn.Module):
92
+ def __init__(self, drop_rate=0.):
93
+ super().__init__()
94
+ self.drop_rate = drop_rate
95
+
96
+ self.conv1 = nn.Conv1d(1, 64, kernel_size=35, stride=2, padding=17)
97
+ self.bn1 = nn.BatchNorm1d(64)
98
+ self.act1 = nn.ReLU()
99
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
100
+
101
+ self.layer1 = Layer(64, 64, kernel_size=3, stride=2, downsample=True)
102
+ self.layer2 = Layer(64, 128, kernel_size=3, stride=2, downsample=True)
103
+ # self.layer3 = Layer(256, 256, kernel_size=3, stride=2, downsample=True)
104
+ self.maxpool2 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
105
+
106
+ def forward(self, x):
107
+ x = self.conv1(x)
108
+ x = self.bn1(x)
109
+ x = self.act1(x)
110
+ x = self.maxpool(x)
111
+
112
+ x = self.layer1(x)
113
+ x = self.layer2(x)
114
+ # x = self.layer3(x)
115
+ x = self.maxpool2(x)
116
+ return x
117
+
118
+
119
+ class SelfAttnModule(nn.Module):
120
+
121
+ def __init__(self, encoder_layer, num_layers, norm=None):
122
+ super().__init__()
123
+ self.layers = _get_clones(encoder_layer, num_layers)
124
+ self.num_layers = num_layers
125
+ self.norm = norm
126
+
127
+ def forward(self, src, pos):
128
+ output = src
129
+
130
+ for layer in self.layers:
131
+ output = layer(output, pos)
132
+
133
+ if self.norm is not None:
134
+ output = self.norm(output)
135
+
136
+ return output
137
+
138
+
139
+ class SelfAttnLayer(nn.Module):
140
+
141
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
142
+ activation="relu"):
143
+ super().__init__()
144
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
145
+ # Implementation of Feedforward model
146
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
147
+ self.dropout = nn.Dropout(dropout)
148
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
149
+
150
+ self.norm1 = nn.LayerNorm(d_model)
151
+ self.norm2 = nn.LayerNorm(d_model)
152
+ self.dropout1 = nn.Dropout(dropout)
153
+ self.dropout2 = nn.Dropout(dropout)
154
+
155
+ self.activation = _get_activation_fn(activation)
156
+
157
+ def forward(self, src, pos):
158
+ q = k = with_pos_embed(src, pos)
159
+ src2 = self.self_attn(q, k, value=src)[0]
160
+ src = src + self.dropout1(src2)
161
+ src = self.norm1(src)
162
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
163
+ src = src + self.dropout2(src2)
164
+ src = self.norm2(src)
165
+ return src
166
+
167
+
168
+ class Layer(nn.Module):
169
+ def __init__(self, inchannel, outchannel, kernel_size, stride, downsample):
170
+ super(Layer, self).__init__()
171
+ self.block1 = BasicBlock(inchannel, outchannel, kernel_size=kernel_size, stride=stride, downsample=downsample)
172
+ self.block2 = BasicBlock(outchannel, outchannel, kernel_size=kernel_size, stride=1)
173
+
174
+ def forward(self, x):
175
+ x = self.block1(x)
176
+ x = self.block2(x)
177
+ return x
178
+
179
+
180
+ class BasicBlock(nn.Module):
181
+ def __init__(self, inchannel, outchannel, kernel_size, stride, downsample=False):
182
+ super(BasicBlock, self).__init__()
183
+ self.conv1 = nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
184
+ self.bn1 = nn.BatchNorm1d(outchannel)
185
+ self.act1 = nn.ReLU(inplace=True)
186
+ self.conv2 = nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
187
+ self.bn2 = nn.BatchNorm1d(outchannel)
188
+ self.act2 = nn.ReLU(inplace=True)
189
+ self.downsample = nn.Sequential(
190
+ nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=2),
191
+ nn.BatchNorm1d(outchannel)
192
+ ) if downsample else None
193
+
194
+ def forward(self, x):
195
+ shortcut = x
196
+ x = self.conv1(x)
197
+ x = self.bn1(x)
198
+ x = self.conv2(x)
199
+ x = self.bn2(x)
200
+ if self.downsample is not None:
201
+ shortcut = self.downsample(shortcut)
202
+ x += shortcut
203
+ x = self.act2(x)
204
+ return x
205
+
206
+
207
+ def _get_clones(module, N):
208
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
209
+
210
+
211
+ def _get_activation_fn(activation):
212
+ """Return an activation function given a string"""
213
+ if activation == "relu":
214
+ return F.relu
215
+ if activation == "gelu":
216
+ return F.gelu
217
+ if activation == "glu":
218
+ return F.glu
219
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
220
+
221
+
222
+ def with_pos_embed(tensor, pos):
223
+ return tensor if pos is None else tensor + pos
224
+
225
+
226
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
227
+ """
228
+ embed_dim: output dimension for each position
229
+ pos: a list of positions to be encoded: size (M,)
230
+ out: (M, D)
231
+ """
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
234
+ omega /= embed_dim / 2.
235
+ omega = 1. / 10000 ** omega # (D/2,)
236
+
237
+ pos = pos.reshape(-1) # (M,)
238
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
239
+
240
+ emb_sin = np.sin(out).astype(np.float32) # (M, D/2)
241
+ emb_cos = np.cos(out).astype(np.float32) # (M, D/2)
242
+
243
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
244
+ return emb
src/model/dataset.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class XrdDataset(Dataset):
9
+ def __init__(self, data_dir, annotations_file):
10
+ self.labels = pd.read_csv(annotations_file)
11
+ self.data_dir = data_dir
12
+
13
+ def __len__(self):
14
+ return len(self.labels)
15
+
16
+ def __getitem__(self, idx):
17
+ dataid = str(self.labels.iloc[idx, 0])
18
+ data_path = os.path.join(self.data_dir, dataid + '.csv')
19
+ data_csv = pd.read_csv(data_path)
20
+ data = data_csv.values.astype(np.float32).T
21
+
22
+ label = self.labels.iloc[idx, 1]
23
+
24
+ return data, label
25
+
26
+
27
+ class mixDataset_cls_dynamic(Dataset):
28
+ def __init__(self, data_dir, anno_struc, mode):
29
+ self.data_dir = data_dir
30
+ self.codIdList = pd.read_csv(anno_struc).values[:, 0].astype(np.int32)
31
+ self.mode = mode
32
+
33
+ def __len__(self):
34
+ return 1000000
35
+
36
+ def __getitem__(self, idx):
37
+ choice1, choice2 = np.random.randint(0, 23073, 2)
38
+ if self.mode == 'train':
39
+ rand1, rand2 = np.random.randint(1, 25, 2)
40
+ else:
41
+ rand1, rand2 = np.random.randint(1, 7, 2)
42
+ data_path1 = os.path.join(self.data_dir, '{}_{}.csv'.format(self.codIdList[choice1], rand1))
43
+ data_path2 = os.path.join(self.data_dir, '{}_{}.csv'.format(self.codIdList[choice2], rand2))
44
+ data1 = pd.read_csv(data_path1).values.astype(np.float32).T
45
+ data2 = pd.read_csv(data_path2).values.astype(np.float32).T
46
+
47
+ ratio1 = np.random.randint(20, 81)
48
+ ratio2 = 100 - ratio1
49
+
50
+ label = np.zeros(23073).astype(np.float32)
51
+
52
+ label[choice1] = 0.4
53
+ label[choice2] = 0.4
54
+
55
+ return data1, data2, ratio1, ratio2, label
src/model/focal_loss.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+
6
+
7
+ class FocalLoss(nn.Module):
8
+ r"""
9
+ This criterion is a implemenation of Focal Loss, which is proposed in
10
+ Focal Loss for Dense Object Detection.
11
+
12
+ Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
13
+
14
+ The losses are averaged across observations for each minibatch.
15
+
16
+ Args:
17
+ alpha(1D Tensor, Variable) : the scalar factor for this criterion
18
+ gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
19
+ putting more focus on hard, misclassified examples
20
+ size_average(bool): By default, the losses are averaged over observations for each minibatch.
21
+ However, if the field size_average is set to False, the losses are
22
+ instead summed for each minibatch.
23
+ """
24
+ def __init__(self, class_num, alpha=None, gamma=2, size_average=True, device='cuda:0'):
25
+ super(FocalLoss, self).__init__()
26
+ if alpha is None:
27
+ self.alpha = Variable(torch.ones(class_num, 1))
28
+ else:
29
+ if isinstance(alpha, Variable):
30
+ self.alpha = alpha
31
+ else:
32
+ self.alpha = Variable(alpha)
33
+ self.gamma = gamma
34
+ self.class_num = class_num
35
+ self.size_average = size_average
36
+ self.device = device
37
+
38
+ def forward(self, inputs, targets):
39
+ N = inputs.size(0)
40
+ C = inputs.size(1)
41
+ P = F.softmax(inputs, dim=1)
42
+
43
+ class_mask = inputs.data.new(N, C).fill_(0)
44
+ class_mask = Variable(class_mask)
45
+ ids = targets.view(-1, 1)
46
+ class_mask.scatter_(1, ids.data, 1.)
47
+ # print(class_mask)
48
+
49
+ if inputs.is_cuda and not self.alpha.is_cuda:
50
+ # self.alpha = self.alpha.cuda()
51
+ self.alpha = self.alpha.to(self.device)
52
+ alpha = self.alpha[ids.data.view(-1)]
53
+
54
+ probs = (P * class_mask).sum(1).view(-1, 1)
55
+
56
+ log_p = probs.log()
57
+ # print('probs size= {}'.format(probs.size()))
58
+ # print(probs)
59
+
60
+ batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
61
+ # print('-----bacth_loss------')
62
+ # print(batch_loss)
63
+
64
+ if self.size_average:
65
+ loss = batch_loss.mean()
66
+ else:
67
+ loss = batch_loss.sum()
68
+ return loss
69
+
70
+
71
+ # import torch
72
+ # import torch.nn as nn
73
+ #
74
+ #
75
+ # class FocalLoss(nn.Module):
76
+ #
77
+ # def __init__(self, gamma=0, eps=1e-7):
78
+ # super(FocalLoss, self).__init__()
79
+ # self.gamma = gamma
80
+ # self.eps = eps
81
+ # self.ce = torch.nn.CrossEntropyLoss()
82
+ #
83
+ # def forward(self, input, target):
84
+ # logp = self.ce(input, target)
85
+ # p = torch.exp(-logp)
86
+ # loss = (1 - p) ** self.gamma * logp
87
+ # return loss.mean()
src/othermodels/ATTENTIONonly.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from torch.nn.init import trunc_normal_
8
+
9
+
10
+ class VIT(nn.Module):
11
+
12
+ def __init__(self, embed_dim=64, nhead=8, num_encoder_layers=6, dim_feedforward=1024,
13
+ dropout=0.1, activation="relu", num_classes=23073):
14
+ super().__init__()
15
+
16
+ self.embed_dim = embed_dim
17
+ self.num_classes = num_classes
18
+
19
+ self.conv = torch.nn.Conv1d(1, embed_dim, kernel_size=32, stride=32, padding=0)
20
+
21
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
22
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 141))
23
+
24
+ # -------------encoder----------------
25
+ sa_layer = SelfAttnLayer(embed_dim, nhead, dim_feedforward, dropout, activation)
26
+ self.encoder = SelfAttnModule(sa_layer, num_encoder_layers)
27
+ # ------------------------------------
28
+
29
+ self.norm_after = nn.LayerNorm(embed_dim)
30
+
31
+ self.cls_head = nn.Sequential(
32
+ nn.Linear(embed_dim, int(embed_dim * 4)),
33
+ nn.BatchNorm1d(int(embed_dim * 4)),
34
+ nn.ReLU(inplace=True),
35
+ nn.Dropout(0.5),
36
+ nn.Linear(int(embed_dim * 4), int(embed_dim * 4)),
37
+ nn.BatchNorm1d(int(embed_dim * 4)),
38
+ nn.ReLU(inplace=True),
39
+ nn.Dropout(0.5),
40
+ nn.Linear(int(embed_dim * 4), num_classes)
41
+ )
42
+
43
+ self._reset_parameters()
44
+ self.init_weights()
45
+
46
+ def _reset_parameters(self):
47
+ for p in self.parameters():
48
+ if p.dim() > 1:
49
+ nn.init.xavier_uniform_(p)
50
+
51
+ def init_weights(self):
52
+ trunc_normal_(self.cls_token, std=.02)
53
+
54
+ self.pos_embed.requires_grad = False
55
+
56
+ pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, np.array(range(self.pos_embed.shape[2])))
57
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).T.unsqueeze(0))
58
+
59
+ def bce_fineTune_init_weights(self):
60
+ for p in self.conv.parameters():
61
+ p.requires_grad = False
62
+
63
+ for p in self.encoder.parameters():
64
+ if p.dim() > 1:
65
+ nn.init.xavier_uniform_(p)
66
+ for p in self.cls_head.parameters():
67
+ if p.dim() > 1:
68
+ nn.init.xavier_uniform_(p)
69
+
70
+ def forward(self, x):
71
+ N = x.shape[0]
72
+ if x.shape[1] == 2:
73
+ x = x[:, 1:, :]
74
+
75
+ x = x / 100
76
+ x = self.conv(x)
77
+
78
+ # flatten NxCxL to LxNxC
79
+ x = x.permute(2, 0, 1).contiguous()
80
+
81
+ cls_token = self.cls_token.expand(-1, N, -1)
82
+ x = torch.cat((cls_token, x), dim=0)
83
+
84
+ pos_embed = self.pos_embed.permute(2, 0, 1).contiguous().repeat(1, N, 1)
85
+ feats = self.encoder(x, pos_embed)
86
+ feats = self.norm_after(feats)
87
+ logits = self.cls_head(feats[0])
88
+ return logits
89
+
90
+
91
+ class ConvModule(nn.Module):
92
+ def __init__(self, drop_rate=0.):
93
+ super().__init__()
94
+ self.drop_rate = drop_rate
95
+
96
+ self.conv1 = nn.Conv1d(1, 64, kernel_size=35, stride=2, padding=17)
97
+ self.bn1 = nn.BatchNorm1d(64)
98
+ self.act1 = nn.ReLU()
99
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
100
+
101
+ self.layer1 = Layer(64, 64, kernel_size=3, stride=2, downsample=True)
102
+ self.layer2 = Layer(64, 128, kernel_size=3, stride=2, downsample=True)
103
+ # self.layer3 = Layer(256, 256, kernel_size=3, stride=2, downsample=True)
104
+ self.maxpool2 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
105
+
106
+ def forward(self, x):
107
+ x = self.conv1(x)
108
+ x = self.bn1(x)
109
+ x = self.act1(x)
110
+ x = self.maxpool(x)
111
+
112
+ x = self.layer1(x)
113
+ x = self.layer2(x)
114
+ # x = self.layer3(x)
115
+ x = self.maxpool2(x)
116
+ return x
117
+
118
+
119
+ class SelfAttnModule(nn.Module):
120
+
121
+ def __init__(self, encoder_layer, num_layers, norm=None):
122
+ super().__init__()
123
+ self.layers = _get_clones(encoder_layer, num_layers)
124
+ self.num_layers = num_layers
125
+ self.norm = norm
126
+
127
+ def forward(self, src, pos):
128
+ output = src
129
+
130
+ for layer in self.layers:
131
+ output = layer(output, pos)
132
+
133
+ if self.norm is not None:
134
+ output = self.norm(output)
135
+
136
+ return output
137
+
138
+
139
+ class SelfAttnLayer(nn.Module):
140
+
141
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
142
+ activation="relu"):
143
+ super().__init__()
144
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
145
+ # Implementation of Feedforward model
146
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
147
+ self.dropout = nn.Dropout(dropout)
148
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
149
+
150
+ self.norm1 = nn.LayerNorm(d_model)
151
+ self.norm2 = nn.LayerNorm(d_model)
152
+ self.dropout1 = nn.Dropout(dropout)
153
+ self.dropout2 = nn.Dropout(dropout)
154
+
155
+ self.activation = _get_activation_fn(activation)
156
+
157
+ def forward(self, src, pos):
158
+ q = k = with_pos_embed(src, pos)
159
+ src2 = self.self_attn(q, k, value=src)[0]
160
+ src = src + self.dropout1(src2)
161
+ src = self.norm1(src)
162
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
163
+ src = src + self.dropout2(src2)
164
+ src = self.norm2(src)
165
+ return src
166
+
167
+
168
+ class Layer(nn.Module):
169
+ def __init__(self, inchannel, outchannel, kernel_size, stride, downsample):
170
+ super(Layer, self).__init__()
171
+ self.block1 = BasicBlock(inchannel, outchannel, kernel_size=kernel_size, stride=stride, downsample=downsample)
172
+ self.block2 = BasicBlock(outchannel, outchannel, kernel_size=kernel_size, stride=1)
173
+
174
+ def forward(self, x):
175
+ x = self.block1(x)
176
+ x = self.block2(x)
177
+ return x
178
+
179
+
180
+ class BasicBlock(nn.Module):
181
+ def __init__(self, inchannel, outchannel, kernel_size, stride, downsample=False):
182
+ super(BasicBlock, self).__init__()
183
+ self.conv1 = nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
184
+ self.bn1 = nn.BatchNorm1d(outchannel)
185
+ self.act1 = nn.ReLU(inplace=True)
186
+ self.conv2 = nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
187
+ self.bn2 = nn.BatchNorm1d(outchannel)
188
+ self.act2 = nn.ReLU(inplace=True)
189
+ self.downsample = nn.Sequential(
190
+ nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=2),
191
+ nn.BatchNorm1d(outchannel)
192
+ ) if downsample else None
193
+
194
+ def forward(self, x):
195
+ shortcut = x
196
+ x = self.conv1(x)
197
+ x = self.bn1(x)
198
+ x = self.conv2(x)
199
+ x = self.bn2(x)
200
+ if self.downsample is not None:
201
+ shortcut = self.downsample(shortcut)
202
+ x += shortcut
203
+ x = self.act2(x)
204
+ return x
205
+
206
+
207
+ def _get_clones(module, N):
208
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
209
+
210
+
211
+ def _get_activation_fn(activation):
212
+ """Return an activation function given a string"""
213
+ if activation == "relu":
214
+ return F.relu
215
+ if activation == "gelu":
216
+ return F.gelu
217
+ if activation == "glu":
218
+ return F.glu
219
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
220
+
221
+
222
+ def with_pos_embed(tensor, pos):
223
+ return tensor if pos is None else tensor + pos
224
+
225
+
226
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
227
+ """
228
+ embed_dim: output dimension for each position
229
+ pos: a list of positions to be encoded: size (M,)
230
+ out: (M, D)
231
+ """
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
234
+ omega /= embed_dim / 2.
235
+ omega = 1. / 10000 ** omega # (D/2,)
236
+
237
+ pos = pos.reshape(-1) # (M,)
238
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
239
+
240
+ emb_sin = np.sin(out).astype(np.float32) # (M, D/2)
241
+ emb_cos = np.cos(out).astype(np.float32) # (M, D/2)
242
+
243
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
244
+ return emb
src/othermodels/CNNonly.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from torch.nn.init import trunc_normal_
8
+
9
+
10
+ class CNN(nn.Module):
11
+
12
+ def __init__(self, embed_dim=64, nhead=8, num_encoder_layers=6, dim_feedforward=1024,
13
+ dropout=0.1, activation="relu", num_classes=23073):
14
+ super().__init__()
15
+
16
+ self.embed_dim = embed_dim
17
+ self.num_classes = num_classes
18
+
19
+ self.conv = ConvModule(drop_rate=dropout)
20
+
21
+ self.proj = nn.Linear(141, 1)
22
+
23
+ self.cls_head = nn.Sequential(
24
+ nn.Linear(embed_dim, int(embed_dim * 4)),
25
+ nn.BatchNorm1d(int(embed_dim * 4)),
26
+ nn.ReLU(inplace=True),
27
+ nn.Dropout(0.5),
28
+ nn.Linear(int(embed_dim * 4), int(embed_dim * 4)),
29
+ nn.BatchNorm1d(int(embed_dim * 4)),
30
+ nn.ReLU(inplace=True),
31
+ nn.Dropout(0.5),
32
+ nn.Linear(int(embed_dim * 4), num_classes)
33
+ )
34
+
35
+ self._reset_parameters()
36
+
37
+ def _reset_parameters(self):
38
+ for p in self.parameters():
39
+ if p.dim() > 1:
40
+ nn.init.xavier_uniform_(p)
41
+
42
+
43
+ def forward(self, x):
44
+ N = x.shape[0]
45
+ if x.shape[1] == 2:
46
+ x = x[:, 1:, :]
47
+
48
+ x = x / 100
49
+ x = self.conv(x)
50
+
51
+ # flatten NxCxL to LxNxC
52
+ # x = x.permute(2, 0, 1).contiguous()
53
+ x = self.proj(x).flatten(1)
54
+
55
+ logits = self.cls_head(x)
56
+ return logits
57
+
58
+
59
+ class ConvModule(nn.Module):
60
+ def __init__(self, drop_rate=0.):
61
+ super().__init__()
62
+ self.drop_rate = drop_rate
63
+
64
+ self.conv1 = nn.Conv1d(1, 64, kernel_size=35, stride=2, padding=17)
65
+ self.bn1 = nn.BatchNorm1d(64)
66
+ self.act1 = nn.ReLU()
67
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
68
+
69
+ self.layer1 = Layer(64, 64, kernel_size=3, stride=2, downsample=True)
70
+ self.layer2 = Layer(64, 128, kernel_size=3, stride=2, downsample=True)
71
+ # self.layer3 = Layer(256, 256, kernel_size=3, stride=2, downsample=True)
72
+ self.maxpool2 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
73
+
74
+ def forward(self, x):
75
+ x = self.conv1(x)
76
+ x = self.bn1(x)
77
+ x = self.act1(x)
78
+ x = self.maxpool(x)
79
+
80
+ x = self.layer1(x)
81
+ x = self.layer2(x)
82
+ # x = self.layer3(x)
83
+ x = self.maxpool2(x)
84
+ return x
85
+
86
+ class Layer(nn.Module):
87
+ def __init__(self, inchannel, outchannel, kernel_size, stride, downsample):
88
+ super(Layer, self).__init__()
89
+ self.block1 = BasicBlock(inchannel, outchannel, kernel_size=kernel_size, stride=stride, downsample=downsample)
90
+ self.block2 = BasicBlock(outchannel, outchannel, kernel_size=kernel_size, stride=1)
91
+
92
+ def forward(self, x):
93
+ x = self.block1(x)
94
+ x = self.block2(x)
95
+ return x
96
+
97
+ class BasicBlock(nn.Module):
98
+ def __init__(self, inchannel, outchannel, kernel_size, stride, downsample=False):
99
+ super(BasicBlock, self).__init__()
100
+ self.conv1 = nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
101
+ self.bn1 = nn.BatchNorm1d(outchannel)
102
+ self.act1 = nn.ReLU(inplace=True)
103
+ self.conv2 = nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
104
+ self.bn2 = nn.BatchNorm1d(outchannel)
105
+ self.act2 = nn.ReLU(inplace=True)
106
+ self.downsample = nn.Sequential(
107
+ nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=2),
108
+ nn.BatchNorm1d(outchannel)
109
+ ) if downsample else None
110
+
111
+ def forward(self, x):
112
+ shortcut = x
113
+ x = self.conv1(x)
114
+ x = self.bn1(x)
115
+ x = self.conv2(x)
116
+ x = self.bn2(x)
117
+ if self.downsample is not None:
118
+ shortcut = self.downsample(shortcut)
119
+ x += shortcut
120
+ x = self.act2(x)
121
+ return x
122
+
123
+
124
+ def _get_clones(module, N):
125
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
126
+
127
+
128
+ def _get_activation_fn(activation):
129
+ """Return an activation function given a string"""
130
+ if activation == "relu":
131
+ return F.relu
132
+ if activation == "gelu":
133
+ return F.gelu
134
+ if activation == "glu":
135
+ return F.glu
136
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
137
+
138
+
139
+ def with_pos_embed(tensor, pos):
140
+ return tensor if pos is None else tensor + pos
src/pretrained/# place pretrained .pth files here ADDED
File without changes
src/train_bi-phase.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+
4
+ import os
5
+ import argparse
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch import optim, nn
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+ from torch.utils.data import DataLoader
12
+ from tqdm import tqdm
13
+
14
+ from model.CPICANN import CPICANN
15
+ from model.dataset import mixDataset_cls_dynamic
16
+ from util.logger import Logger
17
+
18
+
19
+ def run_one_epoch(model, dataloader, criterion, optimizer, epoch, mode):
20
+ if mode == 'Train':
21
+ model.train()
22
+ criterion.train()
23
+ desc = 'Training... '
24
+ else:
25
+ model.eval()
26
+ criterion.eval()
27
+ desc = 'Evaluating... '
28
+
29
+ epoch_loss, cls_acc = 0, 0
30
+ if args.progress_bar:
31
+ pbar = tqdm(total=len(dataloader.dataset), desc=desc, unit='data')
32
+ iters = len(dataloader)
33
+ for i, batch in enumerate(dataloader):
34
+ data1 = batch[0].to(device)
35
+ data2 = batch[1].to(device)
36
+ ratio1 = batch[2].to(device)
37
+ ratio2 = batch[3].to(device)
38
+ label_cls = batch[4].to(device)
39
+
40
+ data = torch.einsum('ijk,i->ijk', data1, ratio1) + torch.einsum('ijk,i->ijk', data2, ratio2)
41
+ min_i = data.min(dim=2, keepdim=True)[0]
42
+ max_i = data.max(dim=2, keepdim=True)[0]
43
+ data = (data - min_i) / (max_i - min_i) * 100
44
+
45
+ if mode == 'Train':
46
+ adjust_learning_rate_withWarmup(optimizer, epoch + i / iters, args)
47
+
48
+ logits = model(data)
49
+ loss = criterion(logits, label_cls)
50
+
51
+ optimizer.zero_grad()
52
+ loss.backward()
53
+ optimizer.step()
54
+ else:
55
+ with torch.no_grad():
56
+ logits = model(data)
57
+ loss = criterion(logits, label_cls)
58
+
59
+ epoch_loss += loss.item()
60
+ if args.progress_bar:
61
+ pbar.update(len(data))
62
+ pbar.set_postfix(**{'loss': loss.item()})
63
+
64
+ return epoch_loss / iters
65
+
66
+
67
+ def print_log(epoch, loss_train, loss_val, lr):
68
+ log.printlog('---------------- Epoch {} ----------------'.format(epoch))
69
+
70
+ log.printlog('loss_train : {}'.format(round(loss_train, 6)))
71
+ log.printlog('loss_val : {}'.format(round(loss_val, 6)))
72
+
73
+ log.train_writer.add_scalar('mix_loss', loss_train, epoch)
74
+ log.val_writer.add_scalar('mix_loss', loss_val, epoch)
75
+
76
+ log.train_writer.add_scalar('lr', lr, epoch)
77
+
78
+
79
+ def save_checkpoint(state, is_best, filepath, filename):
80
+ if (state['epoch']) % 10 == 0 or state['epoch'] == 1:
81
+ os.makedirs(filepath, exist_ok=True)
82
+ torch.save(state, filepath + filename)
83
+ log.printlog('checkpoint saved!')
84
+ if is_best:
85
+ torch.save(state, '{}/model_best.pth'.format(filepath))
86
+ log.printlog('best model saved!')
87
+
88
+
89
+ def adjust_learning_rate(optimizer, epoch, schedule):
90
+ """Decay the learning rate based on schedule"""
91
+ lr = optimizer.defaults['lr']
92
+ for milestone in schedule:
93
+ lr *= 0.1 if epoch >= milestone else 1.
94
+ for param_group in optimizer.param_groups:
95
+ param_group['lr'] = lr
96
+
97
+
98
+ def adjust_learning_rate_withWarmup(optimizer, epoch, args):
99
+ """Decays the learning rate with half-cycle cosine after warmup"""
100
+ if epoch < args.warmup_epochs:
101
+ lr = args.lr * epoch / args.warmup_epochs
102
+ else:
103
+ lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
104
+ for param_group in optimizer.param_groups:
105
+ param_group['lr'] = lr
106
+ return lr
107
+
108
+
109
+ def main():
110
+ print('>>>> Running on {} <<<<'.format(device))
111
+
112
+ model = CPICANN(embed_dim=128, num_classes=args.num_classes)
113
+
114
+ # LOAD PRETRAINED MODEL
115
+ loaded = torch.load(args.load_path)
116
+ model.load_state_dict(loaded['model'])
117
+
118
+ model.bce_fineTune_init_weights()
119
+ model.to(device)
120
+ if rank == 0:
121
+ log.printlog(model)
122
+
123
+ trainset = mixDataset_cls_dynamic(args.data_dir_train, args.anno_struc, mode='Train')
124
+ valset = mixDataset_cls_dynamic(args.data_dir_val, args.anno_struc, mode='Eval')
125
+
126
+ if distributed:
127
+ train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
128
+ val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=True)
129
+
130
+ train_loader = DataLoader(trainset, batch_size=512, num_workers=16, pin_memory=True, drop_last=True, sampler=train_sampler)
131
+ val_loader = DataLoader(valset, batch_size=512, num_workers=16, pin_memory=True, drop_last=True, sampler=val_sampler)
132
+
133
+ model = DDP(model, device_ids=[device], output_device=local_rank, find_unused_parameters=False)
134
+ else:
135
+ train_loader = DataLoader(trainset, batch_size=512, num_workers=16, pin_memory=True, shuffle=True)
136
+ val_loader = DataLoader(valset, batch_size=512, num_workers=16, pin_memory=True, shuffle=True)
137
+
138
+ criterion = nn.CrossEntropyLoss()
139
+
140
+ optimizer = optim.AdamW(model.parameters(), args.lr, weight_decay=1e-4)
141
+ start_epoch = 0
142
+
143
+ for epoch in range(start_epoch + 1, args.epochs + 1):
144
+ if distributed:
145
+ train_sampler.set_epoch(epoch)
146
+ val_sampler.set_epoch(epoch)
147
+
148
+ loss_train = run_one_epoch(model, train_loader, criterion, optimizer, epoch, mode='Train')
149
+
150
+ loss_val = run_one_epoch(model, val_loader, criterion, optimizer, epoch, mode='Eval')
151
+
152
+ if rank == 0:
153
+ print_log(epoch, loss_train, loss_val, optimizer.param_groups[0]['lr'])
154
+ save_checkpoint({'epoch': epoch,
155
+ 'model': model.module.state_dict() if distributed else model.state_dict(),
156
+ 'optimizer': optimizer}, is_best=False,
157
+ filepath='{}/checkpoints/'.format(log.get_path()),
158
+ filename='checkpoint_{:04d}.pth'.format(epoch))
159
+
160
+
161
+ if __name__ == '__main__':
162
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
163
+ rank = int(os.environ["RANK"])
164
+ local_rank = int(os.environ["LOCAL_RANK"])
165
+ torch.cuda.set_device(rank % torch.cuda.device_count())
166
+ dist.init_process_group(backend="nccl")
167
+ device = torch.device("cuda", local_rank)
168
+ print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")
169
+
170
+ distributed = True
171
+ else:
172
+ rank = 0
173
+ device = 'cuda:0'
174
+ distributed = False
175
+
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument("-progress_bar", type=bool, default=True)
178
+
179
+ parser.add_argument('--epochs', default=200, type=int, metavar='N',
180
+ help='number of total epochs to run')
181
+ parser.add_argument('--warmup-epochs', default=20, type=int, metavar='N',
182
+ help='number of warmup epochs')
183
+ parser.add_argument('--lr', '--learning-rate', default=8e-4, type=float,
184
+ metavar='LR', help='initial (base) learning rate', dest='lr')
185
+
186
+ parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str,
187
+ help='path to load pretrained single-phase identification model')
188
+ parser.add_argument('--data_dir_train', default='data/train', type=str)
189
+ parser.add_argument('--data_dir_val', default='data/val', type=str)
190
+ parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
191
+ help='path to annotation file for structures')
192
+ parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
193
+
194
+ args = parser.parse_args()
195
+
196
+ if rank == 0:
197
+ log = Logger(val=True)
198
+
199
+ main()
200
+ print('THE END')
src/train_single-phase.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch import optim
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+
12
+ from model.dataset import XrdDataset
13
+ from model.CPICANN import CPICANN
14
+ from model.focal_loss import FocalLoss
15
+ from util.logger import Logger
16
+
17
+
18
+ def get_acc(cls, label):
19
+ cls_acc = sum(cls.argmax(1) == label.int()) / cls.shape[0]
20
+ return cls_acc
21
+
22
+
23
+ def run_one_epoch(model, dataloader, criterion, optimizer, epoch, mode):
24
+ if mode == 'Train':
25
+ model.train()
26
+ criterion.train()
27
+ desc = 'Training... '
28
+ else:
29
+ model.eval()
30
+ criterion.eval()
31
+ desc = 'Evaluating... '
32
+
33
+ epoch_loss, cls_acc = 0, 0
34
+ if args.progress_bar:
35
+ pbar = tqdm(total=len(dataloader.dataset), desc=desc, unit='data')
36
+ iters = len(dataloader)
37
+ for i, batch in enumerate(dataloader):
38
+ data = batch[0].to(device)
39
+ label_cls = batch[1].to(device)
40
+
41
+ if mode == 'Train':
42
+ adjust_learning_rate_withWarmup(optimizer, epoch + i / iters, args)
43
+
44
+ logits = model(data)
45
+ loss = criterion(logits, label_cls.long())
46
+
47
+ optimizer.zero_grad()
48
+ loss.backward()
49
+ optimizer.step()
50
+ else:
51
+ with torch.no_grad():
52
+ logits = model(data)
53
+ loss = criterion(logits, label_cls.long())
54
+
55
+ epoch_loss += loss.item()
56
+ if args.progress_bar:
57
+ pbar.update(len(data))
58
+ pbar.set_postfix(**{'loss': loss.item()})
59
+
60
+ _cls_acc = get_acc(logits, label_cls)
61
+ cls_acc += _cls_acc.item()
62
+
63
+ return epoch_loss / iters, cls_acc * 100 / iters
64
+
65
+
66
+ def print_log(epoch, loss_train, loss_val, acc_train, acc_val, lr):
67
+ log.printlog('---------------- Epoch {} ----------------'.format(epoch))
68
+
69
+ log.printlog('loss_train : {}'.format(round(loss_train, 4)))
70
+ log.printlog('loss_val : {}'.format(round(loss_val, 4)))
71
+
72
+ log.printlog('acc_train : {}%'.format(round(acc_train, 4)))
73
+ log.printlog('acc_val : {}%'.format(round(acc_val, 4)))
74
+
75
+ log.train_writer.add_scalar('loss', loss_train, epoch)
76
+ log.val_writer.add_scalar('loss', loss_val, epoch)
77
+
78
+ log.train_writer.add_scalar('acc', acc_train, epoch)
79
+ log.val_writer.add_scalar('acc', acc_val, epoch)
80
+
81
+ log.train_writer.add_scalar('lr', lr, epoch)
82
+
83
+
84
+ def save_checkpoint(state, is_best, filepath, filename):
85
+ if (state['epoch']) % 10 == 0 or state['epoch'] == 1:
86
+ os.makedirs(filepath, exist_ok=True)
87
+ torch.save(state, filepath + filename)
88
+ log.printlog('checkpoint saved!')
89
+ if is_best:
90
+ torch.save(state, '{}/model_best.pth'.format(filepath))
91
+ log.printlog('best model saved!')
92
+
93
+
94
+ def adjust_learning_rate_withWarmup(optimizer, epoch, args):
95
+ """Decays the learning rate with half-cycle cosine after warmup"""
96
+ if epoch < args.warmup_epochs:
97
+ lr = args.lr * epoch / args.warmup_epochs
98
+ else:
99
+ lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
100
+ for param_group in optimizer.param_groups:
101
+ param_group['lr'] = lr
102
+ return lr
103
+
104
+
105
+ def main():
106
+ print('>>>> Running on {} <<<<'.format(device))
107
+
108
+ model = CPICANN(embed_dim=128, num_classes=args.num_classes)
109
+ model.to(device)
110
+ if rank == 0:
111
+ log.printlog(model)
112
+
113
+ trainset = XrdDataset(args.data_dir_train, args.anno_train)
114
+ valset = XrdDataset(args.data_dir_val, args.anno_val)
115
+
116
+ if distributed:
117
+ train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
118
+ val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=True)
119
+
120
+ train_loader = DataLoader(trainset, batch_size=128, num_workers=16, pin_memory=True, drop_last=True, sampler=train_sampler)
121
+ val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, drop_last=True, sampler=val_sampler)
122
+
123
+ model = DDP(model, device_ids=[device], output_device=local_rank, find_unused_parameters=False)
124
+ else:
125
+ train_loader = DataLoader(trainset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
126
+ val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
127
+
128
+ criterion = FocalLoss(class_num=args.num_classes, device=device)
129
+
130
+ optimizer = optim.AdamW(model.parameters(), args.lr, weight_decay=1e-4)
131
+ start_epoch = 0
132
+
133
+ for epoch in range(start_epoch + 1, args.epochs + 1):
134
+ if distributed:
135
+ train_sampler.set_epoch(epoch)
136
+ val_sampler.set_epoch(epoch)
137
+
138
+ loss_train, acc_train = run_one_epoch(model, train_loader, criterion, optimizer, epoch, mode='Train')
139
+
140
+ loss_val, acc_val = run_one_epoch(model, val_loader, criterion, optimizer, epoch, mode='Eval')
141
+
142
+ if rank == 0:
143
+ print_log(epoch, loss_train, loss_val, acc_train, acc_val, optimizer.param_groups[0]['lr'])
144
+ save_checkpoint({'epoch': epoch,
145
+ 'model': model.module.state_dict() if distributed else model.state_dict(),
146
+ 'optimizer': optimizer}, is_best=False,
147
+ filepath='{}/checkpoints/'.format(log.get_path()),
148
+ filename='checkpoint_{:04d}.pth'.format(epoch))
149
+
150
+
151
+ if __name__ == '__main__':
152
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
153
+ rank = int(os.environ["RANK"])
154
+ local_rank = int(os.environ["LOCAL_RANK"])
155
+ torch.cuda.set_device(rank % torch.cuda.device_count())
156
+ dist.init_process_group(backend="nccl")
157
+ device = torch.device("cuda", local_rank)
158
+ print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")
159
+ distributed = True
160
+ else:
161
+ rank = 0
162
+ device = 'cuda:0'
163
+ distributed = False
164
+
165
+ parser = argparse.ArgumentParser()
166
+ parser.add_argument("--progress_bar", type=bool, default=True)
167
+
168
+ parser.add_argument('--epochs', default=200, type=int, metavar='N',
169
+ help='number of total epochs to run')
170
+ parser.add_argument('--warmup-epochs', default=20, type=int, metavar='N',
171
+ help='number of warmup epochs')
172
+ parser.add_argument('--lr', '--learning-rate', default=8e-5, type=float,
173
+ metavar='LR', help='initial (base) learning rate', dest='lr')
174
+
175
+ parser.add_argument('--data_dir_train', default='data/train/', type=str)
176
+ parser.add_argument('--data_dir_val', default='data/val/', type=str)
177
+ parser.add_argument('--anno_train', default='annotation/anno_train.csv', type=str,
178
+ help='path to annotation file for training data')
179
+ parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
180
+ help='path to annotation file for validation data')
181
+ parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
182
+
183
+ args = parser.parse_args()
184
+
185
+ if rank == 0:
186
+ log = Logger(val=True)
187
+
188
+ main()
189
+ print('THE END')
src/util/logger.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from tensorboardX import SummaryWriter
5
+ from datetime import datetime
6
+
7
+
8
+ class Logger(object):
9
+ def __init__(self, val=True, filename="print.log"):
10
+ self.Time = datetime.now().strftime('%Y-%m-%d_%H%M')
11
+ self.path = 'output/' + self.Time
12
+ self.log_filename = filename
13
+ os.makedirs(self.path) if os.path.exists(self.path) is False else None
14
+ self.run_path = '{}/{}'.format(self.path, 'tb')
15
+
16
+ # common log
17
+ self.terminal = sys.stdout
18
+ self.terminal.write(self.path)
19
+
20
+ # init tensorboardX
21
+ self.train_writer = None
22
+ self.val_writer = None
23
+ self.tensorboard_init(val)
24
+
25
+ def printlog(self, message):
26
+ message = str(message)
27
+ self.terminal.write(message + '\n')
28
+
29
+ log = open(os.path.join(self.path, self.log_filename), "a", encoding='utf8', )
30
+ log.write(message + '\n')
31
+ log.close()
32
+
33
+ def tensorboard_init(self, val=True):
34
+ if val:
35
+ self.train_writer = SummaryWriter(self.run_path+'/train')
36
+ self.val_writer = SummaryWriter(self.run_path+'/val')
37
+ else:
38
+ self.train_writer = SummaryWriter(self.run_path)
39
+
40
+ def get_path(self):
41
+ return self.path
src/val_bi-phase.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from model.CPICANN import CPICANN
10
+
11
+
12
+ def getAnnoMap():
13
+ vs = pd.read_csv(args.anno_struc).values
14
+ annos, elems = {}, {}
15
+ for v in vs:
16
+ annos[v[1]] = v
17
+ elems[v[1]] = set(v[3].split(' '))
18
+
19
+ return annos, elems
20
+
21
+
22
+ def filter_by_elem(logits, elemMap, elem):
23
+ for i, e in elemMap.items():
24
+ if not e <= elem:
25
+ logits[:, i] = -10 ** 9
26
+
27
+ return logits
28
+
29
+
30
+ def main():
31
+ annoMap, elemMap = getAnnoMap()
32
+
33
+ model = CPICANN(embed_dim=128, num_classes=args.num_classes)
34
+
35
+ loaded = torch.load(args.load_path)
36
+ model.load_state_dict(loaded['model'])
37
+ model.to(args.device)
38
+ model.eval()
39
+ print('loaded model from {}'.format(args.load_path))
40
+ print(model)
41
+
42
+ if args.elem_filtration:
43
+ print('elem_filtration activated!')
44
+ else:
45
+ print('elem_filtration deactivated!')
46
+
47
+ lst = pd.read_csv(args.anno_val).values
48
+
49
+ top10Hits = np.array([0] * 10, dtype=np.int32)
50
+
51
+ dataLen = len(lst)
52
+ pbar = tqdm(range(args.infTimes))
53
+ for i in range(args.infTimes):
54
+ while True:
55
+ c1, c2 = np.random.randint(0, dataLen, 2)
56
+ anno1, anno2 = lst[c1], lst[c2]
57
+ if anno1[6] != anno2[6]:
58
+ break
59
+
60
+ # id1, id2 = int(lst[c1][0].split('_')[0]), int(lst[c2][0].split('_')[0])
61
+ # formula1, formula2 = lst[c1][2], lst[c2][2]
62
+ data1 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c1][0]}.csv')).values
63
+ data2 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c2][0]}.csv')).values
64
+
65
+ mixRate1 = np.random.randint(20, 81)
66
+ mixRate2 = 100 - mixRate1
67
+
68
+ data = mixRate1 * data1 + mixRate2 * data2
69
+ elem = set(lst[c2][3].strip().split(' ')) | set(lst[c1][3].strip().split(' '))
70
+
71
+ def runFile(v):
72
+ min_i, scale = min(v), max(v) - min(v)
73
+ v = (v - min_i) / scale * 100
74
+
75
+ v = torch.tensor(v, dtype=torch.float32).reshape(1, 1, -1)
76
+ v = v.to(args.device)
77
+ with torch.no_grad():
78
+ logits = model(v)
79
+
80
+ # filter by elements
81
+ if args.elem_filtration:
82
+ logits = filter_by_elem(logits, elemMap, elem)
83
+
84
+ _pred = torch.nn.functional.softmax(logits.squeeze(), dim=0)
85
+ return _pred.topk(10)
86
+
87
+ top10 = runFile(data)
88
+
89
+ m = [0] * 10
90
+ for no, (indice, rate) in enumerate(zip(top10.indices, top10.values)):
91
+ pred = annoMap[top10.indices[no].item()]
92
+
93
+ if pred[0] == int(anno1[0][:7]):
94
+ m[no] = 1
95
+ elif pred[0] == int(anno2[0][:7]):
96
+ m[no] = 2
97
+
98
+ if 1 in m[:2] and 2 in m[:2]:
99
+ top10Hits[1:] += 1
100
+ elif 1 in m[:3] and 2 in m[:3]:
101
+ top10Hits[2:] += 1
102
+ elif 1 in m[:4] and 2 in m[:4]:
103
+ top10Hits[3:] += 1
104
+ elif 1 in m[:5] and 2 in m[:5]:
105
+ top10Hits[4:] += 1
106
+ elif 1 in m[:6] and 2 in m[:6]:
107
+ top10Hits[5:] += 1
108
+ elif 1 in m[:7] and 2 in m[:7]:
109
+ top10Hits[6:] += 1
110
+ elif 1 in m[:8] and 2 in m[:8]:
111
+ top10Hits[7:] += 1
112
+ elif 1 in m[:9] and 2 in m[:9]:
113
+ top10Hits[8:] += 1
114
+ elif 1 in m[:10] and 2 in m[:10]:
115
+ top10Hits[9:] += 1
116
+
117
+ pbar.update(1)
118
+ pbar.close()
119
+
120
+ for i in range(1, 10):
121
+ print('top{}Hits: {}%'.format(i + 1, round(top10Hits[i] / args.infTimes * 100, 2)))
122
+
123
+
124
+ if __name__ == '__main__':
125
+ parser = argparse.ArgumentParser()
126
+
127
+ parser.add_argument('--device', default='cuda:0', type=str)
128
+ parser.add_argument('--data_dir', default='data/val/', type=str)
129
+ parser.add_argument('--infTimes', default=1000, type=int, help='number of mixed pattern to be inferenced')
130
+ parser.add_argument('--load_path', default='pretrained/bi-phase_checkpoint_2000.pth', type=str,
131
+ help='path to load pretrained single-phase identification model')
132
+ parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
133
+ help='path to annotation file for training data')
134
+ parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
135
+ help='path to annotation file for validation data')
136
+ parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
137
+
138
+ parser.add_argument('--elem_filtration', default=False, type=bool)
139
+
140
+ args = parser.parse_args()
141
+
142
+ main()
143
+ print('THE END')
src/val_single-phase.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+
9
+ from model.CPICANN import CPICANN
10
+ from model.dataset import XrdDataset
11
+
12
+
13
+ def get_cs_anno():
14
+ vs = pd.read_csv(args.anno_struc).values
15
+ csAnno = {}
16
+ for v in vs:
17
+ csAnno[v[1]] = v[6]
18
+ return csAnno
19
+
20
+
21
+ def get_acc(cls, label):
22
+ correct_cnt = sum(cls.argmax(1) == label.int())
23
+ cls_acc = correct_cnt / cls.shape[0]
24
+ return cls_acc, correct_cnt
25
+
26
+
27
+ def run_one_epoch(model, dataloader):
28
+ model.eval()
29
+
30
+ csAnno = get_cs_anno()
31
+
32
+ csCorrect = [0 for _ in range(7)]
33
+ csTotal = [0 for _ in range(7)]
34
+ cMtrx = [[0 for _ in range(7)] for _ in range(7)]
35
+ epoch_loss, cls_acc = 0, 0
36
+ correct_cnt, total_cnt = 0, 0
37
+ pbar = tqdm(total=len(dataloader.dataset), desc='Evaluating... ', unit='data')
38
+ iters = len(dataloader)
39
+ for i, batch in enumerate(dataloader):
40
+
41
+ data = batch[0].to(args.device)
42
+ label_cls = batch[1].to(args.device)
43
+
44
+ with torch.no_grad():
45
+ logits = model(data)
46
+ logits.to(args.device)
47
+
48
+ pbar.update(len(data))
49
+
50
+ _cls_acc, correct = get_acc(logits, label_cls)
51
+ cls_acc += _cls_acc.item()
52
+
53
+ correct_cnt += correct.item()
54
+ total_cnt += len(data)
55
+
56
+ preds = logits.argmax(1)
57
+ for gt, pred in zip(label_cls, preds):
58
+ cs_gt = csAnno[gt.item()]
59
+ cMtrx[cs_gt][csAnno[pred.item()]] += 1
60
+ csTotal[cs_gt] += 1
61
+ if gt == pred:
62
+ csCorrect[cs_gt] += 1
63
+
64
+ return epoch_loss / iters, cls_acc * 100 / iters, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal
65
+
66
+
67
+ def main():
68
+ model = CPICANN(embed_dim=128, num_classes=args.num_classes)
69
+
70
+ loaded = torch.load(args.load_path)
71
+ model.load_state_dict(loaded['model'])
72
+ model.to(args.device)
73
+ model.eval()
74
+ print('loaded model from {}'.format(args.load_path))
75
+
76
+ print(model)
77
+
78
+ valset = XrdDataset(args.data_dir, args.anno_val)
79
+ val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
80
+
81
+ loss_val, acc_val, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal = run_one_epoch(model, val_loader)
82
+
83
+ print("loss_val: ", loss_val)
84
+ print("acc_val: ", acc_val)
85
+ print("{}% ({}/{})".format(round(correct_cnt / total_cnt, 5) * 100, correct_cnt, total_cnt))
86
+
87
+ sums = np.array(cMtrx).sum(axis=1)
88
+ for i, row in enumerate(cMtrx):
89
+ buf = ""
90
+ for j, v in enumerate(row):
91
+ buf += "{}({}%) ".format(v, round(v / sums[i] * 100, 2))
92
+ print(buf)
93
+
94
+ print("csCorrect: ", csCorrect)
95
+ print("csTotal: ", csTotal)
96
+
97
+
98
+ if __name__ == '__main__':
99
+ parser = argparse.ArgumentParser()
100
+
101
+ parser.add_argument('--device', default='cuda:0', type=str)
102
+ parser.add_argument('--data_dir', default='data/val/', type=str)
103
+ parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str,
104
+ help='path to load pretrained single-phase identification model')
105
+ parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
106
+ help='path to annotation file for training data')
107
+ parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
108
+ help='path to annotation file for validation data')
109
+ parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
110
+
111
+ args = parser.parse_args()
112
+
113
+ main()
114
+
115
+ print('THE END')