Upload 24 files
Browse files- src/README.md +27 -0
- src/annotation/# place anno_train.csv and anno_val.csv here +0 -0
- src/annotation/anno_struc.csv +0 -0
- src/data/# place directory train and directory val here +0 -0
- src/data_format.py +60 -0
- src/inference&case/.DS_Store +0 -0
- src/inference&case/.ipynb_checkpoints/CPICANNcode-checkpoint.ipynb +778 -0
- src/inference&case/CPICANNcode.ipynb +778 -0
- src/inference&case/config/elem_setting.csv +5 -0
- src/inference&case/figs/PbSO4.csv.png +0 -0
- src/inference&case/infResults_testdata.csv +2 -0
- src/inference&case/testdata/.DS_Store +0 -0
- src/inference&case/testdata/PbSO4.csv +0 -0
- src/model/CPICANN.py +244 -0
- src/model/dataset.py +55 -0
- src/model/focal_loss.py +87 -0
- src/othermodels/ATTENTIONonly.py +244 -0
- src/othermodels/CNNonly.py +140 -0
- src/pretrained/# place pretrained .pth files here +0 -0
- src/train_bi-phase.py +200 -0
- src/train_single-phase.py +189 -0
- src/util/logger.py +41 -0
- src/val_bi-phase.py +143 -0
- src/val_single-phase.py +115 -0
src/README.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Instructions for replication
|
| 2 |
+
|
| 3 |
+
This directory contains all the source code needed to reproduce this work.
|
| 4 |
+
|
| 5 |
+
### Data preparation
|
| 6 |
+
|
| 7 |
+
To directly run the train and validation script in this directory, data preparation needs to be done. The [OneDrive link](https://hkustgz-my.sharepoint.com/:f:/g/personal/bcao686_connect_hkust-gz_edu_cn/EhdJLtou8I1MoUJCu-KCoboBfi-wOp00WAlQCrONxjoYgg?e=rltgFE) contains all the training and synthetic testing data used in this work, stored in data.zip. This link also contains the pretrained model for single-phase and di-phase identification.
|
| 8 |
+
|
| 9 |
+
File single-phase_checkpoint_0200.pth and file bi-phase_checkpoint_2000.pth from the link above is the pretrained model, place them under directory "pretrained".
|
| 10 |
+
|
| 11 |
+
File data.zip contains the data and the annotaion file. Place directory "train" and "val" from data.zip under directory "data", place the annotation files anno_train.csv and anno_val.csv under directory "annotation".
|
| 12 |
+
|
| 13 |
+
### Model Trianing
|
| 14 |
+
|
| 15 |
+
#### Single-phase
|
| 16 |
+
|
| 17 |
+
Run ```python train_single-phase.py``` to train the single-phase identification model from scratch. To train the model on your data, addtional parameters need to be set: ```python train_single-phase.py --data_dir_train=[your training data] --data_dir_val=[your validation data] --anno_train=[your anno file for training data] --anno_val=[your anno file for validation data]```.
|
| 18 |
+
|
| 19 |
+
#### Bi-phase
|
| 20 |
+
|
| 21 |
+
Run ```python train_bi-phase.py``` to train the bi-phase identification model. The bi-phase identification model is trained based on single-phase model, you can change the default setting by set the parameter ```load_path=[your pretrained single-phase model]```.
|
| 22 |
+
|
| 23 |
+
### Model validation
|
| 24 |
+
|
| 25 |
+
Run ```python train_single-phase.py``` and ```python val_bi-phase.py``` to run the validation code at default setting.
|
| 26 |
+
|
| 27 |
+
If you wish to validate the model on your data, plase format your data using data_format.py
|
src/annotation/# place anno_train.csv and anno_val.csv here
ADDED
|
File without changes
|
src/annotation/anno_struc.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/data/# place directory train and directory val here
ADDED
|
File without changes
|
src/data_format.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from scipy import interpolate
|
| 4 |
+
|
| 5 |
+
global dataWriter
|
| 6 |
+
|
| 7 |
+
def convert_file(file_path):
|
| 8 |
+
suffix = file_path.split('.')[-1]
|
| 9 |
+
if suffix not in ['txt', 'csv', 'xy']:
|
| 10 |
+
Warning(f'File {file_path} not supported, skiping...')
|
| 11 |
+
return None
|
| 12 |
+
|
| 13 |
+
if suffix == 'txt':
|
| 14 |
+
return txt_to_csv(file_path)
|
| 15 |
+
elif suffix == 'csv':
|
| 16 |
+
return csv_to_csv(file_path)
|
| 17 |
+
elif suffix == 'xy':
|
| 18 |
+
return xy_to_csv(file_path)
|
| 19 |
+
|
| 20 |
+
def txt_to_csv(file_path):
|
| 21 |
+
f = open(file_path, 'r')
|
| 22 |
+
rows = []
|
| 23 |
+
for line in f.readlines():
|
| 24 |
+
line = line.strip('\n')
|
| 25 |
+
line = line.replace('\t', ' ')
|
| 26 |
+
line = [x for x in line.split(' ') if x != '']
|
| 27 |
+
if len(line) == 3:
|
| 28 |
+
try:
|
| 29 |
+
line = [line[0], float(line[1])-float(line[2])]
|
| 30 |
+
except ValueError:
|
| 31 |
+
continue
|
| 32 |
+
elif len(line) < 2 or len(line) > 3:
|
| 33 |
+
continue
|
| 34 |
+
rows.append(line)
|
| 35 |
+
f.close()
|
| 36 |
+
|
| 37 |
+
outData = upsample(rows)
|
| 38 |
+
return outData
|
| 39 |
+
|
| 40 |
+
def csv_to_csv(file_path):
|
| 41 |
+
fromData = pd.read_csv(file_path).values
|
| 42 |
+
outData = upsample(list(fromData))
|
| 43 |
+
return outData
|
| 44 |
+
|
| 45 |
+
def xy_to_csv(file_path):
|
| 46 |
+
return txt_to_csv(file_path)
|
| 47 |
+
|
| 48 |
+
def upsample(rows):
|
| 49 |
+
if len(rows) == 0:
|
| 50 |
+
Warning('Empty data!')
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
rows.insert(0, ['10', rows[0][1]]) if float(rows[0][0]) > 10 else None
|
| 54 |
+
rows.append(['80', rows[-1][1]]) if float(rows[-1][0]) < 80 else None
|
| 55 |
+
rowsData = np.array(rows, dtype=np.float32)
|
| 56 |
+
f = interpolate.interp1d(rowsData[:, 0], rowsData[:, 1], kind='slinear')
|
| 57 |
+
xnew = np.linspace(10, 80, 4500)
|
| 58 |
+
ynew = f(xnew)
|
| 59 |
+
# outData = np.array([xnew, ynew]).T
|
| 60 |
+
return ynew
|
src/inference&case/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
src/inference&case/.ipynb_checkpoints/CPICANNcode-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "f2299629",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# It is a template for applying CPICANN to X-ray powder diffraction phase identification"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "markdown",
|
| 13 |
+
"id": "cad44131",
|
| 14 |
+
"metadata": {},
|
| 15 |
+
"source": [
|
| 16 |
+
"### 1: install WPEMPhase package "
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"id": "af59fa9d",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": [
|
| 24 |
+
"pip install WPEMPhase"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"id": "7d9d6d02",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"source": [
|
| 32 |
+
"### 2: The first time you execute CPICANN on your computer, you should initialize the system documents. After that, you do not need to do any additional execution to run CPICANN."
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "raw",
|
| 37 |
+
"id": "4787c1b4",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"source": [
|
| 40 |
+
"Signature:\n",
|
| 41 |
+
"CPICANN.PhaseIdentifier(\n",
|
| 42 |
+
" FilePath,\n",
|
| 43 |
+
" Task='single-phase',\n",
|
| 44 |
+
" Model='default',\n",
|
| 45 |
+
" ElementsSystem='',\n",
|
| 46 |
+
" ElementsContained='',\n",
|
| 47 |
+
" Device='cuda:0',\n",
|
| 48 |
+
")\n",
|
| 49 |
+
"Docstring:\n",
|
| 50 |
+
"CPICANN : Crystallographic Phase Identifier of Convolutional self-Attention Neural Network\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"Contributors : Shouyang Zhang & Bin Cao\n",
|
| 53 |
+
"================================================================\n",
|
| 54 |
+
" Please feel free to open issues in the Github :\n",
|
| 55 |
+
" https://github.com/WPEM/CPICANN\n",
|
| 56 |
+
" or\n",
|
| 57 |
+
" contact Mr.Bin Cao (bcao686@connect.hkust-gz.edu.cn)\n",
|
| 58 |
+
" in case of any problems/comments/suggestions in using the code.\n",
|
| 59 |
+
"==================================================================\n",
|
| 60 |
+
"\n",
|
| 61 |
+
":param FilePath\n",
|
| 62 |
+
"\n",
|
| 63 |
+
":param Task, type=str, default='single-phase'\n",
|
| 64 |
+
" if Task = 'single-phase', CPICANN executes a single phase identification task\n",
|
| 65 |
+
" if Task = 'di-phase', CPICANN executes a dual phase identification task\n",
|
| 66 |
+
"\n",
|
| 67 |
+
":param Model, type=str, default='default'\n",
|
| 68 |
+
" if Model = 'noise_model', CPICANN executes a single phase identification by noise-contained model\n",
|
| 69 |
+
" if Model = 'bca_model', CPICANN executes a single phase identification by background-contained model\n",
|
| 70 |
+
"\n",
|
| 71 |
+
":param ElementsSystem, type=str, default=''\n",
|
| 72 |
+
" Specifies the elements to be included at least in the prediction, example: 'Fe'.\n",
|
| 73 |
+
"\n",
|
| 74 |
+
":param ElementsContained, type=str, default=''\n",
|
| 75 |
+
" Specifies the elements to be included, with at least one of them in the prediction, example: 'O_C_S'.\n",
|
| 76 |
+
"\n",
|
| 77 |
+
":param Device, type=str, default='cuda:0',\n",
|
| 78 |
+
" Which device to run the CPICANN, example: 'cuda:0', 'cpu'.\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"examples:\n",
|
| 81 |
+
"from WPEMPhase import CPICANN\n",
|
| 82 |
+
"CPICANN.PhaseIdentifier(FilePath='./single-phase',Device='cpu')\n",
|
| 83 |
+
"File: ~/miniconda3/lib/python3.9/site-packages/WPEMPhase/CPICANN.py\n",
|
| 84 |
+
"Type: function"
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"cell_type": "code",
|
| 89 |
+
"execution_count": 2,
|
| 90 |
+
"id": "0922c99d",
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [
|
| 93 |
+
{
|
| 94 |
+
"name": "stdout",
|
| 95 |
+
"output_type": "stream",
|
| 96 |
+
"text": [
|
| 97 |
+
"Collecting WPEMPhase\n",
|
| 98 |
+
" Downloading WPEMPhase-0.1.0-py3-none-any.whl.metadata (1.0 kB)\n",
|
| 99 |
+
"Requirement already satisfied: torch in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2.0.0)\n",
|
| 100 |
+
"Requirement already satisfied: plot in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (0.6.5)\n",
|
| 101 |
+
"Requirement already satisfied: scipy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.9.3)\n",
|
| 102 |
+
"Requirement already satisfied: pandas in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.5.1)\n",
|
| 103 |
+
"Requirement already satisfied: numpy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.23.3)\n",
|
| 104 |
+
"Requirement already satisfied: art in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (6.1)\n",
|
| 105 |
+
"Requirement already satisfied: pymatgen in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2023.3.23)\n",
|
| 106 |
+
"Requirement already satisfied: wget in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (3.2)\n",
|
| 107 |
+
"Requirement already satisfied: python-dateutil>=2.8.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2.8.2)\n",
|
| 108 |
+
"Requirement already satisfied: pytz>=2020.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2022.5)\n",
|
| 109 |
+
"Requirement already satisfied: matplotlib in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.1)\n",
|
| 110 |
+
"Requirement already satisfied: typing in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.4.3)\n",
|
| 111 |
+
"Requirement already satisfied: pyyaml in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (6.0)\n",
|
| 112 |
+
"Requirement already satisfied: monty>=3.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2023.4.10)\n",
|
| 113 |
+
"Requirement already satisfied: mp-api>=0.27.3 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.31.2)\n",
|
| 114 |
+
"Requirement already satisfied: networkx>=2.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.8.8)\n",
|
| 115 |
+
"Requirement already satisfied: palettable>=3.1.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.3.3)\n",
|
| 116 |
+
"Requirement already satisfied: plotly>=4.5.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (5.14.1)\n",
|
| 117 |
+
"Requirement already satisfied: pybtex in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.24.0)\n",
|
| 118 |
+
"Requirement already satisfied: requests in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.28.2)\n",
|
| 119 |
+
"Requirement already satisfied: ruamel.yaml>=0.17.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.17.21)\n",
|
| 120 |
+
"Requirement already satisfied: spglib>=2.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.0.2)\n",
|
| 121 |
+
"Requirement already satisfied: sympy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (1.11.1)\n",
|
| 122 |
+
"Requirement already satisfied: tabulate in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.9.0)\n",
|
| 123 |
+
"Requirement already satisfied: tqdm in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (4.66.1)\n",
|
| 124 |
+
"Requirement already satisfied: uncertainties>=3.1.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.1.7)\n",
|
| 125 |
+
"Requirement already satisfied: filelock in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.10.7)\n",
|
| 126 |
+
"Requirement already satisfied: typing-extensions in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (4.11.0)\n",
|
| 127 |
+
"Requirement already satisfied: jinja2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.1.2)\n",
|
| 128 |
+
"Requirement already satisfied: contourpy>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.0.5)\n",
|
| 129 |
+
"Requirement already satisfied: cycler>=0.10 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (0.11.0)\n",
|
| 130 |
+
"Requirement already satisfied: fonttools>=4.22.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (4.38.0)\n",
|
| 131 |
+
"Requirement already satisfied: kiwisolver>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.4.4)\n",
|
| 132 |
+
"Requirement already satisfied: packaging>=20.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (23.0)\n",
|
| 133 |
+
"Requirement already satisfied: pillow>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (9.5.0)\n",
|
| 134 |
+
"Requirement already satisfied: pyparsing>=2.3.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (3.0.9)\n",
|
| 135 |
+
"Requirement already satisfied: importlib-resources>=3.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (5.12.0)\n",
|
| 136 |
+
"Requirement already satisfied: setuptools in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (67.6.1)\n",
|
| 137 |
+
"Requirement already satisfied: msgpack in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (1.0.5)\n",
|
| 138 |
+
"Requirement already satisfied: emmet-core<=0.50.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (0.50.0)\n",
|
| 139 |
+
"Requirement already satisfied: tenacity>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plotly>=4.5.0->pymatgen->WPEMPhase) (8.2.2)\n",
|
| 140 |
+
"Requirement already satisfied: six>=1.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas->WPEMPhase) (1.16.0)\n",
|
| 141 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2.0.4)\n",
|
| 142 |
+
"Requirement already satisfied: idna<4,>=2.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (3.3)\n",
|
| 143 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (1.26.9)\n",
|
| 144 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2022.12.7)\n",
|
| 145 |
+
"Requirement already satisfied: ruamel.yaml.clib>=0.2.6 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from ruamel.yaml>=0.17.0->pymatgen->WPEMPhase) (0.2.6)\n",
|
| 146 |
+
"Requirement already satisfied: future in /Users/jacob/miniconda3/lib/python3.9/site-packages (from uncertainties>=3.1.4->pymatgen->WPEMPhase) (0.18.3)\n",
|
| 147 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from jinja2->torch->WPEMPhase) (2.1.1)\n",
|
| 148 |
+
"Requirement already satisfied: latexcodec>=1.0.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pybtex->pymatgen->WPEMPhase) (2.0.1)\n",
|
| 149 |
+
"Requirement already satisfied: mpmath>=0.19 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from sympy->pymatgen->WPEMPhase) (1.3.0)\n",
|
| 150 |
+
"Requirement already satisfied: pydantic>=1.10.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from emmet-core<=0.50.0->mp-api>=0.27.3->pymatgen->WPEMPhase) (1.10.7)\n",
|
| 151 |
+
"Requirement already satisfied: zipp>=3.1.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib->plot->WPEMPhase) (3.9.0)\n",
|
| 152 |
+
"Downloading WPEMPhase-0.1.0-py3-none-any.whl (710 kB)\n",
|
| 153 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m710.2/710.2 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
|
| 154 |
+
"\u001b[?25hInstalling collected packages: WPEMPhase\n",
|
| 155 |
+
"Successfully installed WPEMPhase-0.1.0\n",
|
| 156 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
],
|
| 160 |
+
"source": [
|
| 161 |
+
"pip install WPEMPhase"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": 3,
|
| 167 |
+
"id": "8e1680a6",
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"from WPEMPhase import CPICANN"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"cell_type": "code",
|
| 176 |
+
"execution_count": 4,
|
| 177 |
+
"id": "7625eb66",
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [
|
| 180 |
+
{
|
| 181 |
+
"name": "stdout",
|
| 182 |
+
"output_type": "stream",
|
| 183 |
+
"text": [
|
| 184 |
+
"This is the first time CPICANN is being executed on your computer, configuring...\n",
|
| 185 |
+
"Downloading: 3% [24690688 / 776454342] bytes"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"name": "stderr",
|
| 190 |
+
"output_type": "stream",
|
| 191 |
+
"text": [
|
| 192 |
+
"IOPub message rate exceeded.\n",
|
| 193 |
+
"The notebook server will temporarily stop sending output\n",
|
| 194 |
+
"to the client in order to avoid crashing it.\n",
|
| 195 |
+
"To change this limit, set the config variable\n",
|
| 196 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"Current values:\n",
|
| 199 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 200 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 201 |
+
"\n"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"name": "stdout",
|
| 206 |
+
"output_type": "stream",
|
| 207 |
+
"text": [
|
| 208 |
+
"Downloading: 7% [61341696 / 776454342] bytes"
|
| 209 |
+
]
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"name": "stderr",
|
| 213 |
+
"output_type": "stream",
|
| 214 |
+
"text": [
|
| 215 |
+
"IOPub message rate exceeded.\n",
|
| 216 |
+
"The notebook server will temporarily stop sending output\n",
|
| 217 |
+
"to the client in order to avoid crashing it.\n",
|
| 218 |
+
"To change this limit, set the config variable\n",
|
| 219 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"Current values:\n",
|
| 222 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 223 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 224 |
+
"\n"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"name": "stdout",
|
| 229 |
+
"output_type": "stream",
|
| 230 |
+
"text": [
|
| 231 |
+
"Downloading: 13% [107954176 / 776454342] bytes"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"name": "stderr",
|
| 236 |
+
"output_type": "stream",
|
| 237 |
+
"text": [
|
| 238 |
+
"IOPub message rate exceeded.\n",
|
| 239 |
+
"The notebook server will temporarily stop sending output\n",
|
| 240 |
+
"to the client in order to avoid crashing it.\n",
|
| 241 |
+
"To change this limit, set the config variable\n",
|
| 242 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"Current values:\n",
|
| 245 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 246 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 247 |
+
"\n"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"name": "stdout",
|
| 252 |
+
"output_type": "stream",
|
| 253 |
+
"text": [
|
| 254 |
+
"Downloading: 19% [148324352 / 776454342] bytes"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"name": "stderr",
|
| 259 |
+
"output_type": "stream",
|
| 260 |
+
"text": [
|
| 261 |
+
"IOPub message rate exceeded.\n",
|
| 262 |
+
"The notebook server will temporarily stop sending output\n",
|
| 263 |
+
"to the client in order to avoid crashing it.\n",
|
| 264 |
+
"To change this limit, set the config variable\n",
|
| 265 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"Current values:\n",
|
| 268 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 269 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 270 |
+
"\n"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"name": "stdout",
|
| 275 |
+
"output_type": "stream",
|
| 276 |
+
"text": [
|
| 277 |
+
"Downloading: 24% [189382656 / 776454342] bytes"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"name": "stderr",
|
| 282 |
+
"output_type": "stream",
|
| 283 |
+
"text": [
|
| 284 |
+
"IOPub message rate exceeded.\n",
|
| 285 |
+
"The notebook server will temporarily stop sending output\n",
|
| 286 |
+
"to the client in order to avoid crashing it.\n",
|
| 287 |
+
"To change this limit, set the config variable\n",
|
| 288 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"Current values:\n",
|
| 291 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 292 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 293 |
+
"\n"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"name": "stdout",
|
| 298 |
+
"output_type": "stream",
|
| 299 |
+
"text": [
|
| 300 |
+
"Downloading: 28% [221265920 / 776454342] bytes"
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"name": "stderr",
|
| 305 |
+
"output_type": "stream",
|
| 306 |
+
"text": [
|
| 307 |
+
"IOPub message rate exceeded.\n",
|
| 308 |
+
"The notebook server will temporarily stop sending output\n",
|
| 309 |
+
"to the client in order to avoid crashing it.\n",
|
| 310 |
+
"To change this limit, set the config variable\n",
|
| 311 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"Current values:\n",
|
| 314 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 315 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 316 |
+
"\n"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"name": "stdout",
|
| 321 |
+
"output_type": "stream",
|
| 322 |
+
"text": [
|
| 323 |
+
"Downloading: 33% [262488064 / 776454342] bytes"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"name": "stderr",
|
| 328 |
+
"output_type": "stream",
|
| 329 |
+
"text": [
|
| 330 |
+
"IOPub message rate exceeded.\n",
|
| 331 |
+
"The notebook server will temporarily stop sending output\n",
|
| 332 |
+
"to the client in order to avoid crashing it.\n",
|
| 333 |
+
"To change this limit, set the config variable\n",
|
| 334 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"Current values:\n",
|
| 337 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 338 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 339 |
+
"\n"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"name": "stdout",
|
| 344 |
+
"output_type": "stream",
|
| 345 |
+
"text": [
|
| 346 |
+
"Downloading: 39% [304799744 / 776454342] bytes"
|
| 347 |
+
]
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"name": "stderr",
|
| 351 |
+
"output_type": "stream",
|
| 352 |
+
"text": [
|
| 353 |
+
"IOPub message rate exceeded.\n",
|
| 354 |
+
"The notebook server will temporarily stop sending output\n",
|
| 355 |
+
"to the client in order to avoid crashing it.\n",
|
| 356 |
+
"To change this limit, set the config variable\n",
|
| 357 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 358 |
+
"\n",
|
| 359 |
+
"Current values:\n",
|
| 360 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 361 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 362 |
+
"\n"
|
| 363 |
+
]
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"name": "stdout",
|
| 367 |
+
"output_type": "stream",
|
| 368 |
+
"text": [
|
| 369 |
+
"Downloading: 44% [346030080 / 776454342] bytes"
|
| 370 |
+
]
|
| 371 |
+
},
|
| 372 |
+
{
|
| 373 |
+
"name": "stderr",
|
| 374 |
+
"output_type": "stream",
|
| 375 |
+
"text": [
|
| 376 |
+
"IOPub message rate exceeded.\n",
|
| 377 |
+
"The notebook server will temporarily stop sending output\n",
|
| 378 |
+
"to the client in order to avoid crashing it.\n",
|
| 379 |
+
"To change this limit, set the config variable\n",
|
| 380 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"Current values:\n",
|
| 383 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 384 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 385 |
+
"\n"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"name": "stdout",
|
| 390 |
+
"output_type": "stream",
|
| 391 |
+
"text": [
|
| 392 |
+
"Downloading: 50% [388333568 / 776454342] bytes"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"name": "stderr",
|
| 397 |
+
"output_type": "stream",
|
| 398 |
+
"text": [
|
| 399 |
+
"IOPub message rate exceeded.\n",
|
| 400 |
+
"The notebook server will temporarily stop sending output\n",
|
| 401 |
+
"to the client in order to avoid crashing it.\n",
|
| 402 |
+
"To change this limit, set the config variable\n",
|
| 403 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"Current values:\n",
|
| 406 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 407 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 408 |
+
"\n"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"name": "stdout",
|
| 413 |
+
"output_type": "stream",
|
| 414 |
+
"text": [
|
| 415 |
+
"Downloading: 55% [429015040 / 776454342] bytes"
|
| 416 |
+
]
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"name": "stderr",
|
| 420 |
+
"output_type": "stream",
|
| 421 |
+
"text": [
|
| 422 |
+
"IOPub message rate exceeded.\n",
|
| 423 |
+
"The notebook server will temporarily stop sending output\n",
|
| 424 |
+
"to the client in order to avoid crashing it.\n",
|
| 425 |
+
"To change this limit, set the config variable\n",
|
| 426 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"Current values:\n",
|
| 429 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 430 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 431 |
+
"\n"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"name": "stdout",
|
| 436 |
+
"output_type": "stream",
|
| 437 |
+
"text": [
|
| 438 |
+
"Downloading: 60% [470278144 / 776454342] bytes"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"name": "stderr",
|
| 443 |
+
"output_type": "stream",
|
| 444 |
+
"text": [
|
| 445 |
+
"IOPub message rate exceeded.\n",
|
| 446 |
+
"The notebook server will temporarily stop sending output\n",
|
| 447 |
+
"to the client in order to avoid crashing it.\n",
|
| 448 |
+
"To change this limit, set the config variable\n",
|
| 449 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 450 |
+
"\n",
|
| 451 |
+
"Current values:\n",
|
| 452 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 453 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 454 |
+
"\n"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"name": "stdout",
|
| 459 |
+
"output_type": "stream",
|
| 460 |
+
"text": [
|
| 461 |
+
"Downloading: 65% [507609088 / 776454342] bytes"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"name": "stderr",
|
| 466 |
+
"output_type": "stream",
|
| 467 |
+
"text": [
|
| 468 |
+
"IOPub message rate exceeded.\n",
|
| 469 |
+
"The notebook server will temporarily stop sending output\n",
|
| 470 |
+
"to the client in order to avoid crashing it.\n",
|
| 471 |
+
"To change this limit, set the config variable\n",
|
| 472 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"Current values:\n",
|
| 475 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 476 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 477 |
+
"\n"
|
| 478 |
+
]
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"name": "stdout",
|
| 482 |
+
"output_type": "stream",
|
| 483 |
+
"text": [
|
| 484 |
+
"Downloading: 70% [549601280 / 776454342] bytes"
|
| 485 |
+
]
|
| 486 |
+
},
|
| 487 |
+
{
|
| 488 |
+
"name": "stderr",
|
| 489 |
+
"output_type": "stream",
|
| 490 |
+
"text": [
|
| 491 |
+
"IOPub message rate exceeded.\n",
|
| 492 |
+
"The notebook server will temporarily stop sending output\n",
|
| 493 |
+
"to the client in order to avoid crashing it.\n",
|
| 494 |
+
"To change this limit, set the config variable\n",
|
| 495 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 496 |
+
"\n",
|
| 497 |
+
"Current values:\n",
|
| 498 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 499 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 500 |
+
"\n"
|
| 501 |
+
]
|
| 502 |
+
},
|
| 503 |
+
{
|
| 504 |
+
"name": "stdout",
|
| 505 |
+
"output_type": "stream",
|
| 506 |
+
"text": [
|
| 507 |
+
"Downloading: 75% [587497472 / 776454342] bytes"
|
| 508 |
+
]
|
| 509 |
+
},
|
| 510 |
+
{
|
| 511 |
+
"name": "stderr",
|
| 512 |
+
"output_type": "stream",
|
| 513 |
+
"text": [
|
| 514 |
+
"IOPub message rate exceeded.\n",
|
| 515 |
+
"The notebook server will temporarily stop sending output\n",
|
| 516 |
+
"to the client in order to avoid crashing it.\n",
|
| 517 |
+
"To change this limit, set the config variable\n",
|
| 518 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"Current values:\n",
|
| 521 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 522 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 523 |
+
"\n"
|
| 524 |
+
]
|
| 525 |
+
},
|
| 526 |
+
{
|
| 527 |
+
"name": "stdout",
|
| 528 |
+
"output_type": "stream",
|
| 529 |
+
"text": [
|
| 530 |
+
"Downloading: 80% [622919680 / 776454342] bytes"
|
| 531 |
+
]
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"name": "stderr",
|
| 535 |
+
"output_type": "stream",
|
| 536 |
+
"text": [
|
| 537 |
+
"IOPub message rate exceeded.\n",
|
| 538 |
+
"The notebook server will temporarily stop sending output\n",
|
| 539 |
+
"to the client in order to avoid crashing it.\n",
|
| 540 |
+
"To change this limit, set the config variable\n",
|
| 541 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 542 |
+
"\n",
|
| 543 |
+
"Current values:\n",
|
| 544 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 545 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 546 |
+
"\n"
|
| 547 |
+
]
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"name": "stdout",
|
| 551 |
+
"output_type": "stream",
|
| 552 |
+
"text": [
|
| 553 |
+
"Downloading: 86% [668491776 / 776454342] bytes"
|
| 554 |
+
]
|
| 555 |
+
},
|
| 556 |
+
{
|
| 557 |
+
"name": "stderr",
|
| 558 |
+
"output_type": "stream",
|
| 559 |
+
"text": [
|
| 560 |
+
"IOPub message rate exceeded.\n",
|
| 561 |
+
"The notebook server will temporarily stop sending output\n",
|
| 562 |
+
"to the client in order to avoid crashing it.\n",
|
| 563 |
+
"To change this limit, set the config variable\n",
|
| 564 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 565 |
+
"\n",
|
| 566 |
+
"Current values:\n",
|
| 567 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 568 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 569 |
+
"\n"
|
| 570 |
+
]
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"name": "stdout",
|
| 574 |
+
"output_type": "stream",
|
| 575 |
+
"text": [
|
| 576 |
+
"Downloading: 89% [698474496 / 776454342] bytes"
|
| 577 |
+
]
|
| 578 |
+
},
|
| 579 |
+
{
|
| 580 |
+
"name": "stderr",
|
| 581 |
+
"output_type": "stream",
|
| 582 |
+
"text": [
|
| 583 |
+
"IOPub message rate exceeded.\n",
|
| 584 |
+
"The notebook server will temporarily stop sending output\n",
|
| 585 |
+
"to the client in order to avoid crashing it.\n",
|
| 586 |
+
"To change this limit, set the config variable\n",
|
| 587 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 588 |
+
"\n",
|
| 589 |
+
"Current values:\n",
|
| 590 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 591 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 592 |
+
"\n"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"name": "stdout",
|
| 597 |
+
"output_type": "stream",
|
| 598 |
+
"text": [
|
| 599 |
+
"Downloading: 91% [713318400 / 776454342] bytes"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"name": "stderr",
|
| 604 |
+
"output_type": "stream",
|
| 605 |
+
"text": [
|
| 606 |
+
"IOPub message rate exceeded.\n",
|
| 607 |
+
"The notebook server will temporarily stop sending output\n",
|
| 608 |
+
"to the client in order to avoid crashing it.\n",
|
| 609 |
+
"To change this limit, set the config variable\n",
|
| 610 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 611 |
+
"\n",
|
| 612 |
+
"Current values:\n",
|
| 613 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 614 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 615 |
+
"\n"
|
| 616 |
+
]
|
| 617 |
+
},
|
| 618 |
+
{
|
| 619 |
+
"name": "stdout",
|
| 620 |
+
"output_type": "stream",
|
| 621 |
+
"text": [
|
| 622 |
+
"Downloading: 96% [746487808 / 776454342] bytes"
|
| 623 |
+
]
|
| 624 |
+
},
|
| 625 |
+
{
|
| 626 |
+
"name": "stderr",
|
| 627 |
+
"output_type": "stream",
|
| 628 |
+
"text": [
|
| 629 |
+
"IOPub message rate exceeded.\n",
|
| 630 |
+
"The notebook server will temporarily stop sending output\n",
|
| 631 |
+
"to the client in order to avoid crashing it.\n",
|
| 632 |
+
"To change this limit, set the config variable\n",
|
| 633 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"Current values:\n",
|
| 636 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 637 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 638 |
+
"\n"
|
| 639 |
+
]
|
| 640 |
+
},
|
| 641 |
+
{
|
| 642 |
+
"name": "stdout",
|
| 643 |
+
"output_type": "stream",
|
| 644 |
+
"text": [
|
| 645 |
+
"Downloading: 100% [776454342 / 776454342] bytes ____ ____ ___ ____ _ _ _ _ _ \n",
|
| 646 |
+
" / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
|
| 647 |
+
"| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
|
| 648 |
+
"| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
|
| 649 |
+
" \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
|
| 650 |
+
" \n",
|
| 651 |
+
"\n",
|
| 652 |
+
"The phase identification module of WPEM\n",
|
| 653 |
+
"URL : https://github.com/WPEM/CPICANN\n",
|
| 654 |
+
"Executed on : 2024-04-21 14:14:25 | Have a great day.\n",
|
| 655 |
+
"================================================================================\n",
|
| 656 |
+
"loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_back3.pth\n",
|
| 657 |
+
"\n",
|
| 658 |
+
">>>>>> RUNNING: ./testdata/.DS_Store\n",
|
| 659 |
+
"\n",
|
| 660 |
+
">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
|
| 661 |
+
"pred cls_id : 2475 confidence : 98.89%\n",
|
| 662 |
+
"pred cod_id : 9009622 formula : Pb2 S2 O6\n",
|
| 663 |
+
"pred space group No: 11 space group : P2_1/m\n",
|
| 664 |
+
"\n",
|
| 665 |
+
"inference result saved in infResults_testdata.csv\n",
|
| 666 |
+
"inference figures saved at figs/\n",
|
| 667 |
+
"THE END\n"
|
| 668 |
+
]
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"data": {
|
| 672 |
+
"text/plain": [
|
| 673 |
+
"True"
|
| 674 |
+
]
|
| 675 |
+
},
|
| 676 |
+
"execution_count": 4,
|
| 677 |
+
"metadata": {},
|
| 678 |
+
"output_type": "execute_result"
|
| 679 |
+
}
|
| 680 |
+
],
|
| 681 |
+
"source": [
|
| 682 |
+
"# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"CPICANN.PhaseIdentifier(FilePath='./testdata',Model='bca_model',Task='single-phase',Device='cpu',)"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"cell_type": "code",
|
| 689 |
+
"execution_count": 7,
|
| 690 |
+
"id": "304b62b5",
|
| 691 |
+
"metadata": {},
|
| 692 |
+
"outputs": [
|
| 693 |
+
{
|
| 694 |
+
"name": "stdout",
|
| 695 |
+
"output_type": "stream",
|
| 696 |
+
"text": [
|
| 697 |
+
" ____ ____ ___ ____ _ _ _ _ _ \n",
|
| 698 |
+
" / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
|
| 699 |
+
"| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
|
| 700 |
+
"| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
|
| 701 |
+
" \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
|
| 702 |
+
" \n",
|
| 703 |
+
"\n",
|
| 704 |
+
"The phase identification module of WPEM\n",
|
| 705 |
+
"URL : https://github.com/WPEM/CPICANN\n",
|
| 706 |
+
"Executed on : 2024-04-21 14:14:53 | Have a great day.\n",
|
| 707 |
+
"================================================================================\n",
|
| 708 |
+
"loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_noise3.pth\n",
|
| 709 |
+
"\n",
|
| 710 |
+
">>>>>> RUNNING: ./testdata/.DS_Store\n",
|
| 711 |
+
"\n",
|
| 712 |
+
">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
|
| 713 |
+
"pred cls_id : 3378 confidence : 100.00%\n",
|
| 714 |
+
"pred cod_id : 9004484 formula : Pb4 S4 O16\n",
|
| 715 |
+
"pred space group No: 62 space group : Pnma\n",
|
| 716 |
+
"\n",
|
| 717 |
+
"inference result saved in infResults_testdata.csv\n",
|
| 718 |
+
"inference figures saved at figs/\n",
|
| 719 |
+
"THE END\n"
|
| 720 |
+
]
|
| 721 |
+
},
|
| 722 |
+
{
|
| 723 |
+
"data": {
|
| 724 |
+
"text/plain": [
|
| 725 |
+
"True"
|
| 726 |
+
]
|
| 727 |
+
},
|
| 728 |
+
"execution_count": 7,
|
| 729 |
+
"metadata": {},
|
| 730 |
+
"output_type": "execute_result"
|
| 731 |
+
}
|
| 732 |
+
],
|
| 733 |
+
"source": [
|
| 734 |
+
"from WPEMPhase import CPICANN\n",
|
| 735 |
+
"# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
|
| 736 |
+
"\n",
|
| 737 |
+
"CPICANN.PhaseIdentifier(FilePath='./testdata',Model='noise_model',Task='single-phase',ElementsContained='Pb_S_O',Device='cpu',)"
|
| 738 |
+
]
|
| 739 |
+
},
|
| 740 |
+
{
|
| 741 |
+
"cell_type": "raw",
|
| 742 |
+
"id": "977ede12",
|
| 743 |
+
"metadata": {},
|
| 744 |
+
"source": [
|
| 745 |
+
"For inquiries or assistance, please don't hesitate to contact us at bcao686@connect.hkust-gz.edu.cn (Dr. CAO Bin)."
|
| 746 |
+
]
|
| 747 |
+
},
|
| 748 |
+
{
|
| 749 |
+
"cell_type": "code",
|
| 750 |
+
"execution_count": null,
|
| 751 |
+
"id": "bb101480",
|
| 752 |
+
"metadata": {},
|
| 753 |
+
"outputs": [],
|
| 754 |
+
"source": []
|
| 755 |
+
}
|
| 756 |
+
],
|
| 757 |
+
"metadata": {
|
| 758 |
+
"kernelspec": {
|
| 759 |
+
"display_name": "Python 3 (ipykernel)",
|
| 760 |
+
"language": "python",
|
| 761 |
+
"name": "python3"
|
| 762 |
+
},
|
| 763 |
+
"language_info": {
|
| 764 |
+
"codemirror_mode": {
|
| 765 |
+
"name": "ipython",
|
| 766 |
+
"version": 3
|
| 767 |
+
},
|
| 768 |
+
"file_extension": ".py",
|
| 769 |
+
"mimetype": "text/x-python",
|
| 770 |
+
"name": "python",
|
| 771 |
+
"nbconvert_exporter": "python",
|
| 772 |
+
"pygments_lexer": "ipython3",
|
| 773 |
+
"version": "3.9.12"
|
| 774 |
+
}
|
| 775 |
+
},
|
| 776 |
+
"nbformat": 4,
|
| 777 |
+
"nbformat_minor": 5
|
| 778 |
+
}
|
src/inference&case/CPICANNcode.ipynb
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "f2299629",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# It is a template for applying CPICANN to X-ray powder diffraction phase identification"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "markdown",
|
| 13 |
+
"id": "cad44131",
|
| 14 |
+
"metadata": {},
|
| 15 |
+
"source": [
|
| 16 |
+
"### 1: install WPEMPhase package "
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"id": "af59fa9d",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": [
|
| 24 |
+
"pip install WPEMPhase"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"id": "7d9d6d02",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"source": [
|
| 32 |
+
"### 2: The first time you execute CPICANN on your computer, you should initialize the system documents. After that, you do not need to do any additional execution to run CPICANN."
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "raw",
|
| 37 |
+
"id": "4787c1b4",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"source": [
|
| 40 |
+
"Signature:\n",
|
| 41 |
+
"CPICANN.PhaseIdentifier(\n",
|
| 42 |
+
" FilePath,\n",
|
| 43 |
+
" Task='single-phase',\n",
|
| 44 |
+
" Model='default',\n",
|
| 45 |
+
" ElementsSystem='',\n",
|
| 46 |
+
" ElementsContained='',\n",
|
| 47 |
+
" Device='cuda:0',\n",
|
| 48 |
+
")\n",
|
| 49 |
+
"Docstring:\n",
|
| 50 |
+
"CPICANN : Crystallographic Phase Identifier of Convolutional self-Attention Neural Network\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"Contributors : Shouyang Zhang & Bin Cao\n",
|
| 53 |
+
"================================================================\n",
|
| 54 |
+
" Please feel free to open issues in the Github :\n",
|
| 55 |
+
" https://github.com/WPEM/CPICANN\n",
|
| 56 |
+
" or\n",
|
| 57 |
+
" contact Mr.Bin Cao (bcao686@connect.hkust-gz.edu.cn)\n",
|
| 58 |
+
" in case of any problems/comments/suggestions in using the code.\n",
|
| 59 |
+
"==================================================================\n",
|
| 60 |
+
"\n",
|
| 61 |
+
":param FilePath\n",
|
| 62 |
+
"\n",
|
| 63 |
+
":param Task, type=str, default='single-phase'\n",
|
| 64 |
+
" if Task = 'single-phase', CPICANN executes a single phase identification task\n",
|
| 65 |
+
" if Task = 'di-phase', CPICANN executes a dual phase identification task\n",
|
| 66 |
+
"\n",
|
| 67 |
+
":param Model, type=str, default='default'\n",
|
| 68 |
+
" if Model = 'noise_model', CPICANN executes a single phase identification by noise-contained model\n",
|
| 69 |
+
" if Model = 'bca_model', CPICANN executes a single phase identification by background-contained model\n",
|
| 70 |
+
"\n",
|
| 71 |
+
":param ElementsSystem, type=str, default=''\n",
|
| 72 |
+
" Specifies the elements to be included at least in the prediction, example: 'Fe'.\n",
|
| 73 |
+
"\n",
|
| 74 |
+
":param ElementsContained, type=str, default=''\n",
|
| 75 |
+
" Specifies the elements to be included, with at least one of them in the prediction, example: 'O_C_S'.\n",
|
| 76 |
+
"\n",
|
| 77 |
+
":param Device, type=str, default='cuda:0',\n",
|
| 78 |
+
" Which device to run the CPICANN, example: 'cuda:0', 'cpu'.\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"examples:\n",
|
| 81 |
+
"from WPEMPhase import CPICANN\n",
|
| 82 |
+
"CPICANN.PhaseIdentifier(FilePath='./single-phase',Device='cpu')\n",
|
| 83 |
+
"File: ~/miniconda3/lib/python3.9/site-packages/WPEMPhase/CPICANN.py\n",
|
| 84 |
+
"Type: function"
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"cell_type": "code",
|
| 89 |
+
"execution_count": 2,
|
| 90 |
+
"id": "0922c99d",
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [
|
| 93 |
+
{
|
| 94 |
+
"name": "stdout",
|
| 95 |
+
"output_type": "stream",
|
| 96 |
+
"text": [
|
| 97 |
+
"Collecting WPEMPhase\n",
|
| 98 |
+
" Downloading WPEMPhase-0.1.0-py3-none-any.whl.metadata (1.0 kB)\n",
|
| 99 |
+
"Requirement already satisfied: torch in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2.0.0)\n",
|
| 100 |
+
"Requirement already satisfied: plot in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (0.6.5)\n",
|
| 101 |
+
"Requirement already satisfied: scipy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.9.3)\n",
|
| 102 |
+
"Requirement already satisfied: pandas in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.5.1)\n",
|
| 103 |
+
"Requirement already satisfied: numpy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (1.23.3)\n",
|
| 104 |
+
"Requirement already satisfied: art in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (6.1)\n",
|
| 105 |
+
"Requirement already satisfied: pymatgen in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (2023.3.23)\n",
|
| 106 |
+
"Requirement already satisfied: wget in /Users/jacob/miniconda3/lib/python3.9/site-packages (from WPEMPhase) (3.2)\n",
|
| 107 |
+
"Requirement already satisfied: python-dateutil>=2.8.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2.8.2)\n",
|
| 108 |
+
"Requirement already satisfied: pytz>=2020.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pandas->WPEMPhase) (2022.5)\n",
|
| 109 |
+
"Requirement already satisfied: matplotlib in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.1)\n",
|
| 110 |
+
"Requirement already satisfied: typing in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (3.7.4.3)\n",
|
| 111 |
+
"Requirement already satisfied: pyyaml in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plot->WPEMPhase) (6.0)\n",
|
| 112 |
+
"Requirement already satisfied: monty>=3.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2023.4.10)\n",
|
| 113 |
+
"Requirement already satisfied: mp-api>=0.27.3 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.31.2)\n",
|
| 114 |
+
"Requirement already satisfied: networkx>=2.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.8.8)\n",
|
| 115 |
+
"Requirement already satisfied: palettable>=3.1.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.3.3)\n",
|
| 116 |
+
"Requirement already satisfied: plotly>=4.5.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (5.14.1)\n",
|
| 117 |
+
"Requirement already satisfied: pybtex in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.24.0)\n",
|
| 118 |
+
"Requirement already satisfied: requests in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.28.2)\n",
|
| 119 |
+
"Requirement already satisfied: ruamel.yaml>=0.17.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.17.21)\n",
|
| 120 |
+
"Requirement already satisfied: spglib>=2.0.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (2.0.2)\n",
|
| 121 |
+
"Requirement already satisfied: sympy in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (1.11.1)\n",
|
| 122 |
+
"Requirement already satisfied: tabulate in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (0.9.0)\n",
|
| 123 |
+
"Requirement already satisfied: tqdm in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (4.66.1)\n",
|
| 124 |
+
"Requirement already satisfied: uncertainties>=3.1.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pymatgen->WPEMPhase) (3.1.7)\n",
|
| 125 |
+
"Requirement already satisfied: filelock in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.10.7)\n",
|
| 126 |
+
"Requirement already satisfied: typing-extensions in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (4.11.0)\n",
|
| 127 |
+
"Requirement already satisfied: jinja2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from torch->WPEMPhase) (3.1.2)\n",
|
| 128 |
+
"Requirement already satisfied: contourpy>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.0.5)\n",
|
| 129 |
+
"Requirement already satisfied: cycler>=0.10 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (0.11.0)\n",
|
| 130 |
+
"Requirement already satisfied: fonttools>=4.22.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (4.38.0)\n",
|
| 131 |
+
"Requirement already satisfied: kiwisolver>=1.0.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (1.4.4)\n",
|
| 132 |
+
"Requirement already satisfied: packaging>=20.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (23.0)\n",
|
| 133 |
+
"Requirement already satisfied: pillow>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (9.5.0)\n",
|
| 134 |
+
"Requirement already satisfied: pyparsing>=2.3.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (3.0.9)\n",
|
| 135 |
+
"Requirement already satisfied: importlib-resources>=3.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from matplotlib->plot->WPEMPhase) (5.12.0)\n",
|
| 136 |
+
"Requirement already satisfied: setuptools in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (67.6.1)\n",
|
| 137 |
+
"Requirement already satisfied: msgpack in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (1.0.5)\n",
|
| 138 |
+
"Requirement already satisfied: emmet-core<=0.50.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from mp-api>=0.27.3->pymatgen->WPEMPhase) (0.50.0)\n",
|
| 139 |
+
"Requirement already satisfied: tenacity>=6.2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from plotly>=4.5.0->pymatgen->WPEMPhase) (8.2.2)\n",
|
| 140 |
+
"Requirement already satisfied: six>=1.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas->WPEMPhase) (1.16.0)\n",
|
| 141 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2.0.4)\n",
|
| 142 |
+
"Requirement already satisfied: idna<4,>=2.5 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (3.3)\n",
|
| 143 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (1.26.9)\n",
|
| 144 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from requests->pymatgen->WPEMPhase) (2022.12.7)\n",
|
| 145 |
+
"Requirement already satisfied: ruamel.yaml.clib>=0.2.6 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from ruamel.yaml>=0.17.0->pymatgen->WPEMPhase) (0.2.6)\n",
|
| 146 |
+
"Requirement already satisfied: future in /Users/jacob/miniconda3/lib/python3.9/site-packages (from uncertainties>=3.1.4->pymatgen->WPEMPhase) (0.18.3)\n",
|
| 147 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from jinja2->torch->WPEMPhase) (2.1.1)\n",
|
| 148 |
+
"Requirement already satisfied: latexcodec>=1.0.4 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from pybtex->pymatgen->WPEMPhase) (2.0.1)\n",
|
| 149 |
+
"Requirement already satisfied: mpmath>=0.19 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from sympy->pymatgen->WPEMPhase) (1.3.0)\n",
|
| 150 |
+
"Requirement already satisfied: pydantic>=1.10.2 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from emmet-core<=0.50.0->mp-api>=0.27.3->pymatgen->WPEMPhase) (1.10.7)\n",
|
| 151 |
+
"Requirement already satisfied: zipp>=3.1.0 in /Users/jacob/miniconda3/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib->plot->WPEMPhase) (3.9.0)\n",
|
| 152 |
+
"Downloading WPEMPhase-0.1.0-py3-none-any.whl (710 kB)\n",
|
| 153 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m710.2/710.2 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
|
| 154 |
+
"\u001b[?25hInstalling collected packages: WPEMPhase\n",
|
| 155 |
+
"Successfully installed WPEMPhase-0.1.0\n",
|
| 156 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
],
|
| 160 |
+
"source": [
|
| 161 |
+
"pip install WPEMPhase"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": 3,
|
| 167 |
+
"id": "8e1680a6",
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"from WPEMPhase import CPICANN"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"cell_type": "code",
|
| 176 |
+
"execution_count": 4,
|
| 177 |
+
"id": "7625eb66",
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [
|
| 180 |
+
{
|
| 181 |
+
"name": "stdout",
|
| 182 |
+
"output_type": "stream",
|
| 183 |
+
"text": [
|
| 184 |
+
"This is the first time CPICANN is being executed on your computer, configuring...\n",
|
| 185 |
+
"Downloading: 3% [24690688 / 776454342] bytes"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"name": "stderr",
|
| 190 |
+
"output_type": "stream",
|
| 191 |
+
"text": [
|
| 192 |
+
"IOPub message rate exceeded.\n",
|
| 193 |
+
"The notebook server will temporarily stop sending output\n",
|
| 194 |
+
"to the client in order to avoid crashing it.\n",
|
| 195 |
+
"To change this limit, set the config variable\n",
|
| 196 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"Current values:\n",
|
| 199 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 200 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 201 |
+
"\n"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"name": "stdout",
|
| 206 |
+
"output_type": "stream",
|
| 207 |
+
"text": [
|
| 208 |
+
"Downloading: 7% [61341696 / 776454342] bytes"
|
| 209 |
+
]
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"name": "stderr",
|
| 213 |
+
"output_type": "stream",
|
| 214 |
+
"text": [
|
| 215 |
+
"IOPub message rate exceeded.\n",
|
| 216 |
+
"The notebook server will temporarily stop sending output\n",
|
| 217 |
+
"to the client in order to avoid crashing it.\n",
|
| 218 |
+
"To change this limit, set the config variable\n",
|
| 219 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"Current values:\n",
|
| 222 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 223 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 224 |
+
"\n"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"name": "stdout",
|
| 229 |
+
"output_type": "stream",
|
| 230 |
+
"text": [
|
| 231 |
+
"Downloading: 13% [107954176 / 776454342] bytes"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"name": "stderr",
|
| 236 |
+
"output_type": "stream",
|
| 237 |
+
"text": [
|
| 238 |
+
"IOPub message rate exceeded.\n",
|
| 239 |
+
"The notebook server will temporarily stop sending output\n",
|
| 240 |
+
"to the client in order to avoid crashing it.\n",
|
| 241 |
+
"To change this limit, set the config variable\n",
|
| 242 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"Current values:\n",
|
| 245 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 246 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 247 |
+
"\n"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"name": "stdout",
|
| 252 |
+
"output_type": "stream",
|
| 253 |
+
"text": [
|
| 254 |
+
"Downloading: 19% [148324352 / 776454342] bytes"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"name": "stderr",
|
| 259 |
+
"output_type": "stream",
|
| 260 |
+
"text": [
|
| 261 |
+
"IOPub message rate exceeded.\n",
|
| 262 |
+
"The notebook server will temporarily stop sending output\n",
|
| 263 |
+
"to the client in order to avoid crashing it.\n",
|
| 264 |
+
"To change this limit, set the config variable\n",
|
| 265 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"Current values:\n",
|
| 268 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 269 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 270 |
+
"\n"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"name": "stdout",
|
| 275 |
+
"output_type": "stream",
|
| 276 |
+
"text": [
|
| 277 |
+
"Downloading: 24% [189382656 / 776454342] bytes"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"name": "stderr",
|
| 282 |
+
"output_type": "stream",
|
| 283 |
+
"text": [
|
| 284 |
+
"IOPub message rate exceeded.\n",
|
| 285 |
+
"The notebook server will temporarily stop sending output\n",
|
| 286 |
+
"to the client in order to avoid crashing it.\n",
|
| 287 |
+
"To change this limit, set the config variable\n",
|
| 288 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"Current values:\n",
|
| 291 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 292 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 293 |
+
"\n"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"name": "stdout",
|
| 298 |
+
"output_type": "stream",
|
| 299 |
+
"text": [
|
| 300 |
+
"Downloading: 28% [221265920 / 776454342] bytes"
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"name": "stderr",
|
| 305 |
+
"output_type": "stream",
|
| 306 |
+
"text": [
|
| 307 |
+
"IOPub message rate exceeded.\n",
|
| 308 |
+
"The notebook server will temporarily stop sending output\n",
|
| 309 |
+
"to the client in order to avoid crashing it.\n",
|
| 310 |
+
"To change this limit, set the config variable\n",
|
| 311 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"Current values:\n",
|
| 314 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 315 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 316 |
+
"\n"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"name": "stdout",
|
| 321 |
+
"output_type": "stream",
|
| 322 |
+
"text": [
|
| 323 |
+
"Downloading: 33% [262488064 / 776454342] bytes"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"name": "stderr",
|
| 328 |
+
"output_type": "stream",
|
| 329 |
+
"text": [
|
| 330 |
+
"IOPub message rate exceeded.\n",
|
| 331 |
+
"The notebook server will temporarily stop sending output\n",
|
| 332 |
+
"to the client in order to avoid crashing it.\n",
|
| 333 |
+
"To change this limit, set the config variable\n",
|
| 334 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"Current values:\n",
|
| 337 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 338 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 339 |
+
"\n"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"name": "stdout",
|
| 344 |
+
"output_type": "stream",
|
| 345 |
+
"text": [
|
| 346 |
+
"Downloading: 39% [304799744 / 776454342] bytes"
|
| 347 |
+
]
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"name": "stderr",
|
| 351 |
+
"output_type": "stream",
|
| 352 |
+
"text": [
|
| 353 |
+
"IOPub message rate exceeded.\n",
|
| 354 |
+
"The notebook server will temporarily stop sending output\n",
|
| 355 |
+
"to the client in order to avoid crashing it.\n",
|
| 356 |
+
"To change this limit, set the config variable\n",
|
| 357 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 358 |
+
"\n",
|
| 359 |
+
"Current values:\n",
|
| 360 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 361 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 362 |
+
"\n"
|
| 363 |
+
]
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"name": "stdout",
|
| 367 |
+
"output_type": "stream",
|
| 368 |
+
"text": [
|
| 369 |
+
"Downloading: 44% [346030080 / 776454342] bytes"
|
| 370 |
+
]
|
| 371 |
+
},
|
| 372 |
+
{
|
| 373 |
+
"name": "stderr",
|
| 374 |
+
"output_type": "stream",
|
| 375 |
+
"text": [
|
| 376 |
+
"IOPub message rate exceeded.\n",
|
| 377 |
+
"The notebook server will temporarily stop sending output\n",
|
| 378 |
+
"to the client in order to avoid crashing it.\n",
|
| 379 |
+
"To change this limit, set the config variable\n",
|
| 380 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"Current values:\n",
|
| 383 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 384 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 385 |
+
"\n"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"name": "stdout",
|
| 390 |
+
"output_type": "stream",
|
| 391 |
+
"text": [
|
| 392 |
+
"Downloading: 50% [388333568 / 776454342] bytes"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"name": "stderr",
|
| 397 |
+
"output_type": "stream",
|
| 398 |
+
"text": [
|
| 399 |
+
"IOPub message rate exceeded.\n",
|
| 400 |
+
"The notebook server will temporarily stop sending output\n",
|
| 401 |
+
"to the client in order to avoid crashing it.\n",
|
| 402 |
+
"To change this limit, set the config variable\n",
|
| 403 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"Current values:\n",
|
| 406 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 407 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 408 |
+
"\n"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"name": "stdout",
|
| 413 |
+
"output_type": "stream",
|
| 414 |
+
"text": [
|
| 415 |
+
"Downloading: 55% [429015040 / 776454342] bytes"
|
| 416 |
+
]
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"name": "stderr",
|
| 420 |
+
"output_type": "stream",
|
| 421 |
+
"text": [
|
| 422 |
+
"IOPub message rate exceeded.\n",
|
| 423 |
+
"The notebook server will temporarily stop sending output\n",
|
| 424 |
+
"to the client in order to avoid crashing it.\n",
|
| 425 |
+
"To change this limit, set the config variable\n",
|
| 426 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"Current values:\n",
|
| 429 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 430 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 431 |
+
"\n"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"name": "stdout",
|
| 436 |
+
"output_type": "stream",
|
| 437 |
+
"text": [
|
| 438 |
+
"Downloading: 60% [470278144 / 776454342] bytes"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"name": "stderr",
|
| 443 |
+
"output_type": "stream",
|
| 444 |
+
"text": [
|
| 445 |
+
"IOPub message rate exceeded.\n",
|
| 446 |
+
"The notebook server will temporarily stop sending output\n",
|
| 447 |
+
"to the client in order to avoid crashing it.\n",
|
| 448 |
+
"To change this limit, set the config variable\n",
|
| 449 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 450 |
+
"\n",
|
| 451 |
+
"Current values:\n",
|
| 452 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 453 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 454 |
+
"\n"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"name": "stdout",
|
| 459 |
+
"output_type": "stream",
|
| 460 |
+
"text": [
|
| 461 |
+
"Downloading: 65% [507609088 / 776454342] bytes"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"name": "stderr",
|
| 466 |
+
"output_type": "stream",
|
| 467 |
+
"text": [
|
| 468 |
+
"IOPub message rate exceeded.\n",
|
| 469 |
+
"The notebook server will temporarily stop sending output\n",
|
| 470 |
+
"to the client in order to avoid crashing it.\n",
|
| 471 |
+
"To change this limit, set the config variable\n",
|
| 472 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"Current values:\n",
|
| 475 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 476 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 477 |
+
"\n"
|
| 478 |
+
]
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"name": "stdout",
|
| 482 |
+
"output_type": "stream",
|
| 483 |
+
"text": [
|
| 484 |
+
"Downloading: 70% [549601280 / 776454342] bytes"
|
| 485 |
+
]
|
| 486 |
+
},
|
| 487 |
+
{
|
| 488 |
+
"name": "stderr",
|
| 489 |
+
"output_type": "stream",
|
| 490 |
+
"text": [
|
| 491 |
+
"IOPub message rate exceeded.\n",
|
| 492 |
+
"The notebook server will temporarily stop sending output\n",
|
| 493 |
+
"to the client in order to avoid crashing it.\n",
|
| 494 |
+
"To change this limit, set the config variable\n",
|
| 495 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 496 |
+
"\n",
|
| 497 |
+
"Current values:\n",
|
| 498 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 499 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 500 |
+
"\n"
|
| 501 |
+
]
|
| 502 |
+
},
|
| 503 |
+
{
|
| 504 |
+
"name": "stdout",
|
| 505 |
+
"output_type": "stream",
|
| 506 |
+
"text": [
|
| 507 |
+
"Downloading: 75% [587497472 / 776454342] bytes"
|
| 508 |
+
]
|
| 509 |
+
},
|
| 510 |
+
{
|
| 511 |
+
"name": "stderr",
|
| 512 |
+
"output_type": "stream",
|
| 513 |
+
"text": [
|
| 514 |
+
"IOPub message rate exceeded.\n",
|
| 515 |
+
"The notebook server will temporarily stop sending output\n",
|
| 516 |
+
"to the client in order to avoid crashing it.\n",
|
| 517 |
+
"To change this limit, set the config variable\n",
|
| 518 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"Current values:\n",
|
| 521 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 522 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 523 |
+
"\n"
|
| 524 |
+
]
|
| 525 |
+
},
|
| 526 |
+
{
|
| 527 |
+
"name": "stdout",
|
| 528 |
+
"output_type": "stream",
|
| 529 |
+
"text": [
|
| 530 |
+
"Downloading: 80% [622919680 / 776454342] bytes"
|
| 531 |
+
]
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"name": "stderr",
|
| 535 |
+
"output_type": "stream",
|
| 536 |
+
"text": [
|
| 537 |
+
"IOPub message rate exceeded.\n",
|
| 538 |
+
"The notebook server will temporarily stop sending output\n",
|
| 539 |
+
"to the client in order to avoid crashing it.\n",
|
| 540 |
+
"To change this limit, set the config variable\n",
|
| 541 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 542 |
+
"\n",
|
| 543 |
+
"Current values:\n",
|
| 544 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 545 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 546 |
+
"\n"
|
| 547 |
+
]
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"name": "stdout",
|
| 551 |
+
"output_type": "stream",
|
| 552 |
+
"text": [
|
| 553 |
+
"Downloading: 86% [668491776 / 776454342] bytes"
|
| 554 |
+
]
|
| 555 |
+
},
|
| 556 |
+
{
|
| 557 |
+
"name": "stderr",
|
| 558 |
+
"output_type": "stream",
|
| 559 |
+
"text": [
|
| 560 |
+
"IOPub message rate exceeded.\n",
|
| 561 |
+
"The notebook server will temporarily stop sending output\n",
|
| 562 |
+
"to the client in order to avoid crashing it.\n",
|
| 563 |
+
"To change this limit, set the config variable\n",
|
| 564 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 565 |
+
"\n",
|
| 566 |
+
"Current values:\n",
|
| 567 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 568 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 569 |
+
"\n"
|
| 570 |
+
]
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"name": "stdout",
|
| 574 |
+
"output_type": "stream",
|
| 575 |
+
"text": [
|
| 576 |
+
"Downloading: 89% [698474496 / 776454342] bytes"
|
| 577 |
+
]
|
| 578 |
+
},
|
| 579 |
+
{
|
| 580 |
+
"name": "stderr",
|
| 581 |
+
"output_type": "stream",
|
| 582 |
+
"text": [
|
| 583 |
+
"IOPub message rate exceeded.\n",
|
| 584 |
+
"The notebook server will temporarily stop sending output\n",
|
| 585 |
+
"to the client in order to avoid crashing it.\n",
|
| 586 |
+
"To change this limit, set the config variable\n",
|
| 587 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 588 |
+
"\n",
|
| 589 |
+
"Current values:\n",
|
| 590 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 591 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 592 |
+
"\n"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"name": "stdout",
|
| 597 |
+
"output_type": "stream",
|
| 598 |
+
"text": [
|
| 599 |
+
"Downloading: 91% [713318400 / 776454342] bytes"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"name": "stderr",
|
| 604 |
+
"output_type": "stream",
|
| 605 |
+
"text": [
|
| 606 |
+
"IOPub message rate exceeded.\n",
|
| 607 |
+
"The notebook server will temporarily stop sending output\n",
|
| 608 |
+
"to the client in order to avoid crashing it.\n",
|
| 609 |
+
"To change this limit, set the config variable\n",
|
| 610 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 611 |
+
"\n",
|
| 612 |
+
"Current values:\n",
|
| 613 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 614 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 615 |
+
"\n"
|
| 616 |
+
]
|
| 617 |
+
},
|
| 618 |
+
{
|
| 619 |
+
"name": "stdout",
|
| 620 |
+
"output_type": "stream",
|
| 621 |
+
"text": [
|
| 622 |
+
"Downloading: 96% [746487808 / 776454342] bytes"
|
| 623 |
+
]
|
| 624 |
+
},
|
| 625 |
+
{
|
| 626 |
+
"name": "stderr",
|
| 627 |
+
"output_type": "stream",
|
| 628 |
+
"text": [
|
| 629 |
+
"IOPub message rate exceeded.\n",
|
| 630 |
+
"The notebook server will temporarily stop sending output\n",
|
| 631 |
+
"to the client in order to avoid crashing it.\n",
|
| 632 |
+
"To change this limit, set the config variable\n",
|
| 633 |
+
"`--NotebookApp.iopub_msg_rate_limit`.\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"Current values:\n",
|
| 636 |
+
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
| 637 |
+
"NotebookApp.rate_limit_window=3.0 (secs)\n",
|
| 638 |
+
"\n"
|
| 639 |
+
]
|
| 640 |
+
},
|
| 641 |
+
{
|
| 642 |
+
"name": "stdout",
|
| 643 |
+
"output_type": "stream",
|
| 644 |
+
"text": [
|
| 645 |
+
"Downloading: 100% [776454342 / 776454342] bytes ____ ____ ___ ____ _ _ _ _ _ \n",
|
| 646 |
+
" / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
|
| 647 |
+
"| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
|
| 648 |
+
"| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
|
| 649 |
+
" \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
|
| 650 |
+
" \n",
|
| 651 |
+
"\n",
|
| 652 |
+
"The phase identification module of WPEM\n",
|
| 653 |
+
"URL : https://github.com/WPEM/CPICANN\n",
|
| 654 |
+
"Executed on : 2024-04-21 14:14:25 | Have a great day.\n",
|
| 655 |
+
"================================================================================\n",
|
| 656 |
+
"loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_back3.pth\n",
|
| 657 |
+
"\n",
|
| 658 |
+
">>>>>> RUNNING: ./testdata/.DS_Store\n",
|
| 659 |
+
"\n",
|
| 660 |
+
">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
|
| 661 |
+
"pred cls_id : 2475 confidence : 98.89%\n",
|
| 662 |
+
"pred cod_id : 9009622 formula : Pb2 S2 O6\n",
|
| 663 |
+
"pred space group No: 11 space group : P2_1/m\n",
|
| 664 |
+
"\n",
|
| 665 |
+
"inference result saved in infResults_testdata.csv\n",
|
| 666 |
+
"inference figures saved at figs/\n",
|
| 667 |
+
"THE END\n"
|
| 668 |
+
]
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"data": {
|
| 672 |
+
"text/plain": [
|
| 673 |
+
"True"
|
| 674 |
+
]
|
| 675 |
+
},
|
| 676 |
+
"execution_count": 4,
|
| 677 |
+
"metadata": {},
|
| 678 |
+
"output_type": "execute_result"
|
| 679 |
+
}
|
| 680 |
+
],
|
| 681 |
+
"source": [
|
| 682 |
+
"# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"CPICANN.PhaseIdentifier(FilePath='./testdata',Model='bca_model',Task='single-phase',Device='cpu',)"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"cell_type": "code",
|
| 689 |
+
"execution_count": 7,
|
| 690 |
+
"id": "304b62b5",
|
| 691 |
+
"metadata": {},
|
| 692 |
+
"outputs": [
|
| 693 |
+
{
|
| 694 |
+
"name": "stdout",
|
| 695 |
+
"output_type": "stream",
|
| 696 |
+
"text": [
|
| 697 |
+
" ____ ____ ___ ____ _ _ _ _ _ \n",
|
| 698 |
+
" / ___|| _ \\ |_ _| / ___| / \\ | \\ | || \\ | |\n",
|
| 699 |
+
"| | | |_) | | | | | / _ \\ | \\| || \\| |\n",
|
| 700 |
+
"| |___ | __/ | | | |___ / ___ \\ | |\\ || |\\ |\n",
|
| 701 |
+
" \\____||_| |___| \\____|/_/ \\_\\|_| \\_||_| \\_|\n",
|
| 702 |
+
" \n",
|
| 703 |
+
"\n",
|
| 704 |
+
"The phase identification module of WPEM\n",
|
| 705 |
+
"URL : https://github.com/WPEM/CPICANN\n",
|
| 706 |
+
"Executed on : 2024-04-21 14:14:53 | Have a great day.\n",
|
| 707 |
+
"================================================================================\n",
|
| 708 |
+
"loaded model from /Users/jacob/miniconda3/lib/python3.9/site-packages/WPEMPhase/pretrained/CPICANN_single-phase_noise3.pth\n",
|
| 709 |
+
"\n",
|
| 710 |
+
">>>>>> RUNNING: ./testdata/.DS_Store\n",
|
| 711 |
+
"\n",
|
| 712 |
+
">>>>>> RUNNING: ./testdata/PbSO4.csv\n",
|
| 713 |
+
"pred cls_id : 3378 confidence : 100.00%\n",
|
| 714 |
+
"pred cod_id : 9004484 formula : Pb4 S4 O16\n",
|
| 715 |
+
"pred space group No: 62 space group : Pnma\n",
|
| 716 |
+
"\n",
|
| 717 |
+
"inference result saved in infResults_testdata.csv\n",
|
| 718 |
+
"inference figures saved at figs/\n",
|
| 719 |
+
"THE END\n"
|
| 720 |
+
]
|
| 721 |
+
},
|
| 722 |
+
{
|
| 723 |
+
"data": {
|
| 724 |
+
"text/plain": [
|
| 725 |
+
"True"
|
| 726 |
+
]
|
| 727 |
+
},
|
| 728 |
+
"execution_count": 7,
|
| 729 |
+
"metadata": {},
|
| 730 |
+
"output_type": "execute_result"
|
| 731 |
+
}
|
| 732 |
+
],
|
| 733 |
+
"source": [
|
| 734 |
+
"from WPEMPhase import CPICANN\n",
|
| 735 |
+
"# Here, illustrate the system requirements and how to initialize the system files at the first time of execution.\n",
|
| 736 |
+
"\n",
|
| 737 |
+
"CPICANN.PhaseIdentifier(FilePath='./testdata',Model='noise_model',Task='single-phase',ElementsContained='Pb_S_O',Device='cpu',)"
|
| 738 |
+
]
|
| 739 |
+
},
|
| 740 |
+
{
|
| 741 |
+
"cell_type": "raw",
|
| 742 |
+
"id": "977ede12",
|
| 743 |
+
"metadata": {},
|
| 744 |
+
"source": [
|
| 745 |
+
"For inquiries or assistance, please don't hesitate to contact us at bcao686@connect.hkust-gz.edu.cn (Dr. CAO Bin)."
|
| 746 |
+
]
|
| 747 |
+
},
|
| 748 |
+
{
|
| 749 |
+
"cell_type": "code",
|
| 750 |
+
"execution_count": null,
|
| 751 |
+
"id": "bb101480",
|
| 752 |
+
"metadata": {},
|
| 753 |
+
"outputs": [],
|
| 754 |
+
"source": []
|
| 755 |
+
}
|
| 756 |
+
],
|
| 757 |
+
"metadata": {
|
| 758 |
+
"kernelspec": {
|
| 759 |
+
"display_name": "Python 3 (ipykernel)",
|
| 760 |
+
"language": "python",
|
| 761 |
+
"name": "python3"
|
| 762 |
+
},
|
| 763 |
+
"language_info": {
|
| 764 |
+
"codemirror_mode": {
|
| 765 |
+
"name": "ipython",
|
| 766 |
+
"version": 3
|
| 767 |
+
},
|
| 768 |
+
"file_extension": ".py",
|
| 769 |
+
"mimetype": "text/x-python",
|
| 770 |
+
"name": "python",
|
| 771 |
+
"nbconvert_exporter": "python",
|
| 772 |
+
"pygments_lexer": "ipython3",
|
| 773 |
+
"version": "3.9.12"
|
| 774 |
+
}
|
| 775 |
+
},
|
| 776 |
+
"nbformat": 4,
|
| 777 |
+
"nbformat_minor": 5
|
| 778 |
+
}
|
src/inference&case/config/elem_setting.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FileName,must include elements,include at least one of,exclude elements
|
| 2 |
+
CdS.csv,Cd_S,,
|
| 3 |
+
MnS.csv,Mn_S,,
|
| 4 |
+
NiO2H2.csv,Ni_O_H,,
|
| 5 |
+
PbSO4.csv,Pb_O,S,
|
src/inference&case/figs/PbSO4.csv.png
ADDED
|
src/inference&case/infResults_testdata.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
path,fileName,predRank,pred,codId,formula,spaceGroupNo,spaceGroup
|
| 2 |
+
./testdata,PbSO4.csv,1,3378,9004484.0,Pb4 S4 O16,62,Pnma
|
src/inference&case/testdata/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/inference&case/testdata/PbSO4.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/model/CPICANN.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn.init import trunc_normal_
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CPICANN(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, embed_dim=64, nhead=8, num_encoder_layers=6, dim_feedforward=1024,
|
| 13 |
+
dropout=0.1, activation="relu", num_classes=23073):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.embed_dim = embed_dim
|
| 17 |
+
self.num_classes = num_classes
|
| 18 |
+
|
| 19 |
+
self.conv = ConvModule(drop_rate=dropout)
|
| 20 |
+
|
| 21 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 22 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 142))
|
| 23 |
+
|
| 24 |
+
# -------------encoder----------------
|
| 25 |
+
sa_layer = SelfAttnLayer(embed_dim, nhead, dim_feedforward, dropout, activation)
|
| 26 |
+
self.encoder = SelfAttnModule(sa_layer, num_encoder_layers)
|
| 27 |
+
# ------------------------------------
|
| 28 |
+
|
| 29 |
+
self.norm_after = nn.LayerNorm(embed_dim)
|
| 30 |
+
|
| 31 |
+
self.cls_head = nn.Sequential(
|
| 32 |
+
nn.Linear(embed_dim, int(embed_dim * 4)),
|
| 33 |
+
nn.BatchNorm1d(int(embed_dim * 4)),
|
| 34 |
+
nn.ReLU(inplace=True),
|
| 35 |
+
nn.Dropout(0.5),
|
| 36 |
+
nn.Linear(int(embed_dim * 4), int(embed_dim * 4)),
|
| 37 |
+
nn.BatchNorm1d(int(embed_dim * 4)),
|
| 38 |
+
nn.ReLU(inplace=True),
|
| 39 |
+
nn.Dropout(0.5),
|
| 40 |
+
nn.Linear(int(embed_dim * 4), num_classes)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self._reset_parameters()
|
| 44 |
+
self.init_weights()
|
| 45 |
+
|
| 46 |
+
def _reset_parameters(self):
|
| 47 |
+
for p in self.parameters():
|
| 48 |
+
if p.dim() > 1:
|
| 49 |
+
nn.init.xavier_uniform_(p)
|
| 50 |
+
|
| 51 |
+
def init_weights(self):
|
| 52 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 53 |
+
|
| 54 |
+
self.pos_embed.requires_grad = False
|
| 55 |
+
|
| 56 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, np.array(range(self.pos_embed.shape[2])))
|
| 57 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).T.unsqueeze(0))
|
| 58 |
+
|
| 59 |
+
def bce_fineTune_init_weights(self):
|
| 60 |
+
for p in self.conv.parameters():
|
| 61 |
+
p.requires_grad = False
|
| 62 |
+
|
| 63 |
+
for p in self.encoder.parameters():
|
| 64 |
+
if p.dim() > 1:
|
| 65 |
+
nn.init.xavier_uniform_(p)
|
| 66 |
+
for p in self.cls_head.parameters():
|
| 67 |
+
if p.dim() > 1:
|
| 68 |
+
nn.init.xavier_uniform_(p)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
N = x.shape[0]
|
| 72 |
+
if x.shape[1] == 2:
|
| 73 |
+
x = x[:, 1:, :]
|
| 74 |
+
|
| 75 |
+
x = x / 100
|
| 76 |
+
x = self.conv(x)
|
| 77 |
+
|
| 78 |
+
# flatten NxCxL to LxNxC
|
| 79 |
+
x = x.permute(2, 0, 1).contiguous()
|
| 80 |
+
|
| 81 |
+
cls_token = self.cls_token.expand(-1, N, -1)
|
| 82 |
+
x = torch.cat((cls_token, x), dim=0)
|
| 83 |
+
|
| 84 |
+
pos_embed = self.pos_embed.permute(2, 0, 1).contiguous().repeat(1, N, 1)
|
| 85 |
+
feats = self.encoder(x, pos_embed)
|
| 86 |
+
feats = self.norm_after(feats)
|
| 87 |
+
logits = self.cls_head(feats[0])
|
| 88 |
+
return logits
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ConvModule(nn.Module):
|
| 92 |
+
def __init__(self, drop_rate=0.):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.drop_rate = drop_rate
|
| 95 |
+
|
| 96 |
+
self.conv1 = nn.Conv1d(1, 64, kernel_size=35, stride=2, padding=17)
|
| 97 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 98 |
+
self.act1 = nn.ReLU()
|
| 99 |
+
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 100 |
+
|
| 101 |
+
self.layer1 = Layer(64, 64, kernel_size=3, stride=2, downsample=True)
|
| 102 |
+
self.layer2 = Layer(64, 128, kernel_size=3, stride=2, downsample=True)
|
| 103 |
+
# self.layer3 = Layer(256, 256, kernel_size=3, stride=2, downsample=True)
|
| 104 |
+
self.maxpool2 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x = self.conv1(x)
|
| 108 |
+
x = self.bn1(x)
|
| 109 |
+
x = self.act1(x)
|
| 110 |
+
x = self.maxpool(x)
|
| 111 |
+
|
| 112 |
+
x = self.layer1(x)
|
| 113 |
+
x = self.layer2(x)
|
| 114 |
+
# x = self.layer3(x)
|
| 115 |
+
x = self.maxpool2(x)
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class SelfAttnModule(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 124 |
+
self.num_layers = num_layers
|
| 125 |
+
self.norm = norm
|
| 126 |
+
|
| 127 |
+
def forward(self, src, pos):
|
| 128 |
+
output = src
|
| 129 |
+
|
| 130 |
+
for layer in self.layers:
|
| 131 |
+
output = layer(output, pos)
|
| 132 |
+
|
| 133 |
+
if self.norm is not None:
|
| 134 |
+
output = self.norm(output)
|
| 135 |
+
|
| 136 |
+
return output
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class SelfAttnLayer(nn.Module):
|
| 140 |
+
|
| 141 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 142 |
+
activation="relu"):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 145 |
+
# Implementation of Feedforward model
|
| 146 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 147 |
+
self.dropout = nn.Dropout(dropout)
|
| 148 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 149 |
+
|
| 150 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 151 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 152 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 153 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 154 |
+
|
| 155 |
+
self.activation = _get_activation_fn(activation)
|
| 156 |
+
|
| 157 |
+
def forward(self, src, pos):
|
| 158 |
+
q = k = with_pos_embed(src, pos)
|
| 159 |
+
src2 = self.self_attn(q, k, value=src)[0]
|
| 160 |
+
src = src + self.dropout1(src2)
|
| 161 |
+
src = self.norm1(src)
|
| 162 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 163 |
+
src = src + self.dropout2(src2)
|
| 164 |
+
src = self.norm2(src)
|
| 165 |
+
return src
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Layer(nn.Module):
|
| 169 |
+
def __init__(self, inchannel, outchannel, kernel_size, stride, downsample):
|
| 170 |
+
super(Layer, self).__init__()
|
| 171 |
+
self.block1 = BasicBlock(inchannel, outchannel, kernel_size=kernel_size, stride=stride, downsample=downsample)
|
| 172 |
+
self.block2 = BasicBlock(outchannel, outchannel, kernel_size=kernel_size, stride=1)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
x = self.block1(x)
|
| 176 |
+
x = self.block2(x)
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class BasicBlock(nn.Module):
|
| 181 |
+
def __init__(self, inchannel, outchannel, kernel_size, stride, downsample=False):
|
| 182 |
+
super(BasicBlock, self).__init__()
|
| 183 |
+
self.conv1 = nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
|
| 184 |
+
self.bn1 = nn.BatchNorm1d(outchannel)
|
| 185 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 186 |
+
self.conv2 = nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
|
| 187 |
+
self.bn2 = nn.BatchNorm1d(outchannel)
|
| 188 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 189 |
+
self.downsample = nn.Sequential(
|
| 190 |
+
nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=2),
|
| 191 |
+
nn.BatchNorm1d(outchannel)
|
| 192 |
+
) if downsample else None
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
shortcut = x
|
| 196 |
+
x = self.conv1(x)
|
| 197 |
+
x = self.bn1(x)
|
| 198 |
+
x = self.conv2(x)
|
| 199 |
+
x = self.bn2(x)
|
| 200 |
+
if self.downsample is not None:
|
| 201 |
+
shortcut = self.downsample(shortcut)
|
| 202 |
+
x += shortcut
|
| 203 |
+
x = self.act2(x)
|
| 204 |
+
return x
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _get_clones(module, N):
|
| 208 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _get_activation_fn(activation):
|
| 212 |
+
"""Return an activation function given a string"""
|
| 213 |
+
if activation == "relu":
|
| 214 |
+
return F.relu
|
| 215 |
+
if activation == "gelu":
|
| 216 |
+
return F.gelu
|
| 217 |
+
if activation == "glu":
|
| 218 |
+
return F.glu
|
| 219 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def with_pos_embed(tensor, pos):
|
| 223 |
+
return tensor if pos is None else tensor + pos
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 227 |
+
"""
|
| 228 |
+
embed_dim: output dimension for each position
|
| 229 |
+
pos: a list of positions to be encoded: size (M,)
|
| 230 |
+
out: (M, D)
|
| 231 |
+
"""
|
| 232 |
+
assert embed_dim % 2 == 0
|
| 233 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 234 |
+
omega /= embed_dim / 2.
|
| 235 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
| 236 |
+
|
| 237 |
+
pos = pos.reshape(-1) # (M,)
|
| 238 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 239 |
+
|
| 240 |
+
emb_sin = np.sin(out).astype(np.float32) # (M, D/2)
|
| 241 |
+
emb_cos = np.cos(out).astype(np.float32) # (M, D/2)
|
| 242 |
+
|
| 243 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 244 |
+
return emb
|
src/model/dataset.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class XrdDataset(Dataset):
|
| 9 |
+
def __init__(self, data_dir, annotations_file):
|
| 10 |
+
self.labels = pd.read_csv(annotations_file)
|
| 11 |
+
self.data_dir = data_dir
|
| 12 |
+
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return len(self.labels)
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, idx):
|
| 17 |
+
dataid = str(self.labels.iloc[idx, 0])
|
| 18 |
+
data_path = os.path.join(self.data_dir, dataid + '.csv')
|
| 19 |
+
data_csv = pd.read_csv(data_path)
|
| 20 |
+
data = data_csv.values.astype(np.float32).T
|
| 21 |
+
|
| 22 |
+
label = self.labels.iloc[idx, 1]
|
| 23 |
+
|
| 24 |
+
return data, label
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class mixDataset_cls_dynamic(Dataset):
|
| 28 |
+
def __init__(self, data_dir, anno_struc, mode):
|
| 29 |
+
self.data_dir = data_dir
|
| 30 |
+
self.codIdList = pd.read_csv(anno_struc).values[:, 0].astype(np.int32)
|
| 31 |
+
self.mode = mode
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return 1000000
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
choice1, choice2 = np.random.randint(0, 23073, 2)
|
| 38 |
+
if self.mode == 'train':
|
| 39 |
+
rand1, rand2 = np.random.randint(1, 25, 2)
|
| 40 |
+
else:
|
| 41 |
+
rand1, rand2 = np.random.randint(1, 7, 2)
|
| 42 |
+
data_path1 = os.path.join(self.data_dir, '{}_{}.csv'.format(self.codIdList[choice1], rand1))
|
| 43 |
+
data_path2 = os.path.join(self.data_dir, '{}_{}.csv'.format(self.codIdList[choice2], rand2))
|
| 44 |
+
data1 = pd.read_csv(data_path1).values.astype(np.float32).T
|
| 45 |
+
data2 = pd.read_csv(data_path2).values.astype(np.float32).T
|
| 46 |
+
|
| 47 |
+
ratio1 = np.random.randint(20, 81)
|
| 48 |
+
ratio2 = 100 - ratio1
|
| 49 |
+
|
| 50 |
+
label = np.zeros(23073).astype(np.float32)
|
| 51 |
+
|
| 52 |
+
label[choice1] = 0.4
|
| 53 |
+
label[choice2] = 0.4
|
| 54 |
+
|
| 55 |
+
return data1, data2, ratio1, ratio2, label
|
src/model/focal_loss.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FocalLoss(nn.Module):
|
| 8 |
+
r"""
|
| 9 |
+
This criterion is a implemenation of Focal Loss, which is proposed in
|
| 10 |
+
Focal Loss for Dense Object Detection.
|
| 11 |
+
|
| 12 |
+
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
|
| 13 |
+
|
| 14 |
+
The losses are averaged across observations for each minibatch.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
alpha(1D Tensor, Variable) : the scalar factor for this criterion
|
| 18 |
+
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
|
| 19 |
+
putting more focus on hard, misclassified examples
|
| 20 |
+
size_average(bool): By default, the losses are averaged over observations for each minibatch.
|
| 21 |
+
However, if the field size_average is set to False, the losses are
|
| 22 |
+
instead summed for each minibatch.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, class_num, alpha=None, gamma=2, size_average=True, device='cuda:0'):
|
| 25 |
+
super(FocalLoss, self).__init__()
|
| 26 |
+
if alpha is None:
|
| 27 |
+
self.alpha = Variable(torch.ones(class_num, 1))
|
| 28 |
+
else:
|
| 29 |
+
if isinstance(alpha, Variable):
|
| 30 |
+
self.alpha = alpha
|
| 31 |
+
else:
|
| 32 |
+
self.alpha = Variable(alpha)
|
| 33 |
+
self.gamma = gamma
|
| 34 |
+
self.class_num = class_num
|
| 35 |
+
self.size_average = size_average
|
| 36 |
+
self.device = device
|
| 37 |
+
|
| 38 |
+
def forward(self, inputs, targets):
|
| 39 |
+
N = inputs.size(0)
|
| 40 |
+
C = inputs.size(1)
|
| 41 |
+
P = F.softmax(inputs, dim=1)
|
| 42 |
+
|
| 43 |
+
class_mask = inputs.data.new(N, C).fill_(0)
|
| 44 |
+
class_mask = Variable(class_mask)
|
| 45 |
+
ids = targets.view(-1, 1)
|
| 46 |
+
class_mask.scatter_(1, ids.data, 1.)
|
| 47 |
+
# print(class_mask)
|
| 48 |
+
|
| 49 |
+
if inputs.is_cuda and not self.alpha.is_cuda:
|
| 50 |
+
# self.alpha = self.alpha.cuda()
|
| 51 |
+
self.alpha = self.alpha.to(self.device)
|
| 52 |
+
alpha = self.alpha[ids.data.view(-1)]
|
| 53 |
+
|
| 54 |
+
probs = (P * class_mask).sum(1).view(-1, 1)
|
| 55 |
+
|
| 56 |
+
log_p = probs.log()
|
| 57 |
+
# print('probs size= {}'.format(probs.size()))
|
| 58 |
+
# print(probs)
|
| 59 |
+
|
| 60 |
+
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
|
| 61 |
+
# print('-----bacth_loss------')
|
| 62 |
+
# print(batch_loss)
|
| 63 |
+
|
| 64 |
+
if self.size_average:
|
| 65 |
+
loss = batch_loss.mean()
|
| 66 |
+
else:
|
| 67 |
+
loss = batch_loss.sum()
|
| 68 |
+
return loss
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# import torch
|
| 72 |
+
# import torch.nn as nn
|
| 73 |
+
#
|
| 74 |
+
#
|
| 75 |
+
# class FocalLoss(nn.Module):
|
| 76 |
+
#
|
| 77 |
+
# def __init__(self, gamma=0, eps=1e-7):
|
| 78 |
+
# super(FocalLoss, self).__init__()
|
| 79 |
+
# self.gamma = gamma
|
| 80 |
+
# self.eps = eps
|
| 81 |
+
# self.ce = torch.nn.CrossEntropyLoss()
|
| 82 |
+
#
|
| 83 |
+
# def forward(self, input, target):
|
| 84 |
+
# logp = self.ce(input, target)
|
| 85 |
+
# p = torch.exp(-logp)
|
| 86 |
+
# loss = (1 - p) ** self.gamma * logp
|
| 87 |
+
# return loss.mean()
|
src/othermodels/ATTENTIONonly.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn.init import trunc_normal_
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VIT(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, embed_dim=64, nhead=8, num_encoder_layers=6, dim_feedforward=1024,
|
| 13 |
+
dropout=0.1, activation="relu", num_classes=23073):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.embed_dim = embed_dim
|
| 17 |
+
self.num_classes = num_classes
|
| 18 |
+
|
| 19 |
+
self.conv = torch.nn.Conv1d(1, embed_dim, kernel_size=32, stride=32, padding=0)
|
| 20 |
+
|
| 21 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 22 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 141))
|
| 23 |
+
|
| 24 |
+
# -------------encoder----------------
|
| 25 |
+
sa_layer = SelfAttnLayer(embed_dim, nhead, dim_feedforward, dropout, activation)
|
| 26 |
+
self.encoder = SelfAttnModule(sa_layer, num_encoder_layers)
|
| 27 |
+
# ------------------------------------
|
| 28 |
+
|
| 29 |
+
self.norm_after = nn.LayerNorm(embed_dim)
|
| 30 |
+
|
| 31 |
+
self.cls_head = nn.Sequential(
|
| 32 |
+
nn.Linear(embed_dim, int(embed_dim * 4)),
|
| 33 |
+
nn.BatchNorm1d(int(embed_dim * 4)),
|
| 34 |
+
nn.ReLU(inplace=True),
|
| 35 |
+
nn.Dropout(0.5),
|
| 36 |
+
nn.Linear(int(embed_dim * 4), int(embed_dim * 4)),
|
| 37 |
+
nn.BatchNorm1d(int(embed_dim * 4)),
|
| 38 |
+
nn.ReLU(inplace=True),
|
| 39 |
+
nn.Dropout(0.5),
|
| 40 |
+
nn.Linear(int(embed_dim * 4), num_classes)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self._reset_parameters()
|
| 44 |
+
self.init_weights()
|
| 45 |
+
|
| 46 |
+
def _reset_parameters(self):
|
| 47 |
+
for p in self.parameters():
|
| 48 |
+
if p.dim() > 1:
|
| 49 |
+
nn.init.xavier_uniform_(p)
|
| 50 |
+
|
| 51 |
+
def init_weights(self):
|
| 52 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 53 |
+
|
| 54 |
+
self.pos_embed.requires_grad = False
|
| 55 |
+
|
| 56 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, np.array(range(self.pos_embed.shape[2])))
|
| 57 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).T.unsqueeze(0))
|
| 58 |
+
|
| 59 |
+
def bce_fineTune_init_weights(self):
|
| 60 |
+
for p in self.conv.parameters():
|
| 61 |
+
p.requires_grad = False
|
| 62 |
+
|
| 63 |
+
for p in self.encoder.parameters():
|
| 64 |
+
if p.dim() > 1:
|
| 65 |
+
nn.init.xavier_uniform_(p)
|
| 66 |
+
for p in self.cls_head.parameters():
|
| 67 |
+
if p.dim() > 1:
|
| 68 |
+
nn.init.xavier_uniform_(p)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
N = x.shape[0]
|
| 72 |
+
if x.shape[1] == 2:
|
| 73 |
+
x = x[:, 1:, :]
|
| 74 |
+
|
| 75 |
+
x = x / 100
|
| 76 |
+
x = self.conv(x)
|
| 77 |
+
|
| 78 |
+
# flatten NxCxL to LxNxC
|
| 79 |
+
x = x.permute(2, 0, 1).contiguous()
|
| 80 |
+
|
| 81 |
+
cls_token = self.cls_token.expand(-1, N, -1)
|
| 82 |
+
x = torch.cat((cls_token, x), dim=0)
|
| 83 |
+
|
| 84 |
+
pos_embed = self.pos_embed.permute(2, 0, 1).contiguous().repeat(1, N, 1)
|
| 85 |
+
feats = self.encoder(x, pos_embed)
|
| 86 |
+
feats = self.norm_after(feats)
|
| 87 |
+
logits = self.cls_head(feats[0])
|
| 88 |
+
return logits
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ConvModule(nn.Module):
|
| 92 |
+
def __init__(self, drop_rate=0.):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.drop_rate = drop_rate
|
| 95 |
+
|
| 96 |
+
self.conv1 = nn.Conv1d(1, 64, kernel_size=35, stride=2, padding=17)
|
| 97 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 98 |
+
self.act1 = nn.ReLU()
|
| 99 |
+
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 100 |
+
|
| 101 |
+
self.layer1 = Layer(64, 64, kernel_size=3, stride=2, downsample=True)
|
| 102 |
+
self.layer2 = Layer(64, 128, kernel_size=3, stride=2, downsample=True)
|
| 103 |
+
# self.layer3 = Layer(256, 256, kernel_size=3, stride=2, downsample=True)
|
| 104 |
+
self.maxpool2 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x = self.conv1(x)
|
| 108 |
+
x = self.bn1(x)
|
| 109 |
+
x = self.act1(x)
|
| 110 |
+
x = self.maxpool(x)
|
| 111 |
+
|
| 112 |
+
x = self.layer1(x)
|
| 113 |
+
x = self.layer2(x)
|
| 114 |
+
# x = self.layer3(x)
|
| 115 |
+
x = self.maxpool2(x)
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class SelfAttnModule(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 124 |
+
self.num_layers = num_layers
|
| 125 |
+
self.norm = norm
|
| 126 |
+
|
| 127 |
+
def forward(self, src, pos):
|
| 128 |
+
output = src
|
| 129 |
+
|
| 130 |
+
for layer in self.layers:
|
| 131 |
+
output = layer(output, pos)
|
| 132 |
+
|
| 133 |
+
if self.norm is not None:
|
| 134 |
+
output = self.norm(output)
|
| 135 |
+
|
| 136 |
+
return output
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class SelfAttnLayer(nn.Module):
|
| 140 |
+
|
| 141 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 142 |
+
activation="relu"):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 145 |
+
# Implementation of Feedforward model
|
| 146 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 147 |
+
self.dropout = nn.Dropout(dropout)
|
| 148 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 149 |
+
|
| 150 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 151 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 152 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 153 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 154 |
+
|
| 155 |
+
self.activation = _get_activation_fn(activation)
|
| 156 |
+
|
| 157 |
+
def forward(self, src, pos):
|
| 158 |
+
q = k = with_pos_embed(src, pos)
|
| 159 |
+
src2 = self.self_attn(q, k, value=src)[0]
|
| 160 |
+
src = src + self.dropout1(src2)
|
| 161 |
+
src = self.norm1(src)
|
| 162 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 163 |
+
src = src + self.dropout2(src2)
|
| 164 |
+
src = self.norm2(src)
|
| 165 |
+
return src
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Layer(nn.Module):
|
| 169 |
+
def __init__(self, inchannel, outchannel, kernel_size, stride, downsample):
|
| 170 |
+
super(Layer, self).__init__()
|
| 171 |
+
self.block1 = BasicBlock(inchannel, outchannel, kernel_size=kernel_size, stride=stride, downsample=downsample)
|
| 172 |
+
self.block2 = BasicBlock(outchannel, outchannel, kernel_size=kernel_size, stride=1)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
x = self.block1(x)
|
| 176 |
+
x = self.block2(x)
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class BasicBlock(nn.Module):
|
| 181 |
+
def __init__(self, inchannel, outchannel, kernel_size, stride, downsample=False):
|
| 182 |
+
super(BasicBlock, self).__init__()
|
| 183 |
+
self.conv1 = nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
|
| 184 |
+
self.bn1 = nn.BatchNorm1d(outchannel)
|
| 185 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 186 |
+
self.conv2 = nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
|
| 187 |
+
self.bn2 = nn.BatchNorm1d(outchannel)
|
| 188 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 189 |
+
self.downsample = nn.Sequential(
|
| 190 |
+
nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=2),
|
| 191 |
+
nn.BatchNorm1d(outchannel)
|
| 192 |
+
) if downsample else None
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
shortcut = x
|
| 196 |
+
x = self.conv1(x)
|
| 197 |
+
x = self.bn1(x)
|
| 198 |
+
x = self.conv2(x)
|
| 199 |
+
x = self.bn2(x)
|
| 200 |
+
if self.downsample is not None:
|
| 201 |
+
shortcut = self.downsample(shortcut)
|
| 202 |
+
x += shortcut
|
| 203 |
+
x = self.act2(x)
|
| 204 |
+
return x
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _get_clones(module, N):
|
| 208 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _get_activation_fn(activation):
|
| 212 |
+
"""Return an activation function given a string"""
|
| 213 |
+
if activation == "relu":
|
| 214 |
+
return F.relu
|
| 215 |
+
if activation == "gelu":
|
| 216 |
+
return F.gelu
|
| 217 |
+
if activation == "glu":
|
| 218 |
+
return F.glu
|
| 219 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def with_pos_embed(tensor, pos):
|
| 223 |
+
return tensor if pos is None else tensor + pos
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 227 |
+
"""
|
| 228 |
+
embed_dim: output dimension for each position
|
| 229 |
+
pos: a list of positions to be encoded: size (M,)
|
| 230 |
+
out: (M, D)
|
| 231 |
+
"""
|
| 232 |
+
assert embed_dim % 2 == 0
|
| 233 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 234 |
+
omega /= embed_dim / 2.
|
| 235 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
| 236 |
+
|
| 237 |
+
pos = pos.reshape(-1) # (M,)
|
| 238 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 239 |
+
|
| 240 |
+
emb_sin = np.sin(out).astype(np.float32) # (M, D/2)
|
| 241 |
+
emb_cos = np.cos(out).astype(np.float32) # (M, D/2)
|
| 242 |
+
|
| 243 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 244 |
+
return emb
|
src/othermodels/CNNonly.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn.init import trunc_normal_
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CNN(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, embed_dim=64, nhead=8, num_encoder_layers=6, dim_feedforward=1024,
|
| 13 |
+
dropout=0.1, activation="relu", num_classes=23073):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.embed_dim = embed_dim
|
| 17 |
+
self.num_classes = num_classes
|
| 18 |
+
|
| 19 |
+
self.conv = ConvModule(drop_rate=dropout)
|
| 20 |
+
|
| 21 |
+
self.proj = nn.Linear(141, 1)
|
| 22 |
+
|
| 23 |
+
self.cls_head = nn.Sequential(
|
| 24 |
+
nn.Linear(embed_dim, int(embed_dim * 4)),
|
| 25 |
+
nn.BatchNorm1d(int(embed_dim * 4)),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
nn.Dropout(0.5),
|
| 28 |
+
nn.Linear(int(embed_dim * 4), int(embed_dim * 4)),
|
| 29 |
+
nn.BatchNorm1d(int(embed_dim * 4)),
|
| 30 |
+
nn.ReLU(inplace=True),
|
| 31 |
+
nn.Dropout(0.5),
|
| 32 |
+
nn.Linear(int(embed_dim * 4), num_classes)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self._reset_parameters()
|
| 36 |
+
|
| 37 |
+
def _reset_parameters(self):
|
| 38 |
+
for p in self.parameters():
|
| 39 |
+
if p.dim() > 1:
|
| 40 |
+
nn.init.xavier_uniform_(p)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
N = x.shape[0]
|
| 45 |
+
if x.shape[1] == 2:
|
| 46 |
+
x = x[:, 1:, :]
|
| 47 |
+
|
| 48 |
+
x = x / 100
|
| 49 |
+
x = self.conv(x)
|
| 50 |
+
|
| 51 |
+
# flatten NxCxL to LxNxC
|
| 52 |
+
# x = x.permute(2, 0, 1).contiguous()
|
| 53 |
+
x = self.proj(x).flatten(1)
|
| 54 |
+
|
| 55 |
+
logits = self.cls_head(x)
|
| 56 |
+
return logits
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ConvModule(nn.Module):
|
| 60 |
+
def __init__(self, drop_rate=0.):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.drop_rate = drop_rate
|
| 63 |
+
|
| 64 |
+
self.conv1 = nn.Conv1d(1, 64, kernel_size=35, stride=2, padding=17)
|
| 65 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 66 |
+
self.act1 = nn.ReLU()
|
| 67 |
+
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 68 |
+
|
| 69 |
+
self.layer1 = Layer(64, 64, kernel_size=3, stride=2, downsample=True)
|
| 70 |
+
self.layer2 = Layer(64, 128, kernel_size=3, stride=2, downsample=True)
|
| 71 |
+
# self.layer3 = Layer(256, 256, kernel_size=3, stride=2, downsample=True)
|
| 72 |
+
self.maxpool2 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
x = self.conv1(x)
|
| 76 |
+
x = self.bn1(x)
|
| 77 |
+
x = self.act1(x)
|
| 78 |
+
x = self.maxpool(x)
|
| 79 |
+
|
| 80 |
+
x = self.layer1(x)
|
| 81 |
+
x = self.layer2(x)
|
| 82 |
+
# x = self.layer3(x)
|
| 83 |
+
x = self.maxpool2(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
class Layer(nn.Module):
|
| 87 |
+
def __init__(self, inchannel, outchannel, kernel_size, stride, downsample):
|
| 88 |
+
super(Layer, self).__init__()
|
| 89 |
+
self.block1 = BasicBlock(inchannel, outchannel, kernel_size=kernel_size, stride=stride, downsample=downsample)
|
| 90 |
+
self.block2 = BasicBlock(outchannel, outchannel, kernel_size=kernel_size, stride=1)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
x = self.block1(x)
|
| 94 |
+
x = self.block2(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
class BasicBlock(nn.Module):
|
| 98 |
+
def __init__(self, inchannel, outchannel, kernel_size, stride, downsample=False):
|
| 99 |
+
super(BasicBlock, self).__init__()
|
| 100 |
+
self.conv1 = nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
|
| 101 |
+
self.bn1 = nn.BatchNorm1d(outchannel)
|
| 102 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 103 |
+
self.conv2 = nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
|
| 104 |
+
self.bn2 = nn.BatchNorm1d(outchannel)
|
| 105 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 106 |
+
self.downsample = nn.Sequential(
|
| 107 |
+
nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=2),
|
| 108 |
+
nn.BatchNorm1d(outchannel)
|
| 109 |
+
) if downsample else None
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
shortcut = x
|
| 113 |
+
x = self.conv1(x)
|
| 114 |
+
x = self.bn1(x)
|
| 115 |
+
x = self.conv2(x)
|
| 116 |
+
x = self.bn2(x)
|
| 117 |
+
if self.downsample is not None:
|
| 118 |
+
shortcut = self.downsample(shortcut)
|
| 119 |
+
x += shortcut
|
| 120 |
+
x = self.act2(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _get_clones(module, N):
|
| 125 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _get_activation_fn(activation):
|
| 129 |
+
"""Return an activation function given a string"""
|
| 130 |
+
if activation == "relu":
|
| 131 |
+
return F.relu
|
| 132 |
+
if activation == "gelu":
|
| 133 |
+
return F.gelu
|
| 134 |
+
if activation == "glu":
|
| 135 |
+
return F.glu
|
| 136 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def with_pos_embed(tensor, pos):
|
| 140 |
+
return tensor if pos is None else tensor + pos
|
src/pretrained/# place pretrained .pth files here
ADDED
|
File without changes
|
src/train_bi-phase.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from torch import optim, nn
|
| 10 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from model.CPICANN import CPICANN
|
| 15 |
+
from model.dataset import mixDataset_cls_dynamic
|
| 16 |
+
from util.logger import Logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def run_one_epoch(model, dataloader, criterion, optimizer, epoch, mode):
|
| 20 |
+
if mode == 'Train':
|
| 21 |
+
model.train()
|
| 22 |
+
criterion.train()
|
| 23 |
+
desc = 'Training... '
|
| 24 |
+
else:
|
| 25 |
+
model.eval()
|
| 26 |
+
criterion.eval()
|
| 27 |
+
desc = 'Evaluating... '
|
| 28 |
+
|
| 29 |
+
epoch_loss, cls_acc = 0, 0
|
| 30 |
+
if args.progress_bar:
|
| 31 |
+
pbar = tqdm(total=len(dataloader.dataset), desc=desc, unit='data')
|
| 32 |
+
iters = len(dataloader)
|
| 33 |
+
for i, batch in enumerate(dataloader):
|
| 34 |
+
data1 = batch[0].to(device)
|
| 35 |
+
data2 = batch[1].to(device)
|
| 36 |
+
ratio1 = batch[2].to(device)
|
| 37 |
+
ratio2 = batch[3].to(device)
|
| 38 |
+
label_cls = batch[4].to(device)
|
| 39 |
+
|
| 40 |
+
data = torch.einsum('ijk,i->ijk', data1, ratio1) + torch.einsum('ijk,i->ijk', data2, ratio2)
|
| 41 |
+
min_i = data.min(dim=2, keepdim=True)[0]
|
| 42 |
+
max_i = data.max(dim=2, keepdim=True)[0]
|
| 43 |
+
data = (data - min_i) / (max_i - min_i) * 100
|
| 44 |
+
|
| 45 |
+
if mode == 'Train':
|
| 46 |
+
adjust_learning_rate_withWarmup(optimizer, epoch + i / iters, args)
|
| 47 |
+
|
| 48 |
+
logits = model(data)
|
| 49 |
+
loss = criterion(logits, label_cls)
|
| 50 |
+
|
| 51 |
+
optimizer.zero_grad()
|
| 52 |
+
loss.backward()
|
| 53 |
+
optimizer.step()
|
| 54 |
+
else:
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
logits = model(data)
|
| 57 |
+
loss = criterion(logits, label_cls)
|
| 58 |
+
|
| 59 |
+
epoch_loss += loss.item()
|
| 60 |
+
if args.progress_bar:
|
| 61 |
+
pbar.update(len(data))
|
| 62 |
+
pbar.set_postfix(**{'loss': loss.item()})
|
| 63 |
+
|
| 64 |
+
return epoch_loss / iters
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def print_log(epoch, loss_train, loss_val, lr):
|
| 68 |
+
log.printlog('---------------- Epoch {} ----------------'.format(epoch))
|
| 69 |
+
|
| 70 |
+
log.printlog('loss_train : {}'.format(round(loss_train, 6)))
|
| 71 |
+
log.printlog('loss_val : {}'.format(round(loss_val, 6)))
|
| 72 |
+
|
| 73 |
+
log.train_writer.add_scalar('mix_loss', loss_train, epoch)
|
| 74 |
+
log.val_writer.add_scalar('mix_loss', loss_val, epoch)
|
| 75 |
+
|
| 76 |
+
log.train_writer.add_scalar('lr', lr, epoch)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def save_checkpoint(state, is_best, filepath, filename):
|
| 80 |
+
if (state['epoch']) % 10 == 0 or state['epoch'] == 1:
|
| 81 |
+
os.makedirs(filepath, exist_ok=True)
|
| 82 |
+
torch.save(state, filepath + filename)
|
| 83 |
+
log.printlog('checkpoint saved!')
|
| 84 |
+
if is_best:
|
| 85 |
+
torch.save(state, '{}/model_best.pth'.format(filepath))
|
| 86 |
+
log.printlog('best model saved!')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def adjust_learning_rate(optimizer, epoch, schedule):
|
| 90 |
+
"""Decay the learning rate based on schedule"""
|
| 91 |
+
lr = optimizer.defaults['lr']
|
| 92 |
+
for milestone in schedule:
|
| 93 |
+
lr *= 0.1 if epoch >= milestone else 1.
|
| 94 |
+
for param_group in optimizer.param_groups:
|
| 95 |
+
param_group['lr'] = lr
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def adjust_learning_rate_withWarmup(optimizer, epoch, args):
|
| 99 |
+
"""Decays the learning rate with half-cycle cosine after warmup"""
|
| 100 |
+
if epoch < args.warmup_epochs:
|
| 101 |
+
lr = args.lr * epoch / args.warmup_epochs
|
| 102 |
+
else:
|
| 103 |
+
lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
| 104 |
+
for param_group in optimizer.param_groups:
|
| 105 |
+
param_group['lr'] = lr
|
| 106 |
+
return lr
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def main():
|
| 110 |
+
print('>>>> Running on {} <<<<'.format(device))
|
| 111 |
+
|
| 112 |
+
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
|
| 113 |
+
|
| 114 |
+
# LOAD PRETRAINED MODEL
|
| 115 |
+
loaded = torch.load(args.load_path)
|
| 116 |
+
model.load_state_dict(loaded['model'])
|
| 117 |
+
|
| 118 |
+
model.bce_fineTune_init_weights()
|
| 119 |
+
model.to(device)
|
| 120 |
+
if rank == 0:
|
| 121 |
+
log.printlog(model)
|
| 122 |
+
|
| 123 |
+
trainset = mixDataset_cls_dynamic(args.data_dir_train, args.anno_struc, mode='Train')
|
| 124 |
+
valset = mixDataset_cls_dynamic(args.data_dir_val, args.anno_struc, mode='Eval')
|
| 125 |
+
|
| 126 |
+
if distributed:
|
| 127 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
|
| 128 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=True)
|
| 129 |
+
|
| 130 |
+
train_loader = DataLoader(trainset, batch_size=512, num_workers=16, pin_memory=True, drop_last=True, sampler=train_sampler)
|
| 131 |
+
val_loader = DataLoader(valset, batch_size=512, num_workers=16, pin_memory=True, drop_last=True, sampler=val_sampler)
|
| 132 |
+
|
| 133 |
+
model = DDP(model, device_ids=[device], output_device=local_rank, find_unused_parameters=False)
|
| 134 |
+
else:
|
| 135 |
+
train_loader = DataLoader(trainset, batch_size=512, num_workers=16, pin_memory=True, shuffle=True)
|
| 136 |
+
val_loader = DataLoader(valset, batch_size=512, num_workers=16, pin_memory=True, shuffle=True)
|
| 137 |
+
|
| 138 |
+
criterion = nn.CrossEntropyLoss()
|
| 139 |
+
|
| 140 |
+
optimizer = optim.AdamW(model.parameters(), args.lr, weight_decay=1e-4)
|
| 141 |
+
start_epoch = 0
|
| 142 |
+
|
| 143 |
+
for epoch in range(start_epoch + 1, args.epochs + 1):
|
| 144 |
+
if distributed:
|
| 145 |
+
train_sampler.set_epoch(epoch)
|
| 146 |
+
val_sampler.set_epoch(epoch)
|
| 147 |
+
|
| 148 |
+
loss_train = run_one_epoch(model, train_loader, criterion, optimizer, epoch, mode='Train')
|
| 149 |
+
|
| 150 |
+
loss_val = run_one_epoch(model, val_loader, criterion, optimizer, epoch, mode='Eval')
|
| 151 |
+
|
| 152 |
+
if rank == 0:
|
| 153 |
+
print_log(epoch, loss_train, loss_val, optimizer.param_groups[0]['lr'])
|
| 154 |
+
save_checkpoint({'epoch': epoch,
|
| 155 |
+
'model': model.module.state_dict() if distributed else model.state_dict(),
|
| 156 |
+
'optimizer': optimizer}, is_best=False,
|
| 157 |
+
filepath='{}/checkpoints/'.format(log.get_path()),
|
| 158 |
+
filename='checkpoint_{:04d}.pth'.format(epoch))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 163 |
+
rank = int(os.environ["RANK"])
|
| 164 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 165 |
+
torch.cuda.set_device(rank % torch.cuda.device_count())
|
| 166 |
+
dist.init_process_group(backend="nccl")
|
| 167 |
+
device = torch.device("cuda", local_rank)
|
| 168 |
+
print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")
|
| 169 |
+
|
| 170 |
+
distributed = True
|
| 171 |
+
else:
|
| 172 |
+
rank = 0
|
| 173 |
+
device = 'cuda:0'
|
| 174 |
+
distributed = False
|
| 175 |
+
|
| 176 |
+
parser = argparse.ArgumentParser()
|
| 177 |
+
parser.add_argument("-progress_bar", type=bool, default=True)
|
| 178 |
+
|
| 179 |
+
parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
| 180 |
+
help='number of total epochs to run')
|
| 181 |
+
parser.add_argument('--warmup-epochs', default=20, type=int, metavar='N',
|
| 182 |
+
help='number of warmup epochs')
|
| 183 |
+
parser.add_argument('--lr', '--learning-rate', default=8e-4, type=float,
|
| 184 |
+
metavar='LR', help='initial (base) learning rate', dest='lr')
|
| 185 |
+
|
| 186 |
+
parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str,
|
| 187 |
+
help='path to load pretrained single-phase identification model')
|
| 188 |
+
parser.add_argument('--data_dir_train', default='data/train', type=str)
|
| 189 |
+
parser.add_argument('--data_dir_val', default='data/val', type=str)
|
| 190 |
+
parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
|
| 191 |
+
help='path to annotation file for structures')
|
| 192 |
+
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
|
| 193 |
+
|
| 194 |
+
args = parser.parse_args()
|
| 195 |
+
|
| 196 |
+
if rank == 0:
|
| 197 |
+
log = Logger(val=True)
|
| 198 |
+
|
| 199 |
+
main()
|
| 200 |
+
print('THE END')
|
src/train_single-phase.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
from torch import optim
|
| 8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from model.dataset import XrdDataset
|
| 13 |
+
from model.CPICANN import CPICANN
|
| 14 |
+
from model.focal_loss import FocalLoss
|
| 15 |
+
from util.logger import Logger
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_acc(cls, label):
|
| 19 |
+
cls_acc = sum(cls.argmax(1) == label.int()) / cls.shape[0]
|
| 20 |
+
return cls_acc
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def run_one_epoch(model, dataloader, criterion, optimizer, epoch, mode):
|
| 24 |
+
if mode == 'Train':
|
| 25 |
+
model.train()
|
| 26 |
+
criterion.train()
|
| 27 |
+
desc = 'Training... '
|
| 28 |
+
else:
|
| 29 |
+
model.eval()
|
| 30 |
+
criterion.eval()
|
| 31 |
+
desc = 'Evaluating... '
|
| 32 |
+
|
| 33 |
+
epoch_loss, cls_acc = 0, 0
|
| 34 |
+
if args.progress_bar:
|
| 35 |
+
pbar = tqdm(total=len(dataloader.dataset), desc=desc, unit='data')
|
| 36 |
+
iters = len(dataloader)
|
| 37 |
+
for i, batch in enumerate(dataloader):
|
| 38 |
+
data = batch[0].to(device)
|
| 39 |
+
label_cls = batch[1].to(device)
|
| 40 |
+
|
| 41 |
+
if mode == 'Train':
|
| 42 |
+
adjust_learning_rate_withWarmup(optimizer, epoch + i / iters, args)
|
| 43 |
+
|
| 44 |
+
logits = model(data)
|
| 45 |
+
loss = criterion(logits, label_cls.long())
|
| 46 |
+
|
| 47 |
+
optimizer.zero_grad()
|
| 48 |
+
loss.backward()
|
| 49 |
+
optimizer.step()
|
| 50 |
+
else:
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
logits = model(data)
|
| 53 |
+
loss = criterion(logits, label_cls.long())
|
| 54 |
+
|
| 55 |
+
epoch_loss += loss.item()
|
| 56 |
+
if args.progress_bar:
|
| 57 |
+
pbar.update(len(data))
|
| 58 |
+
pbar.set_postfix(**{'loss': loss.item()})
|
| 59 |
+
|
| 60 |
+
_cls_acc = get_acc(logits, label_cls)
|
| 61 |
+
cls_acc += _cls_acc.item()
|
| 62 |
+
|
| 63 |
+
return epoch_loss / iters, cls_acc * 100 / iters
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def print_log(epoch, loss_train, loss_val, acc_train, acc_val, lr):
|
| 67 |
+
log.printlog('---------------- Epoch {} ----------------'.format(epoch))
|
| 68 |
+
|
| 69 |
+
log.printlog('loss_train : {}'.format(round(loss_train, 4)))
|
| 70 |
+
log.printlog('loss_val : {}'.format(round(loss_val, 4)))
|
| 71 |
+
|
| 72 |
+
log.printlog('acc_train : {}%'.format(round(acc_train, 4)))
|
| 73 |
+
log.printlog('acc_val : {}%'.format(round(acc_val, 4)))
|
| 74 |
+
|
| 75 |
+
log.train_writer.add_scalar('loss', loss_train, epoch)
|
| 76 |
+
log.val_writer.add_scalar('loss', loss_val, epoch)
|
| 77 |
+
|
| 78 |
+
log.train_writer.add_scalar('acc', acc_train, epoch)
|
| 79 |
+
log.val_writer.add_scalar('acc', acc_val, epoch)
|
| 80 |
+
|
| 81 |
+
log.train_writer.add_scalar('lr', lr, epoch)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def save_checkpoint(state, is_best, filepath, filename):
|
| 85 |
+
if (state['epoch']) % 10 == 0 or state['epoch'] == 1:
|
| 86 |
+
os.makedirs(filepath, exist_ok=True)
|
| 87 |
+
torch.save(state, filepath + filename)
|
| 88 |
+
log.printlog('checkpoint saved!')
|
| 89 |
+
if is_best:
|
| 90 |
+
torch.save(state, '{}/model_best.pth'.format(filepath))
|
| 91 |
+
log.printlog('best model saved!')
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def adjust_learning_rate_withWarmup(optimizer, epoch, args):
|
| 95 |
+
"""Decays the learning rate with half-cycle cosine after warmup"""
|
| 96 |
+
if epoch < args.warmup_epochs:
|
| 97 |
+
lr = args.lr * epoch / args.warmup_epochs
|
| 98 |
+
else:
|
| 99 |
+
lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
| 100 |
+
for param_group in optimizer.param_groups:
|
| 101 |
+
param_group['lr'] = lr
|
| 102 |
+
return lr
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
print('>>>> Running on {} <<<<'.format(device))
|
| 107 |
+
|
| 108 |
+
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
|
| 109 |
+
model.to(device)
|
| 110 |
+
if rank == 0:
|
| 111 |
+
log.printlog(model)
|
| 112 |
+
|
| 113 |
+
trainset = XrdDataset(args.data_dir_train, args.anno_train)
|
| 114 |
+
valset = XrdDataset(args.data_dir_val, args.anno_val)
|
| 115 |
+
|
| 116 |
+
if distributed:
|
| 117 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
|
| 118 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=True)
|
| 119 |
+
|
| 120 |
+
train_loader = DataLoader(trainset, batch_size=128, num_workers=16, pin_memory=True, drop_last=True, sampler=train_sampler)
|
| 121 |
+
val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, drop_last=True, sampler=val_sampler)
|
| 122 |
+
|
| 123 |
+
model = DDP(model, device_ids=[device], output_device=local_rank, find_unused_parameters=False)
|
| 124 |
+
else:
|
| 125 |
+
train_loader = DataLoader(trainset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
|
| 126 |
+
val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
|
| 127 |
+
|
| 128 |
+
criterion = FocalLoss(class_num=args.num_classes, device=device)
|
| 129 |
+
|
| 130 |
+
optimizer = optim.AdamW(model.parameters(), args.lr, weight_decay=1e-4)
|
| 131 |
+
start_epoch = 0
|
| 132 |
+
|
| 133 |
+
for epoch in range(start_epoch + 1, args.epochs + 1):
|
| 134 |
+
if distributed:
|
| 135 |
+
train_sampler.set_epoch(epoch)
|
| 136 |
+
val_sampler.set_epoch(epoch)
|
| 137 |
+
|
| 138 |
+
loss_train, acc_train = run_one_epoch(model, train_loader, criterion, optimizer, epoch, mode='Train')
|
| 139 |
+
|
| 140 |
+
loss_val, acc_val = run_one_epoch(model, val_loader, criterion, optimizer, epoch, mode='Eval')
|
| 141 |
+
|
| 142 |
+
if rank == 0:
|
| 143 |
+
print_log(epoch, loss_train, loss_val, acc_train, acc_val, optimizer.param_groups[0]['lr'])
|
| 144 |
+
save_checkpoint({'epoch': epoch,
|
| 145 |
+
'model': model.module.state_dict() if distributed else model.state_dict(),
|
| 146 |
+
'optimizer': optimizer}, is_best=False,
|
| 147 |
+
filepath='{}/checkpoints/'.format(log.get_path()),
|
| 148 |
+
filename='checkpoint_{:04d}.pth'.format(epoch))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == '__main__':
|
| 152 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 153 |
+
rank = int(os.environ["RANK"])
|
| 154 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 155 |
+
torch.cuda.set_device(rank % torch.cuda.device_count())
|
| 156 |
+
dist.init_process_group(backend="nccl")
|
| 157 |
+
device = torch.device("cuda", local_rank)
|
| 158 |
+
print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")
|
| 159 |
+
distributed = True
|
| 160 |
+
else:
|
| 161 |
+
rank = 0
|
| 162 |
+
device = 'cuda:0'
|
| 163 |
+
distributed = False
|
| 164 |
+
|
| 165 |
+
parser = argparse.ArgumentParser()
|
| 166 |
+
parser.add_argument("--progress_bar", type=bool, default=True)
|
| 167 |
+
|
| 168 |
+
parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
| 169 |
+
help='number of total epochs to run')
|
| 170 |
+
parser.add_argument('--warmup-epochs', default=20, type=int, metavar='N',
|
| 171 |
+
help='number of warmup epochs')
|
| 172 |
+
parser.add_argument('--lr', '--learning-rate', default=8e-5, type=float,
|
| 173 |
+
metavar='LR', help='initial (base) learning rate', dest='lr')
|
| 174 |
+
|
| 175 |
+
parser.add_argument('--data_dir_train', default='data/train/', type=str)
|
| 176 |
+
parser.add_argument('--data_dir_val', default='data/val/', type=str)
|
| 177 |
+
parser.add_argument('--anno_train', default='annotation/anno_train.csv', type=str,
|
| 178 |
+
help='path to annotation file for training data')
|
| 179 |
+
parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
|
| 180 |
+
help='path to annotation file for validation data')
|
| 181 |
+
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
|
| 182 |
+
|
| 183 |
+
args = parser.parse_args()
|
| 184 |
+
|
| 185 |
+
if rank == 0:
|
| 186 |
+
log = Logger(val=True)
|
| 187 |
+
|
| 188 |
+
main()
|
| 189 |
+
print('THE END')
|
src/util/logger.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from tensorboardX import SummaryWriter
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Logger(object):
|
| 9 |
+
def __init__(self, val=True, filename="print.log"):
|
| 10 |
+
self.Time = datetime.now().strftime('%Y-%m-%d_%H%M')
|
| 11 |
+
self.path = 'output/' + self.Time
|
| 12 |
+
self.log_filename = filename
|
| 13 |
+
os.makedirs(self.path) if os.path.exists(self.path) is False else None
|
| 14 |
+
self.run_path = '{}/{}'.format(self.path, 'tb')
|
| 15 |
+
|
| 16 |
+
# common log
|
| 17 |
+
self.terminal = sys.stdout
|
| 18 |
+
self.terminal.write(self.path)
|
| 19 |
+
|
| 20 |
+
# init tensorboardX
|
| 21 |
+
self.train_writer = None
|
| 22 |
+
self.val_writer = None
|
| 23 |
+
self.tensorboard_init(val)
|
| 24 |
+
|
| 25 |
+
def printlog(self, message):
|
| 26 |
+
message = str(message)
|
| 27 |
+
self.terminal.write(message + '\n')
|
| 28 |
+
|
| 29 |
+
log = open(os.path.join(self.path, self.log_filename), "a", encoding='utf8', )
|
| 30 |
+
log.write(message + '\n')
|
| 31 |
+
log.close()
|
| 32 |
+
|
| 33 |
+
def tensorboard_init(self, val=True):
|
| 34 |
+
if val:
|
| 35 |
+
self.train_writer = SummaryWriter(self.run_path+'/train')
|
| 36 |
+
self.val_writer = SummaryWriter(self.run_path+'/val')
|
| 37 |
+
else:
|
| 38 |
+
self.train_writer = SummaryWriter(self.run_path)
|
| 39 |
+
|
| 40 |
+
def get_path(self):
|
| 41 |
+
return self.path
|
src/val_bi-phase.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from model.CPICANN import CPICANN
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def getAnnoMap():
|
| 13 |
+
vs = pd.read_csv(args.anno_struc).values
|
| 14 |
+
annos, elems = {}, {}
|
| 15 |
+
for v in vs:
|
| 16 |
+
annos[v[1]] = v
|
| 17 |
+
elems[v[1]] = set(v[3].split(' '))
|
| 18 |
+
|
| 19 |
+
return annos, elems
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def filter_by_elem(logits, elemMap, elem):
|
| 23 |
+
for i, e in elemMap.items():
|
| 24 |
+
if not e <= elem:
|
| 25 |
+
logits[:, i] = -10 ** 9
|
| 26 |
+
|
| 27 |
+
return logits
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
annoMap, elemMap = getAnnoMap()
|
| 32 |
+
|
| 33 |
+
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
|
| 34 |
+
|
| 35 |
+
loaded = torch.load(args.load_path)
|
| 36 |
+
model.load_state_dict(loaded['model'])
|
| 37 |
+
model.to(args.device)
|
| 38 |
+
model.eval()
|
| 39 |
+
print('loaded model from {}'.format(args.load_path))
|
| 40 |
+
print(model)
|
| 41 |
+
|
| 42 |
+
if args.elem_filtration:
|
| 43 |
+
print('elem_filtration activated!')
|
| 44 |
+
else:
|
| 45 |
+
print('elem_filtration deactivated!')
|
| 46 |
+
|
| 47 |
+
lst = pd.read_csv(args.anno_val).values
|
| 48 |
+
|
| 49 |
+
top10Hits = np.array([0] * 10, dtype=np.int32)
|
| 50 |
+
|
| 51 |
+
dataLen = len(lst)
|
| 52 |
+
pbar = tqdm(range(args.infTimes))
|
| 53 |
+
for i in range(args.infTimes):
|
| 54 |
+
while True:
|
| 55 |
+
c1, c2 = np.random.randint(0, dataLen, 2)
|
| 56 |
+
anno1, anno2 = lst[c1], lst[c2]
|
| 57 |
+
if anno1[6] != anno2[6]:
|
| 58 |
+
break
|
| 59 |
+
|
| 60 |
+
# id1, id2 = int(lst[c1][0].split('_')[0]), int(lst[c2][0].split('_')[0])
|
| 61 |
+
# formula1, formula2 = lst[c1][2], lst[c2][2]
|
| 62 |
+
data1 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c1][0]}.csv')).values
|
| 63 |
+
data2 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c2][0]}.csv')).values
|
| 64 |
+
|
| 65 |
+
mixRate1 = np.random.randint(20, 81)
|
| 66 |
+
mixRate2 = 100 - mixRate1
|
| 67 |
+
|
| 68 |
+
data = mixRate1 * data1 + mixRate2 * data2
|
| 69 |
+
elem = set(lst[c2][3].strip().split(' ')) | set(lst[c1][3].strip().split(' '))
|
| 70 |
+
|
| 71 |
+
def runFile(v):
|
| 72 |
+
min_i, scale = min(v), max(v) - min(v)
|
| 73 |
+
v = (v - min_i) / scale * 100
|
| 74 |
+
|
| 75 |
+
v = torch.tensor(v, dtype=torch.float32).reshape(1, 1, -1)
|
| 76 |
+
v = v.to(args.device)
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
logits = model(v)
|
| 79 |
+
|
| 80 |
+
# filter by elements
|
| 81 |
+
if args.elem_filtration:
|
| 82 |
+
logits = filter_by_elem(logits, elemMap, elem)
|
| 83 |
+
|
| 84 |
+
_pred = torch.nn.functional.softmax(logits.squeeze(), dim=0)
|
| 85 |
+
return _pred.topk(10)
|
| 86 |
+
|
| 87 |
+
top10 = runFile(data)
|
| 88 |
+
|
| 89 |
+
m = [0] * 10
|
| 90 |
+
for no, (indice, rate) in enumerate(zip(top10.indices, top10.values)):
|
| 91 |
+
pred = annoMap[top10.indices[no].item()]
|
| 92 |
+
|
| 93 |
+
if pred[0] == int(anno1[0][:7]):
|
| 94 |
+
m[no] = 1
|
| 95 |
+
elif pred[0] == int(anno2[0][:7]):
|
| 96 |
+
m[no] = 2
|
| 97 |
+
|
| 98 |
+
if 1 in m[:2] and 2 in m[:2]:
|
| 99 |
+
top10Hits[1:] += 1
|
| 100 |
+
elif 1 in m[:3] and 2 in m[:3]:
|
| 101 |
+
top10Hits[2:] += 1
|
| 102 |
+
elif 1 in m[:4] and 2 in m[:4]:
|
| 103 |
+
top10Hits[3:] += 1
|
| 104 |
+
elif 1 in m[:5] and 2 in m[:5]:
|
| 105 |
+
top10Hits[4:] += 1
|
| 106 |
+
elif 1 in m[:6] and 2 in m[:6]:
|
| 107 |
+
top10Hits[5:] += 1
|
| 108 |
+
elif 1 in m[:7] and 2 in m[:7]:
|
| 109 |
+
top10Hits[6:] += 1
|
| 110 |
+
elif 1 in m[:8] and 2 in m[:8]:
|
| 111 |
+
top10Hits[7:] += 1
|
| 112 |
+
elif 1 in m[:9] and 2 in m[:9]:
|
| 113 |
+
top10Hits[8:] += 1
|
| 114 |
+
elif 1 in m[:10] and 2 in m[:10]:
|
| 115 |
+
top10Hits[9:] += 1
|
| 116 |
+
|
| 117 |
+
pbar.update(1)
|
| 118 |
+
pbar.close()
|
| 119 |
+
|
| 120 |
+
for i in range(1, 10):
|
| 121 |
+
print('top{}Hits: {}%'.format(i + 1, round(top10Hits[i] / args.infTimes * 100, 2)))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == '__main__':
|
| 125 |
+
parser = argparse.ArgumentParser()
|
| 126 |
+
|
| 127 |
+
parser.add_argument('--device', default='cuda:0', type=str)
|
| 128 |
+
parser.add_argument('--data_dir', default='data/val/', type=str)
|
| 129 |
+
parser.add_argument('--infTimes', default=1000, type=int, help='number of mixed pattern to be inferenced')
|
| 130 |
+
parser.add_argument('--load_path', default='pretrained/bi-phase_checkpoint_2000.pth', type=str,
|
| 131 |
+
help='path to load pretrained single-phase identification model')
|
| 132 |
+
parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
|
| 133 |
+
help='path to annotation file for training data')
|
| 134 |
+
parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
|
| 135 |
+
help='path to annotation file for validation data')
|
| 136 |
+
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
|
| 137 |
+
|
| 138 |
+
parser.add_argument('--elem_filtration', default=False, type=bool)
|
| 139 |
+
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
main()
|
| 143 |
+
print('THE END')
|
src/val_single-phase.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from model.CPICANN import CPICANN
|
| 10 |
+
from model.dataset import XrdDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_cs_anno():
|
| 14 |
+
vs = pd.read_csv(args.anno_struc).values
|
| 15 |
+
csAnno = {}
|
| 16 |
+
for v in vs:
|
| 17 |
+
csAnno[v[1]] = v[6]
|
| 18 |
+
return csAnno
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_acc(cls, label):
|
| 22 |
+
correct_cnt = sum(cls.argmax(1) == label.int())
|
| 23 |
+
cls_acc = correct_cnt / cls.shape[0]
|
| 24 |
+
return cls_acc, correct_cnt
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def run_one_epoch(model, dataloader):
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
csAnno = get_cs_anno()
|
| 31 |
+
|
| 32 |
+
csCorrect = [0 for _ in range(7)]
|
| 33 |
+
csTotal = [0 for _ in range(7)]
|
| 34 |
+
cMtrx = [[0 for _ in range(7)] for _ in range(7)]
|
| 35 |
+
epoch_loss, cls_acc = 0, 0
|
| 36 |
+
correct_cnt, total_cnt = 0, 0
|
| 37 |
+
pbar = tqdm(total=len(dataloader.dataset), desc='Evaluating... ', unit='data')
|
| 38 |
+
iters = len(dataloader)
|
| 39 |
+
for i, batch in enumerate(dataloader):
|
| 40 |
+
|
| 41 |
+
data = batch[0].to(args.device)
|
| 42 |
+
label_cls = batch[1].to(args.device)
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
logits = model(data)
|
| 46 |
+
logits.to(args.device)
|
| 47 |
+
|
| 48 |
+
pbar.update(len(data))
|
| 49 |
+
|
| 50 |
+
_cls_acc, correct = get_acc(logits, label_cls)
|
| 51 |
+
cls_acc += _cls_acc.item()
|
| 52 |
+
|
| 53 |
+
correct_cnt += correct.item()
|
| 54 |
+
total_cnt += len(data)
|
| 55 |
+
|
| 56 |
+
preds = logits.argmax(1)
|
| 57 |
+
for gt, pred in zip(label_cls, preds):
|
| 58 |
+
cs_gt = csAnno[gt.item()]
|
| 59 |
+
cMtrx[cs_gt][csAnno[pred.item()]] += 1
|
| 60 |
+
csTotal[cs_gt] += 1
|
| 61 |
+
if gt == pred:
|
| 62 |
+
csCorrect[cs_gt] += 1
|
| 63 |
+
|
| 64 |
+
return epoch_loss / iters, cls_acc * 100 / iters, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
|
| 69 |
+
|
| 70 |
+
loaded = torch.load(args.load_path)
|
| 71 |
+
model.load_state_dict(loaded['model'])
|
| 72 |
+
model.to(args.device)
|
| 73 |
+
model.eval()
|
| 74 |
+
print('loaded model from {}'.format(args.load_path))
|
| 75 |
+
|
| 76 |
+
print(model)
|
| 77 |
+
|
| 78 |
+
valset = XrdDataset(args.data_dir, args.anno_val)
|
| 79 |
+
val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
|
| 80 |
+
|
| 81 |
+
loss_val, acc_val, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal = run_one_epoch(model, val_loader)
|
| 82 |
+
|
| 83 |
+
print("loss_val: ", loss_val)
|
| 84 |
+
print("acc_val: ", acc_val)
|
| 85 |
+
print("{}% ({}/{})".format(round(correct_cnt / total_cnt, 5) * 100, correct_cnt, total_cnt))
|
| 86 |
+
|
| 87 |
+
sums = np.array(cMtrx).sum(axis=1)
|
| 88 |
+
for i, row in enumerate(cMtrx):
|
| 89 |
+
buf = ""
|
| 90 |
+
for j, v in enumerate(row):
|
| 91 |
+
buf += "{}({}%) ".format(v, round(v / sums[i] * 100, 2))
|
| 92 |
+
print(buf)
|
| 93 |
+
|
| 94 |
+
print("csCorrect: ", csCorrect)
|
| 95 |
+
print("csTotal: ", csTotal)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == '__main__':
|
| 99 |
+
parser = argparse.ArgumentParser()
|
| 100 |
+
|
| 101 |
+
parser.add_argument('--device', default='cuda:0', type=str)
|
| 102 |
+
parser.add_argument('--data_dir', default='data/val/', type=str)
|
| 103 |
+
parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str,
|
| 104 |
+
help='path to load pretrained single-phase identification model')
|
| 105 |
+
parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
|
| 106 |
+
help='path to annotation file for training data')
|
| 107 |
+
parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
|
| 108 |
+
help='path to annotation file for validation data')
|
| 109 |
+
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
|
| 110 |
+
|
| 111 |
+
args = parser.parse_args()
|
| 112 |
+
|
| 113 |
+
main()
|
| 114 |
+
|
| 115 |
+
print('THE END')
|