miyuki2026 commited on
Commit
df0647f
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +28 -0
  3. Dockerfile +21 -0
  4. examples/download/download_space.py +50 -0
  5. examples/hub_download.py +40 -0
  6. examples/playground/chat.py +118 -0
  7. install.sh +64 -0
  8. log.py +257 -0
  9. main.py +16 -0
  10. project_settings.py +25 -0
  11. script/install_nvidia_driver.sh +184 -0
  12. script/install_python.sh +129 -0
  13. toolbox/__init__.py +5 -0
  14. toolbox/cv2/__init__.py +6 -0
  15. toolbox/cv2/misc.py +137 -0
  16. toolbox/json/__init__.py +6 -0
  17. toolbox/json/misc.py +63 -0
  18. toolbox/minimind/__init__.py +5 -0
  19. toolbox/minimind/model/__init__.py +5 -0
  20. toolbox/minimind/model/configuration_minimind.py +80 -0
  21. toolbox/minimind/model/modeling_minimind.py +386 -0
  22. toolbox/os/__init__.py +6 -0
  23. toolbox/os/command.py +59 -0
  24. toolbox/os/environment.py +114 -0
  25. toolbox/os/other.py +9 -0
  26. toolbox/torch/__init__.py +5 -0
  27. toolbox/torch/modules/__init__.py +6 -0
  28. toolbox/torch/modules/gaussian_mixture.py +173 -0
  29. toolbox/torch/modules/highway.py +30 -0
  30. toolbox/torch/modules/loss.py +738 -0
  31. toolbox/torch/training/__init__.py +6 -0
  32. toolbox/torch/training/metrics/__init__.py +6 -0
  33. toolbox/torch/training/metrics/categorical_accuracy.py +82 -0
  34. toolbox/torch/training/metrics/verbose_categorical_accuracy.py +128 -0
  35. toolbox/torch/training/trainer/__init__.py +5 -0
  36. toolbox/torch/training/trainer/trainer.py +5 -0
  37. toolbox/torch/utils/__init__.py +5 -0
  38. toolbox/torch/utils/data/__init__.py +5 -0
  39. toolbox/torch/utils/data/dataset/__init__.py +5 -0
  40. toolbox/torch/utils/data/dataset/wave_classifier_excel_dataset.py +98 -0
  41. toolbox/torch/utils/data/vocabulary.py +211 -0
  42. toolbox/torchaudio/__init__.py +5 -0
  43. toolbox/torchaudio/configuration_utils.py +63 -0
  44. toolbox/torchaudio/models/__init__.py +5 -0
  45. toolbox/torchaudio/models/cnn_audio_classifier/__init__.py +5 -0
  46. toolbox/torchaudio/models/cnn_audio_classifier/configuration_cnn_audio_classifier.py +24 -0
  47. toolbox/torchaudio/models/cnn_audio_classifier/examples/conv2d_classifier.yaml +45 -0
  48. toolbox/torchaudio/models/cnn_audio_classifier/modeling_cnn_audio_classifier.py +403 -0
  49. toolbox/torchaudio/models/lstm_audio_classifier/__init__.py +6 -0
  50. toolbox/torchaudio/models/lstm_audio_classifier/configuration_lstm_audio_classifier.py +22 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .DS_Store
3
+ .git/
4
+ .idea/
5
+
6
+ **/file_dir
7
+ **/flagged/
8
+ **/log/
9
+ **/logs/
10
+ **/__pycache__/
11
+
12
+ /data/
13
+ /docs/
14
+ /dotenv/
15
+ /examples/**/*.wav
16
+ /hub_datasets/
17
+ /trained_models*/
18
+ /pretrained_models/
19
+ /temp/
20
+
21
+ **/*.csv
22
+ **/*.onnx
23
+ **/*.pdf
24
+ **/*.md
25
+ #**/*.wav
26
+ **/*.xlsx
27
+ **/*.jsonl
28
+ **/*.zip
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . /code
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
+
10
+ RUN useradd -m -u 1000 user
11
+
12
+ USER user
13
+
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ CMD ["python3", "main.py"]
examples/download/download_space.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ import platform
6
+
7
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
8
+
9
+ from huggingface_hub import snapshot_download
10
+
11
+ from project_settings import project_path
12
+
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--repo_id", default="intelli-zen/music_comment", type=str)
17
+ parser.add_argument(
18
+ "--local_dir",
19
+ # default=(project_path / "temp/models" / "sft_llama2_stack_exchange").as_posix(),
20
+ # default=(project_path / "temp/spaces" / "keep_alive_a").as_posix(),
21
+ default=(project_path / "temp/datasets" / "music_comment").as_posix(),
22
+ type=str
23
+ )
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+ def main():
29
+ args = get_args()
30
+
31
+ # export HF_ENDPOINT=https://hf-mirror.com
32
+
33
+ # 下载整个仓库
34
+ snapshot_download(
35
+ # repo_type="model",
36
+ # repo_type="space",
37
+ repo_type="dataset",
38
+ repo_id=args.repo_id,
39
+ local_dir=args.local_dir,
40
+ # ignore_patterns=["*.msgpack", "*.h5", "*.ot"],
41
+ )
42
+
43
+ # 或使用命令行
44
+ # pip install huggingface-hub
45
+ # huggingface-cli download 模型ID --local-dir ./model
46
+ return
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
examples/hub_download.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import platform
5
+
6
+ from huggingface_hub import snapshot_download
7
+
8
+ from project_settings import project_path
9
+
10
+
11
+ def get_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--repo_id", default="jingyaogong/MiniMind2", type=str)
14
+ parser.add_argument(
15
+ "--local_dir",
16
+ default=(project_path / "pretrained_models" / "MiniMind2").as_posix(),
17
+ type=str
18
+ )
19
+ args = parser.parse_args()
20
+ return args
21
+
22
+
23
+ def main():
24
+ args = get_args()
25
+
26
+ # 下载整个仓库
27
+ snapshot_download(
28
+ repo_id=args.repo_id,
29
+ local_dir=args.local_dir,
30
+ ignore_patterns=["*.msgpack", "*.h5", "*.ot"],
31
+ )
32
+
33
+ # 或使用命令行
34
+ # pip install huggingface-hub
35
+ # huggingface-cli download 模型ID --local-dir ./model
36
+ return
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
examples/playground/chat.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/jingyaogong/minimind/blob/master/eval_llm.py
5
+ """
6
+ import argparse
7
+ import time
8
+
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
11
+
12
+ from project_settings import project_path
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "--pretrained_model_name_or_path",
19
+ # default="jingyaogong/MiniMind2",
20
+ default=(project_path / "pretrained_models/MiniMind2"),
21
+ type=str
22
+ )
23
+
24
+ parser.add_argument(
25
+ "--max_new_tokens",
26
+ default=8192, # 8192, 128
27
+ type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
28
+ )
29
+ parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)")
30
+ parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
31
+
32
+ parser.add_argument(
33
+ "--show_speed",
34
+ default=1, # 1, 0
35
+ type=int, help="显示decode速度(tokens/s)"
36
+ )
37
+
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def main():
43
+ args = get_args()
44
+
45
+ if torch.cuda.is_available():
46
+ device = "cuda"
47
+ elif torch.backends.mps.is_available():
48
+ # device = "mps"
49
+ device = "cpu"
50
+ else:
51
+ device = "cpu"
52
+ print(f"device: {device}")
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
55
+ model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path)
56
+ model = model.eval().to(device)
57
+ # print(tokenizer)
58
+ # print(model)
59
+
60
+ prompts = [
61
+ "你有什么特长?",
62
+ "为什么天空是蓝色的",
63
+ "请用Python写一个计算斐波那契数列的函数",
64
+ '解释一下"光合作用"的基本过程',
65
+ "如果明天下雨,我应该如何出门",
66
+ "比较一下猫和狗作为宠物的优缺点",
67
+ "解释什么是机器学习",
68
+ "推荐一些中国的美食"
69
+ ]
70
+ input_mode = int(input("[0] 自动测试\n[1] 手动输入\n"))
71
+
72
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
73
+
74
+ # conversation = list()
75
+ conversation = [
76
+ {"role": "system", "content": "You are a helpful assistant"}
77
+ ]
78
+ while True:
79
+ if input_mode == 0:
80
+ if len(prompts) == 0:
81
+ break
82
+ user_input = prompts.pop(0)
83
+ print(f"💬: {user_input}")
84
+ else:
85
+ user_input = input("💬: ")
86
+ user_input = str(user_input).strip()
87
+ conversation.append({"role": "user", "content": user_input})
88
+ inputs = tokenizer.apply_chat_template(
89
+ conversation=conversation,
90
+ tokenize=False,
91
+ add_generation_prompt=True
92
+ )
93
+ inputs = tokenizer.__call__(
94
+ inputs,
95
+ return_tensors="pt",
96
+ truncation=True
97
+ )
98
+ inputs = inputs.to(device)
99
+ # print(inputs)
100
+
101
+ print("🤖: ", end="")
102
+ st = time.time()
103
+ generated_ids = model.generate(
104
+ inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
105
+ max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
106
+ pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
107
+ top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0,
108
+ )
109
+ response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
110
+ conversation.append({"role": "assistant", "content": response})
111
+ gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
112
+ print(f"\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n") if args.show_speed else print("\n\n")
113
+
114
+ return
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
install.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # bash install.sh --stage 2 --stop_stage 2 --system_version centos
4
+
5
+
6
+ python_version=3.8.10
7
+ system_version="centos";
8
+
9
+ verbose=true;
10
+ stage=-1
11
+ stop_stage=0
12
+
13
+
14
+ # parse options
15
+ while true; do
16
+ [ -z "${1:-}" ] && break; # break if there are no arguments
17
+ case "$1" in
18
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
19
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
20
+ old_value="(eval echo \\$$name)";
21
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
22
+ was_bool=true;
23
+ else
24
+ was_bool=false;
25
+ fi
26
+
27
+ # Set the variable to the right value-- the escaped quotes make it work if
28
+ # the option had spaces, like --cmd "queue.pl -sync y"
29
+ eval "${name}=\"$2\"";
30
+
31
+ # Check that Boolean-valued arguments are really Boolean.
32
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
33
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
34
+ exit 1;
35
+ fi
36
+ shift 2;
37
+ ;;
38
+
39
+ *) break;
40
+ esac
41
+ done
42
+
43
+ work_dir="$(pwd)"
44
+
45
+
46
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
47
+ $verbose && echo "stage 1: install python"
48
+ cd "${work_dir}" || exit 1;
49
+
50
+ sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
51
+ fi
52
+
53
+
54
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
55
+ $verbose && echo "stage 2: create virtualenv"
56
+
57
+ # /usr/local/python-3.9.9/bin/virtualenv cc_audio_8
58
+ # source /data/local/bin/cc_audio_8/bin/activate
59
+ /usr/local/python-${python_version}/bin/pip3 install virtualenv
60
+ mkdir -p /data/local/bin
61
+ cd /data/local/bin || exit 1;
62
+ /usr/local/python-${python_version}/bin/virtualenv cc_audio_8
63
+
64
+ fi
log.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from datetime import datetime
4
+ import logging
5
+ from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler
6
+ import os
7
+ from zoneinfo import ZoneInfo # Python 3.9+ 自带,无需安装
8
+
9
+
10
+ def get_converter(tz_info: str = "Asia/Shanghai"):
11
+ def converter(timestamp):
12
+ dt = datetime.fromtimestamp(timestamp, ZoneInfo(tz_info))
13
+ result = dt.timetuple()
14
+ return result
15
+ return converter
16
+
17
+
18
+ def setup_stream(tz_info: str = "Asia/Shanghai"):
19
+ fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
20
+
21
+ formatter = logging.Formatter(
22
+ fmt=fmt,
23
+ datefmt="%Y-%m-%d %H:%M:%S %z"
24
+ )
25
+ formatter.converter = get_converter(tz_info)
26
+
27
+ stream_handler = logging.StreamHandler()
28
+ stream_handler.setLevel(logging.INFO)
29
+ stream_handler.setFormatter(formatter)
30
+
31
+ # main
32
+ main_logger = logging.getLogger("main")
33
+ main_logger.addHandler(stream_handler)
34
+
35
+ # http
36
+ http_logger = logging.getLogger("http")
37
+ http_logger.addHandler(stream_handler)
38
+
39
+ # api
40
+ api_logger = logging.getLogger("api")
41
+ api_logger.addHandler(stream_handler)
42
+
43
+ logging.basicConfig(
44
+ level=logging.DEBUG,
45
+ datefmt="%a, %d %b %Y %H:%M:%S",
46
+ handlers=[
47
+
48
+ ]
49
+ )
50
+ return
51
+
52
+
53
+ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
54
+ fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
55
+
56
+ formatter = logging.Formatter(
57
+ fmt=fmt,
58
+ datefmt="%Y-%m-%d %H:%M:%S %z"
59
+ )
60
+ formatter.converter = get_converter(tz_info)
61
+
62
+ stream_handler = logging.StreamHandler()
63
+ stream_handler.setLevel(logging.INFO)
64
+ stream_handler.setFormatter(formatter)
65
+
66
+ # main
67
+ main_logger = logging.getLogger("main")
68
+ main_logger.addHandler(stream_handler)
69
+ main_info_file_handler = RotatingFileHandler(
70
+ filename=os.path.join(log_directory, "main.log"),
71
+ maxBytes=100*1024*1024, # 100MB
72
+ encoding="utf-8",
73
+ backupCount=2,
74
+ )
75
+ main_info_file_handler.setLevel(logging.INFO)
76
+ main_info_file_handler.setFormatter(formatter)
77
+ main_logger.addHandler(main_info_file_handler)
78
+
79
+ # http
80
+ http_logger = logging.getLogger("http")
81
+ http_logger.addHandler(stream_handler)
82
+ http_file_handler = RotatingFileHandler(
83
+ filename=os.path.join(log_directory, "http.log"),
84
+ maxBytes=100*1024*1024, # 100MB
85
+ encoding="utf-8",
86
+ backupCount=2,
87
+ )
88
+ http_file_handler.setLevel(logging.DEBUG)
89
+ http_file_handler.setFormatter(formatter)
90
+ http_logger.addHandler(http_file_handler)
91
+
92
+ # api
93
+ api_logger = logging.getLogger("api")
94
+ api_logger.addHandler(stream_handler)
95
+ api_file_handler = RotatingFileHandler(
96
+ filename=os.path.join(log_directory, "api.log"),
97
+ maxBytes=10*1024*1024, # 10MB
98
+ encoding="utf-8",
99
+ backupCount=2,
100
+ )
101
+ api_file_handler.setLevel(logging.DEBUG)
102
+ api_file_handler.setFormatter(formatter)
103
+ api_logger.addHandler(api_file_handler)
104
+
105
+ # alarm
106
+ alarm_logger = logging.getLogger("alarm")
107
+ alarm_file_handler = RotatingFileHandler(
108
+ filename=os.path.join(log_directory, "alarm.log"),
109
+ maxBytes=1*1024*1024, # 1MB
110
+ encoding="utf-8",
111
+ backupCount=2,
112
+ )
113
+ alarm_file_handler.setLevel(logging.DEBUG)
114
+ alarm_file_handler.setFormatter(formatter)
115
+ alarm_logger.addHandler(alarm_file_handler)
116
+
117
+ debug_file_handler = RotatingFileHandler(
118
+ filename=os.path.join(log_directory, "debug.log"),
119
+ maxBytes=1*1024*1024, # 1MB
120
+ encoding="utf-8",
121
+ backupCount=2,
122
+ )
123
+ debug_file_handler.setLevel(logging.DEBUG)
124
+ debug_file_handler.setFormatter(formatter)
125
+
126
+ info_file_handler = RotatingFileHandler(
127
+ filename=os.path.join(log_directory, "info.log"),
128
+ maxBytes=1*1024*1024, # 1MB
129
+ encoding="utf-8",
130
+ backupCount=2,
131
+ )
132
+ info_file_handler.setLevel(logging.INFO)
133
+ info_file_handler.setFormatter(formatter)
134
+
135
+ error_file_handler = RotatingFileHandler(
136
+ filename=os.path.join(log_directory, "error.log"),
137
+ maxBytes=1*1024*1024, # 1MB
138
+ encoding="utf-8",
139
+ backupCount=2,
140
+ )
141
+ error_file_handler.setLevel(logging.ERROR)
142
+ error_file_handler.setFormatter(formatter)
143
+
144
+ logging.basicConfig(
145
+ level=logging.DEBUG,
146
+ datefmt="%a, %d %b %Y %H:%M:%S",
147
+ handlers=[
148
+ debug_file_handler,
149
+ info_file_handler,
150
+ error_file_handler,
151
+ ]
152
+ )
153
+
154
+
155
+ def setup_time_rotating(log_directory: str):
156
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
157
+
158
+ stream_handler = logging.StreamHandler()
159
+ stream_handler.setLevel(logging.INFO)
160
+ stream_handler.setFormatter(logging.Formatter(fmt))
161
+
162
+ # main
163
+ main_logger = logging.getLogger("main")
164
+ main_logger.addHandler(stream_handler)
165
+ main_info_file_handler = TimedRotatingFileHandler(
166
+ filename=os.path.join(log_directory, "main.log"),
167
+ encoding="utf-8",
168
+ when="midnight",
169
+ interval=1,
170
+ backupCount=7
171
+ )
172
+ main_info_file_handler.setLevel(logging.INFO)
173
+ main_info_file_handler.setFormatter(logging.Formatter(fmt))
174
+ main_logger.addHandler(main_info_file_handler)
175
+
176
+ # http
177
+ http_logger = logging.getLogger("http")
178
+ http_file_handler = TimedRotatingFileHandler(
179
+ filename=os.path.join(log_directory, "http.log"),
180
+ encoding='utf-8',
181
+ when="midnight",
182
+ interval=1,
183
+ backupCount=7
184
+ )
185
+ http_file_handler.setLevel(logging.DEBUG)
186
+ http_file_handler.setFormatter(logging.Formatter(fmt))
187
+ http_logger.addHandler(http_file_handler)
188
+
189
+ # api
190
+ api_logger = logging.getLogger("api")
191
+ api_file_handler = TimedRotatingFileHandler(
192
+ filename=os.path.join(log_directory, "api.log"),
193
+ encoding='utf-8',
194
+ when="midnight",
195
+ interval=1,
196
+ backupCount=7
197
+ )
198
+ api_file_handler.setLevel(logging.DEBUG)
199
+ api_file_handler.setFormatter(logging.Formatter(fmt))
200
+ api_logger.addHandler(api_file_handler)
201
+
202
+ # alarm
203
+ alarm_logger = logging.getLogger("alarm")
204
+ alarm_file_handler = TimedRotatingFileHandler(
205
+ filename=os.path.join(log_directory, "alarm.log"),
206
+ encoding="utf-8",
207
+ when="midnight",
208
+ interval=1,
209
+ backupCount=7
210
+ )
211
+ alarm_file_handler.setLevel(logging.DEBUG)
212
+ alarm_file_handler.setFormatter(logging.Formatter(fmt))
213
+ alarm_logger.addHandler(alarm_file_handler)
214
+
215
+ debug_file_handler = TimedRotatingFileHandler(
216
+ filename=os.path.join(log_directory, "debug.log"),
217
+ encoding="utf-8",
218
+ when="D",
219
+ interval=1,
220
+ backupCount=7
221
+ )
222
+ debug_file_handler.setLevel(logging.DEBUG)
223
+ debug_file_handler.setFormatter(logging.Formatter(fmt))
224
+
225
+ info_file_handler = TimedRotatingFileHandler(
226
+ filename=os.path.join(log_directory, "info.log"),
227
+ encoding="utf-8",
228
+ when="D",
229
+ interval=1,
230
+ backupCount=7
231
+ )
232
+ info_file_handler.setLevel(logging.INFO)
233
+ info_file_handler.setFormatter(logging.Formatter(fmt))
234
+
235
+ error_file_handler = TimedRotatingFileHandler(
236
+ filename=os.path.join(log_directory, "error.log"),
237
+ encoding="utf-8",
238
+ when="D",
239
+ interval=1,
240
+ backupCount=7
241
+ )
242
+ error_file_handler.setLevel(logging.ERROR)
243
+ error_file_handler.setFormatter(logging.Formatter(fmt))
244
+
245
+ logging.basicConfig(
246
+ level=logging.DEBUG,
247
+ datefmt="%a, %d %b %Y %H:%M:%S",
248
+ handlers=[
249
+ debug_file_handler,
250
+ info_file_handler,
251
+ error_file_handler,
252
+ ]
253
+ )
254
+
255
+
256
+ if __name__ == "__main__":
257
+ pass
main.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 这是一个示例 Python 脚本。
2
+
3
+ # 按 ⌃R 执行或将其替换为您的代码。
4
+ # 按 双击 ⇧ 在所有地方搜索类、文件、工具窗口、操作和设置。
5
+
6
+
7
+ def print_hi(name):
8
+ # 在下面的代码行中使用断点来调试脚本。
9
+ print(f'Hi, {name}') # 按 ⌘F8 切换断点。
10
+
11
+
12
+ # 按装订区域中的绿色按钮以运行脚本。
13
+ if __name__ == '__main__':
14
+ print_hi('PyCharm')
15
+
16
+ # 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助
project_settings.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from toolbox.os.environment import EnvironmentManager
7
+
8
+
9
+ project_path = os.path.abspath(os.path.dirname(__file__))
10
+ project_path = Path(project_path)
11
+
12
+ log_directory = project_path / "logs"
13
+ log_directory.mkdir(parents=True, exist_ok=True)
14
+
15
+ temp_directory = project_path / "temp"
16
+ temp_directory.mkdir(parents=True, exist_ok=True)
17
+
18
+ environment = EnvironmentManager(
19
+ path=os.path.join(project_path, "dotenv"),
20
+ env=os.environ.get("environment", "dev"),
21
+ )
22
+
23
+
24
+ if __name__ == "__main__":
25
+ pass
script/install_nvidia_driver.sh ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #GPU驱动安装需要先将原有的显示关闭, 重启机器, 再进行安装.
3
+ #参考链接:
4
+ #https://blog.csdn.net/kingschan/article/details/19033595
5
+ #https://blog.csdn.net/HaixWang/article/details/90408538
6
+ #
7
+ #>>> yum install -y pciutils
8
+ #查看 linux 机器上是否有 GPU
9
+ #lspci |grep -i nvidia
10
+ #
11
+ #>>> lspci |grep -i nvidia
12
+ #00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
13
+ #
14
+ #
15
+ #NVIDIA 驱动程序下载
16
+ #先在 pytorch 上查看应该用什么 cuda 版本, 再安装对应的 cuda-toolkit cuda.
17
+ #再根据 gpu 版本下载安装对应的 nvidia 驱动
18
+ #
19
+ ## pytorch 版本
20
+ #https://pytorch.org/get-started/locally/
21
+ #
22
+ ## CUDA 下载 (好像不需要这个)
23
+ #https://developer.nvidia.com/cuda-toolkit-archive
24
+ #
25
+ ## nvidia 驱动
26
+ #https://www.nvidia.cn/Download/index.aspx?lang=cn
27
+ #http://www.nvidia.com/Download/index.aspx
28
+ #
29
+ #在下方的下拉列表中进行选择,针对您的 NVIDIA 产品确定合适的驱动。
30
+ #产品类型:
31
+ #Data Center / Tesla
32
+ #产品系列:
33
+ #T-Series
34
+ #产品家族:
35
+ #Tesla T4
36
+ #操作系统:
37
+ #Linux 64-bit
38
+ #CUDA Toolkit:
39
+ #10.2
40
+ #语言:
41
+ #Chinese (Simpleified)
42
+ #
43
+ #
44
+ #>>> mkdir -p /data/tianxing
45
+ #>>> cd /data/tianxing
46
+ #>>> wget https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
47
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
48
+ #
49
+ ## 异常:
50
+ #ERROR: The Nouveau kernel driver is currently in use by your system. This driver is incompatible with the NVIDIA driver, and must be disabled before proceeding. Please consult the NVIDIA driver README and your
51
+ #Linux distribution's documentation for details on how to correctly disable the Nouveau kernel driver.
52
+ #[OK]
53
+ #
54
+ #For some distributions, Nouveau can be disabled by adding a file in the modprobe configuration directory. Would you like nvidia-installer to attempt to create this modprobe file for you?
55
+ #[NO]
56
+ #
57
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
58
+ #page at www.nvidia.com.
59
+ #[OK]
60
+ #
61
+ ## 参考链接:
62
+ #https://blog.csdn.net/kingschan/article/details/19033595
63
+ #
64
+ ## 禁用原有的显卡驱动 nouveau
65
+ #>>> echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
66
+ #>>> sudo dracut --force
67
+ ## 重启
68
+ #>>> reboot
69
+ #
70
+ #>>> init 3
71
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
72
+ #
73
+ ## 异常
74
+ #ERROR: Unable to find the kernel source tree for the currently running kernel. Please make sure you have installed the kernel source files for your kernel and that they are properly configured; on Red Hat Linux systems, for example, be sure you have the 'kernel-source' or 'kernel-devel' RPM installed. If you know the correct kernel source files are installed, you may specify the kernel source path with the '--kernel-source-path' command line option.
75
+ #[OK]
76
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
77
+ #page at www.nvidia.com.
78
+ #[OK]
79
+ #
80
+ ## 参考链接
81
+ ## https://blog.csdn.net/HaixWang/article/details/90408538
82
+ #
83
+ #>>> uname -r
84
+ #3.10.0-1160.49.1.el7.x86_64
85
+ #>>> yum install kernel-devel kernel-headers -y
86
+ #>>> yum info kernel-devel kernel-headers
87
+ #>>> yum install -y "kernel-devel-uname-r == $(uname -r)"
88
+ #>>> yum -y distro-sync
89
+ #
90
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
91
+ #
92
+ ## 安装成功
93
+ #WARNING: nvidia-installer was forced to guess the X library path '/usr/lib64' and X module path '/usr/lib64/xorg/modules'; these paths were not queryable from the system. If X fails to find the NVIDIA X driver
94
+ #module, please install the `pkg-config` utility and the X.Org SDK/development package for your distribution and reinstall the driver.
95
+ #[OK]
96
+ #Install NVIDIA's 32-bit compatibility libraries?
97
+ #[YES]
98
+ #Installation of the kernel module for the NVIDIA Accelerated Graphics Driver for Linux-x86_64 (version 440.118.02) is now complete.
99
+ #[OK]
100
+ #
101
+ #
102
+ ## 查看 GPU 使用情况; watch -n 1 -d nvidia-smi 每1秒刷新一次.
103
+ #>>> nvidia-smi
104
+ #Thu Mar 9 12:00:37 2023
105
+ #+-----------------------------------------------------------------------------+
106
+ #| NVIDIA-SMI 440.118.02 Driver Version: 440.118.02 CUDA Version: 10.2 |
107
+ #|-------------------------------+----------------------+----------------------+
108
+ #| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
109
+ #| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
110
+ #|===============================+======================+======================|
111
+ #| 0 Tesla T4 Off | 00000000:00:08.0 Off | Off |
112
+ #| N/A 54C P0 22W / 70W | 0MiB / 16127MiB | 0% Default |
113
+ #+-------------------------------+----------------------+----------------------+
114
+ #
115
+ #+-----------------------------------------------------------------------------+
116
+ #| Processes: GPU Memory |
117
+ #| GPU PID Type Process name Usage |
118
+ #|=============================================================================|
119
+ #| No running processes found |
120
+ #+-----------------------------------------------------------------------------+
121
+ #
122
+ #
123
+
124
+ # params
125
+ stage=1
126
+ nvidia_driver_filename=https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
127
+
128
+ # parse options
129
+ while true; do
130
+ [ -z "${1:-}" ] && break; # break if there are no arguments
131
+ case "$1" in
132
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
133
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
134
+ old_value="(eval echo \\$$name)";
135
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
136
+ was_bool=true;
137
+ else
138
+ was_bool=false;
139
+ fi
140
+
141
+ # Set the variable to the right value-- the escaped quotes make it work if
142
+ # the option had spaces, like --cmd "queue.pl -sync y"
143
+ eval "${name}=\"$2\"";
144
+
145
+ # Check that Boolean-valued arguments are really Boolean.
146
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
147
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
148
+ exit 1;
149
+ fi
150
+ shift 2;
151
+ ;;
152
+
153
+ *) break;
154
+ esac
155
+ done
156
+
157
+ echo "stage: ${stage}";
158
+
159
+ yum -y install wget
160
+ yum -y install sudo
161
+
162
+ if [ ${stage} -eq 0 ]; then
163
+ mkdir -p /data/dep
164
+ cd /data/dep || echo 1;
165
+ wget -P /data/dep ${nvidia_driver_filename}
166
+
167
+ echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
168
+ sudo dracut --force
169
+ # 重启
170
+ reboot
171
+ elif [ ${stage} -eq 1 ]; then
172
+ init 3
173
+
174
+ yum install -y kernel-devel kernel-headers
175
+ yum info kernel-devel kernel-headers
176
+ yum install -y "kernel-devel-uname-r == $(uname -r)"
177
+ yum -y distro-sync
178
+
179
+ cd /data/dep || echo 1;
180
+
181
+ # 安装时, 需要回车三下.
182
+ sh NVIDIA-Linux-x86_64-440.118.02.run
183
+ nvidia-smi
184
+ fi
script/install_python.sh ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # 参数:
4
+ python_version="3.6.5";
5
+ system_version="centos";
6
+
7
+
8
+ # parse options
9
+ while true; do
10
+ [ -z "${1:-}" ] && break; # break if there are no arguments
11
+ case "$1" in
12
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
13
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
14
+ old_value="(eval echo \\$$name)";
15
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
16
+ was_bool=true;
17
+ else
18
+ was_bool=false;
19
+ fi
20
+
21
+ # Set the variable to the right value-- the escaped quotes make it work if
22
+ # the option had spaces, like --cmd "queue.pl -sync y"
23
+ eval "${name}=\"$2\"";
24
+
25
+ # Check that Boolean-valued arguments are really Boolean.
26
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
27
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
28
+ exit 1;
29
+ fi
30
+ shift 2;
31
+ ;;
32
+
33
+ *) break;
34
+ esac
35
+ done
36
+
37
+ echo "python_version: ${python_version}";
38
+ echo "system_version: ${system_version}";
39
+
40
+
41
+ if [ ${system_version} = "centos" ]; then
42
+ # 安装 python 开发编译环境
43
+ yum -y groupinstall "Development tools"
44
+ yum -y install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel
45
+ yum install libffi-devel -y
46
+ yum install -y wget
47
+ yum install -y make
48
+
49
+ mkdir -p /data/dep
50
+ cd /data/dep || exit 1;
51
+ if [ ! -e Python-${python_version}.tgz ]; then
52
+ wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
53
+ fi
54
+
55
+ cd /data/dep || exit 1;
56
+ if [ ! -d Python-${python_version} ]; then
57
+ tar -zxvf Python-${python_version}.tgz
58
+ cd /data/dep/Python-${python_version} || exit 1;
59
+ fi
60
+
61
+ mkdir /usr/local/python-${python_version}
62
+ ./configure --prefix=/usr/local/python-${python_version}
63
+ make && make install
64
+
65
+ /usr/local/python-${python_version}/bin/python3 -V
66
+ /usr/local/python-${python_version}/bin/pip3 -V
67
+
68
+ rm -rf /usr/local/bin/python3
69
+ rm -rf /usr/local/bin/pip3
70
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
71
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
72
+
73
+ python3 -V
74
+ pip3 -V
75
+
76
+ elif [ ${system_version} = "ubuntu" ]; then
77
+ # 安装 python 开发编译环境
78
+ # https://zhuanlan.zhihu.com/p/506491209
79
+
80
+ # 刷新软件包目录
81
+ sudo apt update
82
+ # 列出当前可用的更新
83
+ sudo apt list --upgradable
84
+ # 如上一步提示有可以更新的项目,则执行更新
85
+ sudo apt -y upgrade
86
+ # 安装 GCC 编译器
87
+ sudo apt install gcc
88
+ # 检查安装是否成功
89
+ gcc -v
90
+
91
+ # 安装依赖
92
+ sudo apt install -y build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libbz2-dev liblzma-dev sqlite3 libsqlite3-dev tk-dev uuid-dev libgdbm-compat-dev
93
+
94
+ mkdir -p /data/dep
95
+ cd /data/dep || exit 1;
96
+ if [ ! -e Python-${python_version}.tgz ]; then
97
+ # sudo wget -P /data/dep https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tgz
98
+ sudo wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
99
+ fi
100
+
101
+ cd /data/dep || exit 1;
102
+ if [ ! -d Python-${python_version} ]; then
103
+ # tar -zxvf Python-3.6.5.tgz
104
+ tar -zxvf Python-${python_version}.tgz
105
+ # cd /data/dep/Python-3.6.5
106
+ cd /data/dep/Python-${python_version} || exit 1;
107
+ fi
108
+
109
+ # mkdir /usr/local/python-3.6.5
110
+ mkdir /usr/local/python-${python_version}
111
+
112
+ # 检查依赖与配置编译
113
+ # sudo ./configure --prefix=/usr/local/python-3.6.5 --enable-optimizations --with-lto --enable-shared
114
+ sudo ./configure --prefix=/usr/local/python-${python_version} --enable-optimizations --with-lto --enable-shared
115
+ cpu_count=$(cat /proc/cpuinfo | grep processor | wc -l)
116
+ # sudo make -j 4
117
+ sudo make -j "${cpu_count}"
118
+
119
+ /usr/local/python-${python_version}/bin/python3 -V
120
+ /usr/local/python-${python_version}/bin/pip3 -V
121
+
122
+ rm -rf /usr/local/bin/python3
123
+ rm -rf /usr/local/bin/pip3
124
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
125
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
126
+
127
+ python3 -V
128
+ pip3 -V
129
+ fi
toolbox/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/cv2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/cv2/misc.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Union
4
+
5
+
6
+ def erode(labels: List[Union[str, int]], erode_label: Union[str, int], n: int = 1):
7
+ """
8
+ 遍历 labels 列表, 将连续的 erode_label 标签侵蚀 n 个.
9
+ """
10
+ result = list()
11
+ in_span = False
12
+ count = 0
13
+ for idx, label in enumerate(labels):
14
+ if label == erode_label:
15
+ if not in_span:
16
+ in_span = True
17
+ count = 0
18
+ if count < n:
19
+ if len(result) == 0:
20
+ result.append(label)
21
+ else:
22
+ result.append(result[-1])
23
+ count += 1
24
+ continue
25
+ else:
26
+ result.append(label)
27
+ continue
28
+ elif label != erode_label:
29
+ if in_span:
30
+ in_span = False
31
+
32
+ for i in range(min(len(result), n)):
33
+ result[-i-1] = label
34
+ result.append(label)
35
+ continue
36
+ else:
37
+ result.append(label)
38
+ continue
39
+
40
+ result.append(label)
41
+ return result
42
+
43
+
44
+ def dilate(labels: List[Union[str, int]], dilate_label: Union[str, int], n: int = 1):
45
+ """
46
+ 遍历 labels 列表, 将连续的 dilate_label 标签扩张 n 个.
47
+ """
48
+ result = list()
49
+ in_span = False
50
+ count = float('inf')
51
+ for idx, label in enumerate(labels):
52
+ if count < n:
53
+ result.append(dilate_label)
54
+ count += 1
55
+ continue
56
+ if label == dilate_label:
57
+ if not in_span:
58
+ in_span = True
59
+
60
+ for i in range(min(len(result), n)):
61
+ result[-i-1] = label
62
+ result.append(label)
63
+ continue
64
+ else:
65
+ result.append(label)
66
+ continue
67
+ else:
68
+ if in_span:
69
+ in_span = False
70
+ result.append(dilate_label)
71
+ count = 1
72
+ continue
73
+ else:
74
+ result.append(label)
75
+ continue
76
+
77
+ return result
78
+
79
+
80
+ def demo1():
81
+ labels = [
82
+ 'voice', 'mute', 'mute', 'voice', 'voice', 'voice', 'voice', 'bell', 'bell', 'bell', 'mute', 'mute', 'mute', 'voice',
83
+ ]
84
+
85
+ result = erode(
86
+ labels=labels,
87
+ erode_label='voice',
88
+ n=1,
89
+
90
+ )
91
+ print(len(labels))
92
+ print(len(result))
93
+ print(result)
94
+ return
95
+
96
+
97
+ def demo2():
98
+ labels = [
99
+ 'voice', 'mute', 'mute', 'voice', 'voice', 'voice', 'voice', 'bell', 'bell', 'bell', 'mute', 'mute', 'mute', 'voice',
100
+ ]
101
+
102
+ result = dilate(
103
+ labels=labels,
104
+ dilate_label='voice',
105
+ n=2,
106
+
107
+ )
108
+ print(len(labels))
109
+ print(len(result))
110
+ print(result)
111
+
112
+ return
113
+
114
+
115
+ def demo3():
116
+ import time
117
+ labels = ['mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'voice', 'bell', 'bell', 'bell', 'bell', 'bell', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'bell', 'bell', 'bell', 'bell', 'bell', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'bell', 'bell', 'bell', 'bell', 'bell', 'bell', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute', 'mute']
118
+
119
+ begin = time.time()
120
+ labels = erode(labels, erode_label='music', n=1)
121
+ labels = dilate(labels, dilate_label='music', n=1)
122
+
123
+ labels = dilate(labels, dilate_label='voice', n=2)
124
+ labels = erode(labels, erode_label='voice', n=2)
125
+ labels = erode(labels, erode_label='voice', n=1)
126
+ labels = dilate(labels, dilate_label='voice', n=3)
127
+
128
+ cost = time.time() - begin
129
+ print(cost)
130
+ print(labels)
131
+ return
132
+
133
+
134
+ if __name__ == '__main__':
135
+ # demo1()
136
+ # demo2()
137
+ demo3()
toolbox/json/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/json/misc.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Callable
4
+
5
+
6
+ def traverse(js, callback: Callable, *args, **kwargs):
7
+ if isinstance(js, list):
8
+ result = list()
9
+ for l in js:
10
+ l = traverse(l, callback, *args, **kwargs)
11
+ result.append(l)
12
+ return result
13
+ elif isinstance(js, tuple):
14
+ result = list()
15
+ for l in js:
16
+ l = traverse(l, callback, *args, **kwargs)
17
+ result.append(l)
18
+ return tuple(result)
19
+ elif isinstance(js, dict):
20
+ result = dict()
21
+ for k, v in js.items():
22
+ k = traverse(k, callback, *args, **kwargs)
23
+ v = traverse(v, callback, *args, **kwargs)
24
+ result[k] = v
25
+ return result
26
+ elif isinstance(js, int):
27
+ return callback(js, *args, **kwargs)
28
+ elif isinstance(js, str):
29
+ return callback(js, *args, **kwargs)
30
+ else:
31
+ return js
32
+
33
+
34
+ def demo1():
35
+ d = {
36
+ "env": "ppe",
37
+ "mysql_connect": {
38
+ "host": "$mysql_connect_host",
39
+ "port": 3306,
40
+ "user": "callbot",
41
+ "password": "NxcloudAI2021!",
42
+ "database": "callbot_ppe",
43
+ "charset": "utf8"
44
+ },
45
+ "es_connect": {
46
+ "hosts": ["10.20.251.8"],
47
+ "http_auth": ["elastic", "ElasticAI2021!"],
48
+ "port": 9200
49
+ }
50
+ }
51
+
52
+ def callback(s):
53
+ if isinstance(s, str) and s.startswith('$'):
54
+ return s[1:]
55
+ return s
56
+
57
+ result = traverse(d, callback=callback)
58
+ print(result)
59
+ return
60
+
61
+
62
+ if __name__ == '__main__':
63
+ demo1()
toolbox/minimind/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == "__main__":
5
+ pass
toolbox/minimind/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == "__main__":
5
+ pass
toolbox/minimind/model/configuration_minimind.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class MiniMindConfig(PretrainedConfig):
7
+ model_type = "minimind"
8
+
9
+ def __init__(
10
+ self,
11
+ dropout: float = 0.0,
12
+ bos_token_id: int = 1,
13
+ eos_token_id: int = 2,
14
+ hidden_act: str = 'silu',
15
+ hidden_size: int = 512,
16
+ intermediate_size: int = None,
17
+ max_position_embeddings: int = 32768,
18
+ num_attention_heads: int = 8,
19
+ num_hidden_layers: int = 8,
20
+ num_key_value_heads: int = 2,
21
+ vocab_size: int = 6400,
22
+ rms_norm_eps: float = 1e-05,
23
+ rope_theta: int = 1000000.0,
24
+ inference_rope_scaling: bool = False,
25
+ flash_attn: bool = True,
26
+ ####################################################
27
+ # Here are the specific configurations of MOE
28
+ # When use_moe is false, the following is invalid
29
+ ####################################################
30
+ use_moe: bool = False,
31
+ num_experts_per_tok: int = 2,
32
+ n_routed_experts: int = 4,
33
+ n_shared_experts: int = 1,
34
+ scoring_func: str = 'softmax',
35
+ aux_loss_alpha: float = 0.01,
36
+ seq_aux: bool = True,
37
+ norm_topk_prob: bool = True,
38
+ **kwargs
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.dropout = dropout
42
+ self.bos_token_id = bos_token_id
43
+ self.eos_token_id = eos_token_id
44
+ self.hidden_act = hidden_act
45
+ self.hidden_size = hidden_size
46
+ self.intermediate_size = intermediate_size
47
+ self.max_position_embeddings = max_position_embeddings
48
+ self.num_attention_heads = num_attention_heads
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.num_key_value_heads = num_key_value_heads
51
+ self.vocab_size = vocab_size
52
+ self.rms_norm_eps = rms_norm_eps
53
+ self.rope_theta = rope_theta
54
+ self.inference_rope_scaling = inference_rope_scaling
55
+ # 外推长度 = factor * original_max_position_embeddings = 32768
56
+ self.rope_scaling = {
57
+ "beta_fast": 32,
58
+ "beta_slow": 1,
59
+ "factor": 16,
60
+ "original_max_position_embeddings": 2048,
61
+ "attention_factor": 1.0,
62
+ "type": "yarn"
63
+ } if self.inference_rope_scaling else None
64
+ self.flash_attn = flash_attn
65
+ ####################################################
66
+ # Here are the specific configurations of MOE
67
+ # When use_moe is false, the following is invalid
68
+ ####################################################
69
+ self.use_moe = use_moe
70
+ self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
71
+ self.n_routed_experts = n_routed_experts # 总的专家数量
72
+ self.n_shared_experts = n_shared_experts # 共享专家
73
+ self.scoring_func = scoring_func # 评分函数,默认为'softmax'
74
+ self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
75
+ self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
76
+ self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
77
+
78
+
79
+ if __name__ == "__main__":
80
+ pass
toolbox/minimind/model/modeling_minimind.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ import torch
5
+ import torch.nn.init as init
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from transformers.activations import ACT2FN
9
+ from typing import Optional, Tuple, List, Union
10
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from toolbox.minimind.model.configuration_minimind import MiniMindConfig
13
+
14
+
15
+ class RMSNorm(torch.nn.Module):
16
+ def __init__(self, dim: int, eps: float = 1e-5):
17
+ super().__init__()
18
+ self.eps = eps
19
+ self.weight = nn.Parameter(torch.ones(dim))
20
+
21
+ def _norm(self, x):
22
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
23
+
24
+ def forward(self, x):
25
+ return self.weight * self._norm(x.float()).type_as(x)
26
+
27
+
28
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
29
+ rope_scaling: Optional[dict] = None):
30
+ freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
31
+ if rope_scaling is not None:
32
+ orig_max, factor, beta_fast, beta_slow, attn_factor = (
33
+ rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
34
+ rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
35
+ )
36
+ if end / orig_max > 1.0:
37
+ # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
38
+ inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
39
+ low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
40
+ ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
41
+ freqs = freqs * (1 - ramp + ramp / factor)
42
+
43
+ t = torch.arange(end, device=freqs.device)
44
+ freqs = torch.outer(t, freqs).float()
45
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
46
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
47
+ return freqs_cos, freqs_sin
48
+
49
+
50
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
51
+ def rotate_half(x):
52
+ return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
53
+
54
+ q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
55
+ k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
56
+ return q_embed, k_embed
57
+
58
+
59
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
60
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
61
+ bs, slen, num_key_value_heads, head_dim = x.shape
62
+ if n_rep == 1:
63
+ return x
64
+ return (
65
+ x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
66
+ )
67
+
68
+
69
+ class Attention(nn.Module):
70
+ def __init__(self, args: MiniMindConfig):
71
+ super().__init__()
72
+ self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
73
+ assert args.num_attention_heads % self.num_key_value_heads == 0
74
+ self.n_local_heads = args.num_attention_heads
75
+ self.n_local_kv_heads = self.num_key_value_heads
76
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
77
+ self.head_dim = args.hidden_size // args.num_attention_heads
78
+ self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
79
+ self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
80
+ self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
81
+ self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
82
+ self.attn_dropout = nn.Dropout(args.dropout)
83
+ self.resid_dropout = nn.Dropout(args.dropout)
84
+ self.dropout = args.dropout
85
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
86
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
87
+
88
+ def forward(self,
89
+ x: torch.Tensor,
90
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 修改为接收cos和sin
91
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
92
+ use_cache=False,
93
+ attention_mask: Optional[torch.Tensor] = None):
94
+ bsz, seq_len, _ = x.shape
95
+ xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
96
+ xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
97
+ xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
98
+ xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
99
+
100
+ cos, sin = position_embeddings
101
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
102
+
103
+ # kv_cache实现
104
+ if past_key_value is not None:
105
+ xk = torch.cat([past_key_value[0], xk], dim=1)
106
+ xv = torch.cat([past_key_value[1], xv], dim=1)
107
+ past_kv = (xk, xv) if use_cache else None
108
+
109
+ xq, xk, xv = (
110
+ xq.transpose(1, 2),
111
+ repeat_kv(xk, self.n_rep).transpose(1, 2),
112
+ repeat_kv(xv, self.n_rep).transpose(1, 2)
113
+ )
114
+
115
+ if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
116
+ output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
117
+ else:
118
+ scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
119
+ scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
120
+
121
+ if attention_mask is not None:
122
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
123
+ extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
124
+ scores = scores + extended_attention_mask
125
+
126
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
127
+ scores = self.attn_dropout(scores)
128
+ output = scores @ xv
129
+
130
+ output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
131
+ output = self.resid_dropout(self.o_proj(output))
132
+ return output, past_kv
133
+
134
+
135
+ class FeedForward(nn.Module):
136
+ def __init__(self, config: MiniMindConfig):
137
+ super().__init__()
138
+ if config.intermediate_size is None:
139
+ intermediate_size = int(config.hidden_size * 8 / 3)
140
+ config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
141
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
142
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
143
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
144
+ self.dropout = nn.Dropout(config.dropout)
145
+ self.act_fn = ACT2FN[config.hidden_act]
146
+
147
+ def forward(self, x):
148
+ return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
149
+
150
+
151
+ class MoEGate(nn.Module):
152
+ def __init__(self, config: MiniMindConfig):
153
+ super().__init__()
154
+ self.config = config
155
+ self.top_k = config.num_experts_per_tok
156
+ self.n_routed_experts = config.n_routed_experts
157
+
158
+ self.scoring_func = config.scoring_func
159
+ self.alpha = config.aux_loss_alpha
160
+ self.seq_aux = config.seq_aux
161
+
162
+ self.norm_topk_prob = config.norm_topk_prob
163
+ self.gating_dim = config.hidden_size
164
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
165
+ self.reset_parameters()
166
+
167
+ def reset_parameters(self) -> None:
168
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
169
+
170
+ def forward(self, hidden_states):
171
+ bsz, seq_len, h = hidden_states.shape
172
+ hidden_states = hidden_states.view(-1, h)
173
+ logits = F.linear(hidden_states, self.weight, None)
174
+ if self.scoring_func == 'softmax':
175
+ scores = logits.softmax(dim=-1)
176
+ else:
177
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
178
+
179
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
180
+
181
+ if self.top_k > 1 and self.norm_topk_prob:
182
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
183
+ topk_weight = topk_weight / denominator
184
+
185
+ if self.training and self.alpha > 0.0:
186
+ scores_for_aux = scores
187
+ aux_topk = self.top_k
188
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
189
+ if self.seq_aux:
190
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
191
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
192
+ ce.scatter_add_(1, topk_idx_for_aux_loss,
193
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
194
+ seq_len * aux_topk / self.n_routed_experts)
195
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
196
+ else:
197
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
198
+ ce = mask_ce.float().mean(0)
199
+ Pi = scores_for_aux.mean(0)
200
+ fi = ce * self.n_routed_experts
201
+ aux_loss = (Pi * fi).sum() * self.alpha
202
+ else:
203
+ aux_loss = scores.new_zeros(1).squeeze()
204
+ return topk_idx, topk_weight, aux_loss
205
+
206
+
207
+ class MOEFeedForward(nn.Module):
208
+ def __init__(self, config: MiniMindConfig):
209
+ super().__init__()
210
+ self.config = config
211
+ self.experts = nn.ModuleList([
212
+ FeedForward(config)
213
+ for _ in range(config.n_routed_experts)
214
+ ])
215
+ self.gate = MoEGate(config)
216
+ if config.n_shared_experts > 0:
217
+ self.shared_experts = nn.ModuleList([
218
+ FeedForward(config)
219
+ for _ in range(config.n_shared_experts)
220
+ ])
221
+
222
+ def forward(self, x):
223
+ identity = x
224
+ orig_shape = x.shape
225
+ bsz, seq_len, _ = x.shape
226
+ # 使用门控机制选择专家
227
+ topk_idx, topk_weight, aux_loss = self.gate(x)
228
+ x = x.view(-1, x.shape[-1])
229
+ flat_topk_idx = topk_idx.view(-1)
230
+ if self.training:
231
+ x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
232
+ y = torch.empty_like(x, dtype=x.dtype)
233
+ for i, expert in enumerate(self.experts):
234
+ expert_out = expert(x[flat_topk_idx == i])
235
+ if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype)
236
+ else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
237
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
238
+ y = y.view(*orig_shape)
239
+ else:
240
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
241
+ if self.config.n_shared_experts > 0:
242
+ for expert in self.shared_experts:
243
+ y = y + expert(identity)
244
+ self.aux_loss = aux_loss
245
+ return y
246
+
247
+ @torch.no_grad()
248
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
249
+ expert_cache = torch.zeros_like(x)
250
+ idxs = flat_expert_indices.argsort()
251
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
252
+ token_idxs = idxs // self.config.num_experts_per_tok
253
+ # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
254
+ # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
255
+ # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
256
+ # 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
257
+ for i, end_idx in enumerate(tokens_per_expert):
258
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
259
+ if start_idx == end_idx:
260
+ continue
261
+ expert = self.experts[i]
262
+ exp_token_idx = token_idxs[start_idx:end_idx]
263
+ expert_tokens = x[exp_token_idx]
264
+ expert_out = expert(expert_tokens).to(expert_cache.dtype)
265
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
266
+ expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
267
+
268
+ return expert_cache
269
+
270
+
271
+ class MiniMindBlock(nn.Module):
272
+ def __init__(self, layer_id: int, config: MiniMindConfig):
273
+ super().__init__()
274
+ self.num_attention_heads = config.num_attention_heads
275
+ self.hidden_size = config.hidden_size
276
+ self.head_dim = config.hidden_size // config.num_attention_heads
277
+ self.self_attn = Attention(config)
278
+
279
+ self.layer_id = layer_id
280
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
281
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
+ self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
283
+
284
+ def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
285
+ residual = hidden_states
286
+ hidden_states, present_key_value = self.self_attn(
287
+ self.input_layernorm(hidden_states), position_embeddings,
288
+ past_key_value, use_cache, attention_mask
289
+ )
290
+ hidden_states += residual
291
+ hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
292
+ return hidden_states, present_key_value
293
+
294
+
295
+ class MiniMindModel(nn.Module):
296
+ def __init__(self, config: MiniMindConfig):
297
+ super().__init__()
298
+ self.config = config
299
+ self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
300
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
301
+ self.dropout = nn.Dropout(config.dropout)
302
+ self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
303
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
304
+
305
+ freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
306
+ end=config.max_position_embeddings, rope_base=config.rope_theta,
307
+ rope_scaling=config.rope_scaling)
308
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
309
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
310
+
311
+ def forward(self,
312
+ input_ids: Optional[torch.Tensor] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
315
+ use_cache: bool = False,
316
+ **kwargs):
317
+ batch_size, seq_length = input_ids.shape
318
+ if hasattr(past_key_values, 'layers'): past_key_values = None
319
+ past_key_values = past_key_values or [None] * len(self.layers)
320
+ start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
321
+
322
+ hidden_states = self.dropout(self.embed_tokens(input_ids))
323
+
324
+ position_embeddings = (
325
+ self.freqs_cos[start_pos:start_pos + seq_length],
326
+ self.freqs_sin[start_pos:start_pos + seq_length]
327
+ )
328
+
329
+ presents = []
330
+ for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
331
+ hidden_states, present = layer(
332
+ hidden_states,
333
+ position_embeddings,
334
+ past_key_value=past_key_value,
335
+ use_cache=use_cache,
336
+ attention_mask=attention_mask
337
+ )
338
+ presents.append(present)
339
+
340
+ hidden_states = self.norm(hidden_states)
341
+
342
+ aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
343
+ return hidden_states, presents, aux_loss
344
+
345
+
346
+ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
347
+ config_class = MiniMindConfig
348
+
349
+ def __init__(self, config: MiniMindConfig = None):
350
+ self.config = config or MiniMindConfig()
351
+ super().__init__(self.config)
352
+ self.model = MiniMindModel(self.config)
353
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
354
+ self.model.embed_tokens.weight = self.lm_head.weight
355
+
356
+ def forward(self,
357
+ input_ids: Optional[torch.Tensor] = None,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
361
+ use_cache: bool = False,
362
+ logits_to_keep: Union[int, torch.Tensor] = 0,
363
+ **args):
364
+ hidden_states, past_key_values, aux_loss = self.model(
365
+ input_ids=input_ids,
366
+ attention_mask=attention_mask,
367
+ past_key_values=past_key_values,
368
+ use_cache=use_cache,
369
+ **args
370
+ )
371
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
372
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
373
+
374
+ loss = None
375
+ if labels is not None:
376
+ shift_logits = logits[..., :-1, :].contiguous()
377
+ shift_labels = labels[..., 1:].contiguous()
378
+ loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
379
+
380
+ output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
381
+ output.aux_loss = aux_loss
382
+ return output
383
+
384
+
385
+ if __name__ == "__main__":
386
+ pass
toolbox/os/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/os/command.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+
6
+ class Command(object):
7
+ custom_command = [
8
+ "cd"
9
+ ]
10
+
11
+ @staticmethod
12
+ def _get_cmd(command):
13
+ command = str(command).strip()
14
+ if command == "":
15
+ return None
16
+ cmd_and_args = command.split(sep=" ")
17
+ cmd = cmd_and_args[0]
18
+ args = " ".join(cmd_and_args[1:])
19
+ return cmd, args
20
+
21
+ @classmethod
22
+ def popen(cls, command):
23
+ cmd, args = cls._get_cmd(command)
24
+ if cmd in cls.custom_command:
25
+ method = getattr(cls, cmd)
26
+ return method(args)
27
+ else:
28
+ resp = os.popen(command)
29
+ result = resp.read()
30
+ resp.close()
31
+ return result
32
+
33
+ @classmethod
34
+ def cd(cls, args):
35
+ if args.startswith("/"):
36
+ os.chdir(args)
37
+ else:
38
+ pwd = os.getcwd()
39
+ path = os.path.join(pwd, args)
40
+ os.chdir(path)
41
+
42
+ @classmethod
43
+ def system(cls, command):
44
+ return os.system(command)
45
+
46
+ def __init__(self):
47
+ pass
48
+
49
+
50
+ def ps_ef_grep(keyword: str):
51
+ cmd = "ps -ef | grep {}".format(keyword)
52
+ rows = Command.popen(cmd)
53
+ rows = str(rows).split("\n")
54
+ rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
55
+ return rows
56
+
57
+
58
+ if __name__ == "__main__":
59
+ pass
toolbox/os/environment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ from dotenv import load_dotenv
7
+ from dotenv.main import DotEnv
8
+
9
+ from toolbox.json.misc import traverse
10
+
11
+
12
+ class EnvironmentManager(object):
13
+ def __init__(self, path, env, override=False):
14
+ filename = os.path.join(path, '{}.env'.format(env))
15
+ self.filename = filename
16
+
17
+ load_dotenv(
18
+ dotenv_path=filename,
19
+ override=override
20
+ )
21
+
22
+ self._environ = dict()
23
+
24
+ def open_dotenv(self, filename: str = None):
25
+ filename = filename or self.filename
26
+ dotenv = DotEnv(
27
+ dotenv_path=filename,
28
+ stream=None,
29
+ verbose=False,
30
+ interpolate=False,
31
+ override=False,
32
+ encoding="utf-8",
33
+ )
34
+ result = dotenv.dict()
35
+ return result
36
+
37
+ def get(self, key, default=None, dtype=str):
38
+ result = os.environ.get(key)
39
+ if result is None:
40
+ if default is None:
41
+ result = None
42
+ else:
43
+ result = default
44
+ else:
45
+ result = dtype(result)
46
+ self._environ[key] = result
47
+ return result
48
+
49
+
50
+ _DEFAULT_DTYPE_MAP = {
51
+ 'int': int,
52
+ 'float': float,
53
+ 'str': str,
54
+ 'json.loads': json.loads
55
+ }
56
+
57
+
58
+ class JsonConfig(object):
59
+ """
60
+ 将 json 中, 形如 `$float:threshold` 的值, 处理为:
61
+ 从环境变量中查到 threshold, 再将其转换为 float 类型.
62
+ """
63
+ def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
64
+ self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
65
+ self.environment = environment or os.environ
66
+
67
+ def sanitize_by_filename(self, filename: str):
68
+ with open(filename, 'r', encoding='utf-8') as f:
69
+ js = json.load(f)
70
+
71
+ return self.sanitize_by_json(js)
72
+
73
+ def sanitize_by_json(self, js):
74
+ js = traverse(
75
+ js,
76
+ callback=self.sanitize,
77
+ environment=self.environment
78
+ )
79
+ return js
80
+
81
+ def sanitize(self, string, environment):
82
+ """支持 $ 符开始的, 环境变量配置"""
83
+ if isinstance(string, str) and string.startswith('$'):
84
+ dtype, key = string[1:].split(':')
85
+ dtype = self.dtype_map[dtype]
86
+
87
+ value = environment.get(key)
88
+ if value is None:
89
+ raise AssertionError('environment not exist. key: {}'.format(key))
90
+
91
+ value = dtype(value)
92
+ result = value
93
+ else:
94
+ result = string
95
+ return result
96
+
97
+
98
+ def demo1():
99
+ import json
100
+
101
+ from project_settings import project_path
102
+
103
+ environment = EnvironmentManager(
104
+ path=os.path.join(project_path, 'server/callbot_server/dotenv'),
105
+ env='dev',
106
+ )
107
+ init_scenes = environment.get(key='init_scenes', dtype=json.loads)
108
+ print(init_scenes)
109
+ print(environment._environ)
110
+ return
111
+
112
+
113
+ if __name__ == '__main__':
114
+ demo1()
toolbox/os/other.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+
4
+
5
+ def pwd():
6
+ """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
7
+ frame = inspect.stack()[1]
8
+ module = inspect.getmodule(frame[0])
9
+ return os.path.dirname(os.path.abspath(module.__file__))
toolbox/torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/modules/gaussian_mixture.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/georgepar/gmmhmm-pytorch/blob/master/gmm.py
5
+ https://github.com/ldeecke/gmm-torch
6
+ """
7
+ import math
8
+
9
+ from sklearn import cluster
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ class GaussianMixtureModel(nn.Module):
15
+ def __init__(self,
16
+ n_mixtures: int,
17
+ n_features: int,
18
+ init: str = "random",
19
+ device: str = 'cpu',
20
+ n_iter: int = 1000,
21
+ delta: float = 1e-3,
22
+ warm_start: bool = False,
23
+ ):
24
+ super(GaussianMixtureModel, self).__init__()
25
+ self.n_mixtures = n_mixtures
26
+ self.n_features = n_features
27
+ self.init = init
28
+ self.device = device
29
+ self.n_iter = n_iter
30
+ self.delta = delta
31
+ self.warm_start = warm_start
32
+
33
+ if init not in ('kmeans', 'random'):
34
+ raise AssertionError
35
+
36
+ self.mu = nn.Parameter(
37
+ torch.Tensor(n_mixtures, n_features),
38
+ requires_grad=False,
39
+ )
40
+
41
+ self.sigma = None
42
+
43
+ # the weight of each gaussian
44
+ self.pi = nn.Parameter(
45
+ torch.Tensor(n_mixtures),
46
+ requires_grad=False
47
+ )
48
+
49
+ self.converged_ = False
50
+ self.eps = 1e-6
51
+ self.delta = delta
52
+ self.warm_start = warm_start
53
+ self.n_iter = n_iter
54
+
55
+ def reset_sigma(self):
56
+ raise NotImplementedError
57
+
58
+ def estimate_precisions(self):
59
+ raise NotImplementedError
60
+
61
+ def log_prob(self, x):
62
+ raise NotImplementedError
63
+
64
+ def weighted_log_prob(self, x):
65
+ log_prob = self.log_prob(x)
66
+ weighted_log_prob = log_prob + torch.log(self.pi)
67
+ return weighted_log_prob
68
+
69
+ def log_likelihood(self, x):
70
+ weighted_log_prob = self.weighted_log_prob(x)
71
+ per_sample_log_likelihood = torch.logsumexp(weighted_log_prob, dim=1)
72
+ log_likelihood = torch.sum(per_sample_log_likelihood)
73
+ return log_likelihood
74
+
75
+ def e_step(self, x):
76
+ weighted_log_prob = self.weighted_log_prob(x)
77
+ weighted_log_prob = weighted_log_prob.unsqueeze(dim=-1)
78
+ log_likelihood = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)
79
+ q = weighted_log_prob - log_likelihood
80
+ return q.squeeze()
81
+
82
+ def m_step(self, x, q):
83
+ x = x.unsqueeze(dim=1)
84
+
85
+ return
86
+
87
+ def estimate_mu(self, x, pi, responsibilities):
88
+ nk = pi * x.size(0)
89
+ mu = torch.sum(responsibilities * x, dim=0, keepdim=True) / nk
90
+ return mu
91
+
92
+ def estimate_pi(self, x, responsibilities):
93
+ pi = torch.sum(responsibilities, dim=0, keepdim=True) + self.eps
94
+ pi = pi / x.size(0)
95
+ return pi
96
+
97
+ def reset_parameters(self, x=None):
98
+ if self.init == 'random' or x is None:
99
+ self.mu.normal_()
100
+ self.reset_sigma()
101
+ self.pi.fill_(1.0 / self.n_mixtures)
102
+ elif self.init == 'kmeans':
103
+ centroids = cluster.KMeans(n_clusters=self.n_mixtures, n_init=1).fit(x).cluster_centers_
104
+ centroids = torch.tensor(centroids).to(self.device)
105
+ self.update_(mu=centroids)
106
+ else:
107
+ raise NotImplementedError
108
+
109
+
110
+ class DiagonalCovarianceGMM(GaussianMixtureModel):
111
+ def __init__(self,
112
+ n_mixtures: int,
113
+ n_features: int,
114
+ init: str = "random",
115
+ device: str = 'cpu',
116
+ n_iter: int = 1000,
117
+ delta: float = 1e-3,
118
+ warm_start: bool = False,
119
+ ):
120
+ super(DiagonalCovarianceGMM, self).__init__(
121
+ n_mixtures=n_mixtures,
122
+ n_features=n_features,
123
+ init=init,
124
+ device=device,
125
+ n_iter=n_iter,
126
+ delta=delta,
127
+ warm_start=warm_start,
128
+ )
129
+ self.sigma = nn.Parameter(
130
+ torch.Tensor(n_mixtures, n_features), requires_grad=False
131
+ )
132
+ self.reset_parameters()
133
+ self.to(self.device)
134
+
135
+ def reset_sigma(self):
136
+ self.sigma.fill_(1)
137
+
138
+ def estimate_precisions(self):
139
+ return torch.rsqrt(self.sigma)
140
+
141
+ def log_prob(self, x):
142
+ precisions = self.estimate_precisions()
143
+
144
+ x = x.unsqueeze(1)
145
+ mu = self.mu.unsqueeze(0)
146
+ precisions = precisions.unsqueeze(0)
147
+
148
+ # This is outer product
149
+ exp_term = torch.sum(
150
+ (mu * mu + x * x - 2 * x * mu) * (precisions ** 2), dim=2, keepdim=True
151
+ )
152
+ log_det = torch.sum(torch.log(precisions), dim=2, keepdim=True)
153
+
154
+ logp = -0.5 * (self.n_features * torch.log(2 * math.pi) + exp_term) + log_det
155
+
156
+ return logp.squeeze()
157
+
158
+ def estimate_sigma(self, x, mu, pi, responsibilities):
159
+ nk = pi * x.size(0)
160
+ x2 = (responsibilities * x * x).sum(0, keepdim=True) / nk
161
+ mu2 = mu * mu
162
+ xmu = (responsibilities * mu * x).sum(0, keepdim=True) / nk
163
+ sigma = x2 - 2 * xmu + mu2 + self.eps
164
+
165
+ return sigma
166
+
167
+
168
+ def demo1():
169
+ return
170
+
171
+
172
+ if __name__ == '__main__':
173
+ demo1()
toolbox/torch/modules/highway.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch.nn as nn
4
+
5
+
6
+ class Highway(nn.Module):
7
+ """
8
+ https://arxiv.org/abs/1505.00387
9
+ [Submitted on 3 May 2015 (v1), last revised 3 Nov 2015 (this version, v2)]
10
+
11
+ discuss of Highway and ResNet
12
+ https://www.zhihu.com/question/279426970
13
+ """
14
+ def __init__(self, in_size, out_size):
15
+ super(Highway, self).__init__()
16
+ self.H = nn.Linear(in_size, out_size)
17
+ self.H.bias.data.zero_()
18
+ self.T = nn.Linear(in_size, out_size)
19
+ self.T.bias.data.fill_(-1)
20
+ self.relu = nn.ReLU()
21
+ self.sigmoid = nn.Sigmoid()
22
+
23
+ def forward(self, inputs):
24
+ H = self.relu(self.H(inputs))
25
+ T = self.sigmoid(self.T(inputs))
26
+ return H * T + inputs * (1.0 - T)
27
+
28
+
29
+ if __name__ == '__main__':
30
+ pass
toolbox/torch/modules/loss.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.modules.loss import _Loss
11
+ from torch.autograd import Variable
12
+
13
+
14
+ class ClassBalancedLoss(_Loss):
15
+ """
16
+ https://arxiv.org/abs/1901.05555
17
+ """
18
+ @staticmethod
19
+ def demo1():
20
+ batch_loss: torch.FloatTensor = torch.randn(size=(2, 1), dtype=torch.float32)
21
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
22
+
23
+ class_balanced_loss = ClassBalancedLoss(
24
+ num_classes=3,
25
+ num_samples_each_class=[300, 433, 50],
26
+ reduction='mean',
27
+ )
28
+ loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
29
+ print(loss)
30
+ return
31
+
32
+ @staticmethod
33
+ def demo2():
34
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
35
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
36
+
37
+ focal_loss = FocalLoss(
38
+ num_classes=3,
39
+ # reduction='mean',
40
+ # reduction='sum',
41
+ reduction='none',
42
+ )
43
+ batch_loss = focal_loss.forward(inputs, targets)
44
+ print(batch_loss)
45
+
46
+ class_balanced_loss = ClassBalancedLoss(
47
+ num_classes=3,
48
+ num_samples_each_class=[300, 433, 50],
49
+ reduction='mean',
50
+ )
51
+ loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
52
+ print(loss)
53
+
54
+ return
55
+
56
+ def __init__(self,
57
+ num_classes: int,
58
+ num_samples_each_class: List[int],
59
+ beta: float = 0.999,
60
+ reduction: str = 'mean') -> None:
61
+ super(ClassBalancedLoss, self).__init__(None, None, reduction)
62
+
63
+ effective_num = 1.0 - np.power(beta, num_samples_each_class)
64
+ weights = (1.0 - beta) / np.array(effective_num)
65
+ self.weights = weights / np.sum(weights) * num_classes
66
+
67
+ def forward(self, batch_loss: torch.FloatTensor, targets: torch.LongTensor):
68
+ """
69
+ :param batch_loss: shape=[batch_size, 1]
70
+ :param targets: shape=[batch_size,]
71
+ :return:
72
+ """
73
+ weights = list()
74
+ targets = targets.numpy()
75
+ for target in targets:
76
+ weights.append([self.weights[target]])
77
+
78
+ weights = torch.tensor(weights, dtype=torch.float32)
79
+ batch_loss = weights * batch_loss
80
+
81
+ if self.reduction == 'mean':
82
+ loss = batch_loss.mean()
83
+ elif self.reduction == 'sum':
84
+ loss = batch_loss.sum()
85
+ else:
86
+ loss = batch_loss
87
+ return loss
88
+
89
+
90
+ class EqualizationLoss(_Loss):
91
+ """
92
+ 在图像识别中的, sigmoid 的多标签分类, 且 num_classes 类别数之外有一个 background 背景类别.
93
+ Equalization Loss
94
+ https://arxiv.org/abs/2003.05176
95
+ Equalization Loss v2
96
+ https://arxiv.org/abs/2012.08548
97
+ """
98
+
99
+ @staticmethod
100
+ def demo1():
101
+ logits: torch.FloatTensor = torch.randn(size=(3, 3), dtype=torch.float32)
102
+ targets: torch.LongTensor = torch.tensor([1, 2, 3], dtype=torch.long)
103
+
104
+ equalization_loss = EqualizationLoss(
105
+ num_samples_each_class=[300, 433, 50],
106
+ threshold=100,
107
+ reduction='mean',
108
+ )
109
+ loss = equalization_loss.forward(logits=logits, targets=targets)
110
+ print(loss)
111
+ return
112
+
113
+ def __init__(self,
114
+ num_samples_each_class: List[int],
115
+ threshold: int = 100,
116
+ reduction: str = 'mean') -> None:
117
+ super(EqualizationLoss, self).__init__(None, None, reduction)
118
+ self.num_samples_each_class = np.array(num_samples_each_class, dtype=np.int32)
119
+ self.threshold = threshold
120
+
121
+ def forward(self,
122
+ logits: torch.FloatTensor,
123
+ targets: torch.LongTensor
124
+ ):
125
+ """
126
+ num_classes + 1 对应于背景类别 background.
127
+ :param logits: shape=[batch_size, num_classes]
128
+ :param targets: shape=[batch_size]
129
+ :return:
130
+ """
131
+ batch_size, num_classes = logits.size()
132
+
133
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes + 1)
134
+ one_hot_targets = one_hot_targets[:, :-1]
135
+
136
+ exclude = self.exclude_func(
137
+ num_classes=num_classes,
138
+ targets=targets
139
+ )
140
+ is_tail = self.threshold_func(
141
+ num_classes=num_classes,
142
+ num_samples_each_class=self.num_samples_each_class,
143
+ threshold=self.threshold,
144
+ )
145
+
146
+ weights = 1 - exclude * is_tail * (1 - one_hot_targets)
147
+
148
+ batch_loss = F.binary_cross_entropy_with_logits(
149
+ logits,
150
+ one_hot_targets.float(),
151
+ reduction='none'
152
+ )
153
+
154
+ batch_loss = weights * batch_loss
155
+
156
+ if self.reduction == 'mean':
157
+ loss = batch_loss.mean()
158
+ elif self.reduction == 'sum':
159
+ loss = batch_loss.sum()
160
+ else:
161
+ loss = batch_loss
162
+
163
+ loss = loss / num_classes
164
+ return loss
165
+
166
+ @staticmethod
167
+ def exclude_func(num_classes: int, targets: torch.LongTensor):
168
+ """
169
+ 最后一个类别是背景 background.
170
+ :param num_classes: int,
171
+ :param targets: shape=[batch_size,]
172
+ :return: weight, shape=[batch_size, num_classes]
173
+ """
174
+ batch_size = targets.shape[0]
175
+ weight = (targets != num_classes).float()
176
+ weight = weight.view(batch_size, 1).expand(batch_size, num_classes)
177
+ return weight
178
+
179
+ @staticmethod
180
+ def threshold_func(num_classes: int, num_samples_each_class: np.ndarray, threshold: int):
181
+ """
182
+ :param num_classes: int,
183
+ :param num_samples_each_class: shape=[num_classes]
184
+ :param threshold: int,
185
+ :return: weight, shape=[1, num_classes]
186
+ """
187
+ weight = torch.zeros(size=(num_classes,))
188
+ weight[num_samples_each_class < threshold] = 1
189
+ weight = torch.unsqueeze(weight, dim=0)
190
+ return weight
191
+
192
+
193
+ class FocalLoss(_Loss):
194
+ """
195
+ https://arxiv.org/abs/1708.02002
196
+ """
197
+ @staticmethod
198
+ def demo1(self):
199
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
200
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
201
+
202
+ focal_loss = FocalLoss(
203
+ num_classes=3,
204
+ reduction='mean',
205
+ # reduction='sum',
206
+ # reduction='none',
207
+ )
208
+ loss = focal_loss.forward(inputs, targets)
209
+ print(loss)
210
+ return
211
+
212
+ def __init__(self,
213
+ num_classes: int,
214
+ alpha: List[float] = None,
215
+ gamma: int = 2,
216
+ reduction: str = 'mean',
217
+ inputs_logits: bool = True) -> None:
218
+ """
219
+ :param num_classes:
220
+ :param alpha:
221
+ :param gamma:
222
+ :param reduction: (`none`, `mean`, `sum`) available.
223
+ :param inputs_logits: if False, the inputs should be probs.
224
+ """
225
+ super(FocalLoss, self).__init__(None, None, reduction)
226
+ if alpha is None:
227
+ self.alpha = torch.ones(num_classes, 1)
228
+ else:
229
+ self.alpha = torch.tensor(alpha, dtype=torch.float32)
230
+ self.gamma = gamma
231
+ self.num_classes = num_classes
232
+ self.inputs_logits = inputs_logits
233
+
234
+ def forward(self,
235
+ inputs: torch.FloatTensor,
236
+ targets: torch.LongTensor):
237
+ """
238
+ :param inputs: logits, shape=[batch_size, num_classes]
239
+ :param targets: shape=[batch_size,]
240
+ :return:
241
+ """
242
+ batch_size, num_classes = inputs.shape
243
+
244
+ if self.inputs_logits:
245
+ probs = F.softmax(inputs, dim=-1)
246
+ else:
247
+ probs = inputs
248
+
249
+ # class_mask = inputs.data.new(batch_size, num_classes).fill_(0)
250
+ class_mask = torch.zeros(size=(batch_size, num_classes), dtype=inputs.dtype, device=inputs.device)
251
+ # class_mask = Variable(class_mask)
252
+ ids = targets.view(-1, 1)
253
+ class_mask.scatter_(1, ids.data, 1.)
254
+
255
+ if inputs.is_cuda and not self.alpha.is_cuda:
256
+ self.alpha = self.alpha.cuda()
257
+ alpha = self.alpha[ids.data.view(-1)]
258
+
259
+ probs = (probs * class_mask).sum(1).view(-1, 1)
260
+
261
+ log_p = probs.log()
262
+
263
+ batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
264
+
265
+ if self.reduction == 'mean':
266
+ loss = batch_loss.mean()
267
+ elif self.reduction == 'sum':
268
+ loss = batch_loss.sum()
269
+ else:
270
+ loss = batch_loss
271
+ return loss
272
+
273
+
274
+ class HingeLoss(_Loss):
275
+ @staticmethod
276
+ def demo1():
277
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
278
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
279
+
280
+ hinge_loss = HingeLoss(
281
+ margin_list=[300, 433, 50],
282
+ reduction='mean',
283
+ )
284
+ loss = hinge_loss.forward(inputs=inputs, targets=targets)
285
+ print(loss)
286
+ return
287
+
288
+ def __init__(self,
289
+ margin_list: List[float],
290
+ max_margin: float = 0.5,
291
+ scale: float = 1.0,
292
+ weight: Optional[torch.Tensor] = None,
293
+ reduction: str = 'mean') -> None:
294
+ super(HingeLoss, self).__init__(None, None, reduction)
295
+
296
+ self.max_margin = max_margin
297
+ self.scale = scale
298
+ self.weight = weight
299
+
300
+ margin_list = np.array(margin_list)
301
+ margin_list = margin_list * (max_margin / np.max(margin_list))
302
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
303
+
304
+ def forward(self,
305
+ inputs: torch.FloatTensor,
306
+ targets: torch.LongTensor
307
+ ):
308
+ """
309
+ :param inputs: logits, shape=[batch_size, num_classes]
310
+ :param targets: shape=[batch_size,]
311
+ :return:
312
+ """
313
+ batch_size, num_classes = inputs.shape
314
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
315
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
316
+
317
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
318
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
319
+ inputs_margin = inputs - batch_margin
320
+
321
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
322
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
323
+
324
+ loss = F.cross_entropy(
325
+ input=self.scale * logits,
326
+ target=targets,
327
+ weight=self.weight,
328
+ reduction=self.reduction,
329
+ )
330
+ return loss
331
+
332
+
333
+ class HingeLinear(nn.Module):
334
+ """
335
+ use this instead of `HingeLoss`, then you can combine it with `FocalLoss` or others.
336
+ """
337
+ def __init__(self,
338
+ margin_list: List[float],
339
+ max_margin: float = 0.5,
340
+ scale: float = 1.0,
341
+ weight: Optional[torch.Tensor] = None
342
+ ) -> None:
343
+ super(HingeLinear, self).__init__()
344
+
345
+ self.max_margin = max_margin
346
+ self.scale = scale
347
+ self.weight = weight
348
+
349
+ margin_list = np.array(margin_list)
350
+ margin_list = margin_list * (max_margin / np.max(margin_list))
351
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
352
+
353
+ def forward(self,
354
+ inputs: torch.FloatTensor,
355
+ targets: torch.LongTensor
356
+ ):
357
+ """
358
+ :param inputs: logits, shape=[batch_size, num_classes]
359
+ :param targets: shape=[batch_size,]
360
+ :return:
361
+ """
362
+ if self.training and targets is not None:
363
+ batch_size, num_classes = inputs.shape
364
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
365
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
366
+
367
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
368
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
369
+ inputs_margin = inputs - batch_margin
370
+
371
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
372
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
373
+ logits = logits * self.scale
374
+ else:
375
+ logits = inputs
376
+ return logits
377
+
378
+
379
+ class LDAMLoss(_Loss):
380
+ """
381
+ https://arxiv.org/abs/1906.07413
382
+ """
383
+ @staticmethod
384
+ def demo1():
385
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
386
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
387
+
388
+ ldam_loss = LDAMLoss(
389
+ num_samples_each_class=[300, 433, 50],
390
+ reduction='mean',
391
+ )
392
+ loss = ldam_loss.forward(inputs=inputs, targets=targets)
393
+ print(loss)
394
+ return
395
+
396
+ def __init__(self,
397
+ num_samples_each_class: List[int],
398
+ max_margin: float = 0.5,
399
+ scale: float = 30.0,
400
+ weight: Optional[torch.Tensor] = None,
401
+ reduction: str = 'mean') -> None:
402
+ super(LDAMLoss, self).__init__(None, None, reduction)
403
+
404
+ margin_list = np.power(num_samples_each_class, -0.25)
405
+ margin_list = margin_list * (max_margin / np.max(margin_list))
406
+
407
+ self.num_samples_each_class = num_samples_each_class
408
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
409
+ self.scale = scale
410
+ self.weight = weight
411
+
412
+ def forward(self,
413
+ inputs: torch.FloatTensor,
414
+ targets: torch.LongTensor
415
+ ):
416
+ """
417
+ :param inputs: logits, shape=[batch_size, num_classes]
418
+ :param targets: shape=[batch_size,]
419
+ :return:
420
+ """
421
+ batch_size, num_classes = inputs.shape
422
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
423
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
424
+
425
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
426
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
427
+ inputs_margin = inputs - batch_margin
428
+
429
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
430
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
431
+
432
+ loss = F.cross_entropy(
433
+ input=self.scale * logits,
434
+ target=targets,
435
+ weight=self.weight,
436
+ reduction=self.reduction,
437
+ )
438
+ return loss
439
+
440
+
441
+ class NegativeEntropy(_Loss):
442
+ def __init__(self,
443
+ reduction: str = 'mean',
444
+ inputs_logits: bool = True) -> None:
445
+ super(NegativeEntropy, self).__init__(None, None, reduction)
446
+ self.inputs_logits = inputs_logits
447
+
448
+ def forward(self,
449
+ inputs: torch.FloatTensor,
450
+ targets: torch.LongTensor):
451
+ if self.inputs_logits:
452
+ probs = F.softmax(inputs, dim=-1)
453
+ log_probs = torch.nn.functional.log_softmax(probs, dim=-1)
454
+ else:
455
+ probs = inputs
456
+ log_probs = torch.log(probs)
457
+
458
+ weighted_negative_likelihood = - log_probs * probs
459
+
460
+ loss = - weighted_negative_likelihood.sum()
461
+ return loss
462
+
463
+
464
+ class LargeMarginSoftMaxLoss(_Loss):
465
+ """
466
+ Alias: L-Softmax
467
+
468
+ https://arxiv.org/abs/1612.02295
469
+ https://github.com/wy1iu/LargeMargin_Softmax_Loss
470
+ https://github.com/amirhfarzaneh/lsoftmax-pytorch/blob/master/lsoftmax.py
471
+
472
+ 参考链接:
473
+ https://www.jianshu.com/p/06cc3f84aa85
474
+
475
+ 论文认为, softmax 和 cross entropy 的组合, 没有明确鼓励对特征进行判别学习.
476
+
477
+ """
478
+ def __init__(self,
479
+ reduction: str = 'mean') -> None:
480
+ super(LargeMarginSoftMaxLoss, self).__init__(None, None, reduction)
481
+
482
+
483
+ class AngularSoftMaxLoss(_Loss):
484
+ """
485
+ Alias: A-Softmax
486
+
487
+ https://arxiv.org/abs/1704.08063
488
+
489
+ https://github.com/woshildh/a-softmax_pytorch/blob/master/a_softmax.py
490
+
491
+ 参考链接:
492
+ https://www.jianshu.com/p/06cc3f84aa85
493
+
494
+ 好像作者认为人脸是一个球面, 所以将向量转换到一个球面上是有帮助的.
495
+ """
496
+ def __init__(self,
497
+ reduction: str = 'mean') -> None:
498
+ super(AngularSoftMaxLoss, self).__init__(None, None, reduction)
499
+
500
+
501
+ class AdditiveMarginSoftMax(_Loss):
502
+ """
503
+ Alias: AM-Softmax
504
+
505
+ https://arxiv.org/abs/1801.05599
506
+
507
+ Large Margin Cosine Loss
508
+ https://arxiv.org/abs/1801.09414
509
+
510
+ 参考链接:
511
+ https://www.jianshu.com/p/06cc3f84aa85
512
+
513
+ 说明:
514
+ 相对于普通的 对 logits 做 softmax,
515
+ 它将真实标签对应的 logit 值减去 m, 来让模型它该值调整得更大一些.
516
+ 另外, 它还将每个 logits 乘以 s, 这可以控制各 logits 之间的相对大小.
517
+ 根 HingeLoss 有点像.
518
+ """
519
+ def __init__(self,
520
+ reduction: str = 'mean') -> None:
521
+ super(AdditiveMarginSoftMax, self).__init__(None, None, reduction)
522
+
523
+
524
+ class AdditiveAngularMarginSoftMax(_Loss):
525
+ """
526
+ Alias: ArcFace, AAM-Softmax
527
+
528
+ ArcFace: Additive Angular Margin Loss for Deep Face Recognition
529
+ https://arxiv.org/abs/1801.07698
530
+
531
+ 参考代码:
532
+ https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
533
+
534
+ """
535
+ @staticmethod
536
+ def demo1():
537
+ """
538
+ 角度与数值转换
539
+ pi / 180 代表 1 度,
540
+ pi / 180 = 0.01745
541
+ """
542
+
543
+ # 度数转数值
544
+ degree = 10
545
+ result = degree * math.pi / 180
546
+ print(result)
547
+
548
+ # 数值转数度
549
+ radian = 0.2
550
+ result = radian / (math.pi / 180)
551
+ print(result)
552
+
553
+ return
554
+
555
+ def __init__(self,
556
+ hidden_size: int,
557
+ num_labels: int,
558
+ margin: float = 0.2,
559
+ scale: float = 10.0,
560
+ ):
561
+ """
562
+ :param hidden_size:
563
+ :param num_labels:
564
+ :param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
565
+ :param scale:
566
+ """
567
+ super(AdditiveAngularMarginSoftMax, self).__init__()
568
+ self.margin = margin
569
+ self.scale = scale
570
+ self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
571
+ nn.init.xavier_uniform_(self.weight)
572
+
573
+ self.cos_margin = math.cos(self.margin)
574
+ self.sin_margin = math.sin(self.margin)
575
+
576
+ # sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
577
+ # sin(pi - a) = sin(a)
578
+
579
+ self.loss = nn.CrossEntropyLoss()
580
+
581
+ def forward(self,
582
+ inputs: torch.Tensor,
583
+ label: torch.LongTensor = None
584
+ ):
585
+ """
586
+ :param inputs: shape=[batch_size, ..., hidden_size]
587
+ :param label:
588
+ :return: logits
589
+ """
590
+ x = F.normalize(inputs)
591
+ weight = F.normalize(self.weight)
592
+ cosine = F.linear(x, weight)
593
+
594
+ if self.training:
595
+
596
+ # sin^2 + cos^2 = 1
597
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
598
+
599
+ # cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
600
+ cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
601
+
602
+ # when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
603
+ cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
604
+
605
+ one_hot = torch.zeros_like(cosine)
606
+ one_hot.scatter_(1, label.view(-1, 1), 1)
607
+
608
+ #
609
+ logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
610
+ logits = logits * self.scale
611
+ else:
612
+ logits = cosine
613
+
614
+ loss = self.loss(logits, label)
615
+ # prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
616
+ return loss
617
+
618
+
619
+ class AdditiveAngularMarginLinear(nn.Module):
620
+ """
621
+ Alias: ArcFace, AAM-Softmax
622
+
623
+ ArcFace: Additive Angular Margin Loss for Deep Face Recognition
624
+ https://arxiv.org/abs/1801.07698
625
+
626
+ 参考代码:
627
+ https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
628
+
629
+ """
630
+ @staticmethod
631
+ def demo1():
632
+ """
633
+ 角度与数值转换
634
+ pi / 180 代表 1 度,
635
+ pi / 180 = 0.01745
636
+ """
637
+
638
+ # 度数转数值
639
+ degree = 10
640
+ result = degree * math.pi / 180
641
+ print(result)
642
+
643
+ # 数值转数度
644
+ radian = 0.2
645
+ result = radian / (math.pi / 180)
646
+ print(result)
647
+
648
+ return
649
+
650
+ @staticmethod
651
+ def demo2():
652
+
653
+ return
654
+
655
+ def __init__(self,
656
+ hidden_size: int,
657
+ num_labels: int,
658
+ margin: float = 0.2,
659
+ scale: float = 10.0,
660
+ ):
661
+ """
662
+ :param hidden_size:
663
+ :param num_labels:
664
+ :param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
665
+ :param scale:
666
+ """
667
+ super(AdditiveAngularMarginLinear, self).__init__()
668
+ self.margin = margin
669
+ self.scale = scale
670
+ self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
671
+ nn.init.xavier_uniform_(self.weight)
672
+
673
+ self.cos_margin = math.cos(self.margin)
674
+ self.sin_margin = math.sin(self.margin)
675
+
676
+ # sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
677
+ # sin(pi - a) = sin(a)
678
+
679
+ def forward(self,
680
+ inputs: torch.Tensor,
681
+ targets: torch.LongTensor = None
682
+ ):
683
+ """
684
+ :param inputs: shape=[batch_size, ..., hidden_size]
685
+ :param targets:
686
+ :return: logits
687
+ """
688
+ x = F.normalize(inputs)
689
+ weight = F.normalize(self.weight)
690
+ cosine = F.linear(x, weight)
691
+
692
+ if self.training and targets is not None:
693
+ # sin^2 + cos^2 = 1
694
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
695
+
696
+ # cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
697
+ cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
698
+
699
+ # when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
700
+ cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
701
+
702
+ one_hot = torch.zeros_like(cosine)
703
+ one_hot.scatter_(1, targets.view(-1, 1), 1)
704
+
705
+ logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
706
+ logits = logits * self.scale
707
+ else:
708
+ logits = cosine
709
+ return logits
710
+
711
+
712
+ def demo1():
713
+ HingeLoss.demo1()
714
+ return
715
+
716
+
717
+ def demo2():
718
+ AdditiveAngularMarginSoftMax.demo1()
719
+
720
+ inputs = torch.ones(size=(2, 5), dtype=torch.float32)
721
+ label: torch.LongTensor = torch.tensor(data=[0, 1], dtype=torch.long)
722
+
723
+ aam_softmax = AdditiveAngularMarginSoftMax(
724
+ hidden_size=5,
725
+ num_labels=2,
726
+ margin=1,
727
+ scale=1
728
+ )
729
+
730
+ outputs = aam_softmax.forward(inputs, label)
731
+ print(outputs)
732
+
733
+ return
734
+
735
+
736
+ if __name__ == '__main__':
737
+ # demo1()
738
+ demo2()
toolbox/torch/training/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/training/metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/training/metrics/categorical_accuracy.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from overrides import overrides
4
+ import torch
5
+
6
+
7
+ class CategoricalAccuracy(object):
8
+ def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
9
+ if top_k > 1 and tie_break:
10
+ raise AssertionError("Tie break in Categorical Accuracy "
11
+ "can be done only for maximum (top_k = 1)")
12
+ if top_k <= 0:
13
+ raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
14
+ self._top_k = top_k
15
+ self._tie_break = tie_break
16
+ self.correct_count = 0.
17
+ self.total_count = 0.
18
+
19
+ def __call__(self,
20
+ predictions: torch.Tensor,
21
+ gold_labels: torch.Tensor,
22
+ mask: Optional[torch.Tensor] = None):
23
+
24
+ # predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
25
+
26
+ # Some sanity checks.
27
+ num_classes = predictions.size(-1)
28
+ if gold_labels.dim() != predictions.dim() - 1:
29
+ raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
30
+ "found tensor of shape: {}".format(predictions.size()))
31
+ if (gold_labels >= num_classes).any():
32
+ raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
33
+ "the number of classes.".format(num_classes))
34
+
35
+ predictions = predictions.view((-1, num_classes))
36
+ gold_labels = gold_labels.view(-1).long()
37
+ if not self._tie_break:
38
+ # Top K indexes of the predictions (or fewer, if there aren't K of them).
39
+ # Special case topk == 1, because it's common and .max() is much faster than .topk().
40
+ if self._top_k == 1:
41
+ top_k = predictions.max(-1)[1].unsqueeze(-1)
42
+ else:
43
+ top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
44
+
45
+ # This is of shape (batch_size, ..., top_k).
46
+ correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
47
+ else:
48
+ # prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
49
+ max_predictions = predictions.max(-1)[0]
50
+ max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
51
+ # max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
52
+ # ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
53
+ # For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
54
+ correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float()
55
+ tie_counts = max_predictions_mask.sum(-1)
56
+ correct /= tie_counts.float()
57
+ correct.unsqueeze_(-1)
58
+
59
+ if mask is not None:
60
+ correct *= mask.view(-1, 1).float()
61
+ self.total_count += mask.sum()
62
+ else:
63
+ self.total_count += gold_labels.numel()
64
+ self.correct_count += correct.sum()
65
+
66
+ def get_metric(self, reset: bool = False):
67
+ """
68
+ Returns
69
+ -------
70
+ The accumulated accuracy.
71
+ """
72
+ if self.total_count > 1e-12:
73
+ accuracy = float(self.correct_count) / float(self.total_count)
74
+ else:
75
+ accuracy = 0.0
76
+ if reset:
77
+ self.reset()
78
+ return {'accuracy': accuracy}
79
+
80
+ def reset(self):
81
+ self.correct_count = 0.0
82
+ self.total_count = 0.0
toolbox/torch/training/metrics/verbose_categorical_accuracy.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ class CategoricalAccuracyVerbose(object):
10
+ def __init__(self,
11
+ index_to_token: Dict[int, str],
12
+ label_namespace: str = "labels",
13
+ top_k: int = 1,
14
+ ) -> None:
15
+ if top_k <= 0:
16
+ raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
17
+ self._index_to_token = index_to_token
18
+ self._label_namespace = label_namespace
19
+ self._top_k = top_k
20
+ self.correct_count = 0.
21
+ self.total_count = 0.
22
+ self.label_correct_count = dict()
23
+ self.label_total_count = dict()
24
+
25
+ def __call__(self,
26
+ predictions: torch.Tensor,
27
+ gold_labels: torch.Tensor,
28
+ mask: Optional[torch.Tensor] = None):
29
+ num_classes = predictions.size(-1)
30
+ if gold_labels.dim() != predictions.dim() - 1:
31
+ raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
32
+ "found tensor of shape: {}".format(predictions.size()))
33
+ if (gold_labels >= num_classes).any():
34
+ raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
35
+ "the number of classes.".format(num_classes))
36
+
37
+ predictions = predictions.view((-1, num_classes))
38
+ gold_labels = gold_labels.view(-1).long()
39
+
40
+ # Top K indexes of the predictions (or fewer, if there aren't K of them).
41
+ # Special case topk == 1, because it's common and .max() is much faster than .topk().
42
+ if self._top_k == 1:
43
+ top_k = predictions.max(-1)[1].unsqueeze(-1)
44
+ else:
45
+ top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
46
+
47
+ # This is of shape (batch_size, ..., top_k).
48
+ correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
49
+
50
+ if mask is not None:
51
+ correct *= mask.view(-1, 1).float()
52
+ self.total_count += mask.sum()
53
+ else:
54
+ self.total_count += gold_labels.numel()
55
+ self.correct_count += correct.sum()
56
+
57
+ labels: List[int] = np.unique(gold_labels.cpu().numpy()).tolist()
58
+ for label in labels:
59
+ label_mask = (gold_labels == label)
60
+
61
+ label_correct = correct * label_mask.view(-1, 1).float()
62
+ label_correct = int(label_correct.sum())
63
+ label_count = int(label_mask.sum())
64
+
65
+ label_str = self._index_to_token[label]
66
+ if label_str in self.label_correct_count:
67
+ self.label_correct_count[label_str] += label_correct
68
+ else:
69
+ self.label_correct_count[label_str] = label_correct
70
+
71
+ if label_str in self.label_total_count:
72
+ self.label_total_count[label_str] += label_count
73
+ else:
74
+ self.label_total_count[label_str] = label_count
75
+
76
+ def get_metric(self, reset: bool = False):
77
+ """
78
+ Returns
79
+ -------
80
+ The accumulated accuracy.
81
+ """
82
+ result = dict()
83
+ if self.total_count > 1e-12:
84
+ accuracy = float(self.correct_count) / float(self.total_count)
85
+ else:
86
+ accuracy = 0.0
87
+ result['accuracy'] = accuracy
88
+
89
+ for label in self.label_total_count.keys():
90
+ total = self.label_total_count[label]
91
+ correct = self.label_correct_count.get(label, 0.0)
92
+ label_accuracy = correct / total
93
+ result[label] = label_accuracy
94
+
95
+ if reset:
96
+ self.reset()
97
+ return result
98
+
99
+ def reset(self):
100
+ self.correct_count = 0.0
101
+ self.total_count = 0.0
102
+ self.label_correct_count = dict()
103
+ self.label_total_count = dict()
104
+
105
+
106
+ def demo1():
107
+
108
+ categorical_accuracy_verbose = CategoricalAccuracyVerbose(
109
+ index_to_token={0: '0', 1: '1'},
110
+ top_k=2,
111
+ )
112
+
113
+ predictions = torch.randn(size=(2, 3), dtype=torch.float32)
114
+ gold_labels = torch.ones(size=(2,), dtype=torch.long)
115
+ # print(predictions)
116
+ # print(gold_labels)
117
+
118
+ categorical_accuracy_verbose(
119
+ predictions=predictions,
120
+ gold_labels=gold_labels,
121
+ )
122
+ metric = categorical_accuracy_verbose.get_metric()
123
+ print(metric)
124
+ return
125
+
126
+
127
+ if __name__ == '__main__':
128
+ demo1()
toolbox/torch/training/trainer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/training/trainer/trainer.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/utils/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/utils/data/dataset/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/utils/data/dataset/wave_classifier_excel_dataset.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.io import wavfile
9
+ import torch
10
+ import torchaudio
11
+ from torch.utils.data import Dataset
12
+ from tqdm import tqdm
13
+
14
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
15
+
16
+
17
+ class WaveClassifierExcelDataset(Dataset):
18
+ def __init__(self,
19
+ vocab: Vocabulary,
20
+ excel_file: str,
21
+ expected_sample_rate: int,
22
+ resample: bool = False,
23
+ root_path: str = None,
24
+ category: str = None,
25
+ category_field: str = "category",
26
+ label_field: str = "labels",
27
+ max_wave_value: float = 1.0,
28
+ ) -> None:
29
+ self.vocab = vocab
30
+ self.excel_file = excel_file
31
+
32
+ self.expected_sample_rate = expected_sample_rate
33
+ self.resample = resample
34
+ self.root_path = root_path
35
+ self.category = category
36
+ self.category_field = category_field
37
+ self.label_field = label_field
38
+ self.max_wave_value = max_wave_value
39
+
40
+ df = pd.read_excel(excel_file)
41
+
42
+ samples = list()
43
+ for i, row in tqdm(df.iterrows(), total=len(df)):
44
+ filename = row["filename"]
45
+ label = row[self.label_field]
46
+
47
+ if self.category is not None and self.category != row[self.category_field]:
48
+ continue
49
+
50
+ samples.append({
51
+ "filename": filename,
52
+ "label": label,
53
+ })
54
+ self.samples = samples
55
+
56
+ def __getitem__(self, index):
57
+ sample = self.samples[index]
58
+ filename = sample["filename"]
59
+ label = sample["label"]
60
+
61
+ if self.root_path is not None:
62
+ filename = os.path.join(self.root_path, filename)
63
+
64
+ waveform = self.filename_to_waveform(filename)
65
+
66
+ namespace = self.label_field if self.category is None else self.category
67
+ token_to_index = self.vocab.get_token_to_index_vocabulary(namespace=namespace)
68
+ label: int = token_to_index[label]
69
+
70
+ result = {
71
+ "waveform": waveform,
72
+ "label": torch.tensor(label, dtype=torch.int64),
73
+ }
74
+ return result
75
+
76
+ def __len__(self):
77
+ return len(self.samples)
78
+
79
+ def filename_to_waveform(self, filename: str):
80
+ try:
81
+ if self.resample:
82
+ waveform, sample_rate = librosa.load(filename, sr=self.expected_sample_rate)
83
+ # waveform, sample_rate = torchaudio.load(filename, normalize=True)
84
+ else:
85
+ sample_rate, waveform = wavfile.read(filename)
86
+ waveform = waveform / self.max_wave_value
87
+ except ValueError as e:
88
+ print(filename)
89
+ raise e
90
+ if sample_rate != self.expected_sample_rate:
91
+ raise AssertionError
92
+
93
+ waveform = torch.tensor(waveform, dtype=torch.float32)
94
+ return waveform
95
+
96
+
97
+ if __name__ == "__main__":
98
+ pass
toolbox/torch/utils/data/vocabulary.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from collections import defaultdict, OrderedDict
4
+ import os
5
+ from typing import Any, Callable, Dict, Iterable, List, Set
6
+
7
+
8
+ def namespace_match(pattern: str, namespace: str):
9
+ """
10
+ Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
11
+ ``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
12
+ ``stemmed_tokens``.
13
+ """
14
+ if pattern[0] == '*' and namespace.endswith(pattern[1:]):
15
+ return True
16
+ elif pattern == namespace:
17
+ return True
18
+ return False
19
+
20
+
21
+ class _NamespaceDependentDefaultDict(defaultdict):
22
+ def __init__(self,
23
+ non_padded_namespaces: Set[str],
24
+ padded_function: Callable[[], Any],
25
+ non_padded_function: Callable[[], Any]) -> None:
26
+ self._non_padded_namespaces = set(non_padded_namespaces)
27
+ self._padded_function = padded_function
28
+ self._non_padded_function = non_padded_function
29
+ super(_NamespaceDependentDefaultDict, self).__init__()
30
+
31
+ def __missing__(self, key: str):
32
+ if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
33
+ value = self._non_padded_function()
34
+ else:
35
+ value = self._padded_function()
36
+ dict.__setitem__(self, key, value)
37
+ return value
38
+
39
+ def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
40
+ # add non_padded_namespaces which weren't already present
41
+ self._non_padded_namespaces.update(non_padded_namespaces)
42
+
43
+
44
+ class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
45
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
46
+ super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
47
+ lambda: {padding_token: 0, oov_token: 1},
48
+ lambda: {})
49
+
50
+
51
+ class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
52
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
53
+ super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
54
+ lambda: {0: padding_token, 1: oov_token},
55
+ lambda: {})
56
+
57
+
58
+ DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
59
+ DEFAULT_PADDING_TOKEN = '[PAD]'
60
+ DEFAULT_OOV_TOKEN = '[UNK]'
61
+ NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
62
+
63
+
64
+ class Vocabulary(object):
65
+ def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
66
+ self._non_padded_namespaces = set(non_padded_namespaces)
67
+ self._padding_token = DEFAULT_PADDING_TOKEN
68
+ self._oov_token = DEFAULT_OOV_TOKEN
69
+ self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
70
+ self._padding_token,
71
+ self._oov_token)
72
+ self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
73
+ self._padding_token,
74
+ self._oov_token)
75
+
76
+ def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
77
+ if token not in self._token_to_index[namespace]:
78
+ index = len(self._token_to_index[namespace])
79
+ self._token_to_index[namespace][token] = index
80
+ self._index_to_token[namespace][index] = token
81
+ return index
82
+ else:
83
+ return self._token_to_index[namespace][token]
84
+
85
+ def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
86
+ return self._index_to_token[namespace]
87
+
88
+ def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
89
+ return self._token_to_index[namespace]
90
+
91
+ def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
92
+ if token in self._token_to_index[namespace]:
93
+ return self._token_to_index[namespace][token]
94
+ else:
95
+ return self._token_to_index[namespace][self._oov_token]
96
+
97
+ def get_token_from_index(self, index: int, namespace: str = 'tokens'):
98
+ return self._index_to_token[namespace][index]
99
+
100
+ def get_vocab_size(self, namespace: str = 'tokens') -> int:
101
+ return len(self._token_to_index[namespace])
102
+
103
+ def save_to_files(self, directory: str):
104
+ os.makedirs(directory, exist_ok=True)
105
+ with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
106
+ for namespace_str in self._non_padded_namespaces:
107
+ f.write('{}\n'.format(namespace_str))
108
+
109
+ for namespace, token_to_index in self._token_to_index.items():
110
+ filename = os.path.join(directory, '{}.txt'.format(namespace))
111
+ with open(filename, 'w', encoding='utf-8') as f:
112
+ for token, _ in token_to_index.items():
113
+ f.write('{}\n'.format(token))
114
+
115
+ @classmethod
116
+ def from_files(cls, directory: str) -> 'Vocabulary':
117
+ with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
118
+ non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
119
+
120
+ vocab = cls(non_padded_namespaces=non_padded_namespaces)
121
+
122
+ for namespace_filename in os.listdir(directory):
123
+ if namespace_filename == NAMESPACE_PADDING_FILE:
124
+ continue
125
+ if namespace_filename.startswith("."):
126
+ continue
127
+ namespace = namespace_filename.replace('.txt', '')
128
+ if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
129
+ is_padded = False
130
+ else:
131
+ is_padded = True
132
+ filename = os.path.join(directory, namespace_filename)
133
+ vocab.set_from_file(filename, is_padded, namespace=namespace)
134
+
135
+ return vocab
136
+
137
+ def set_from_file(self,
138
+ filename: str,
139
+ is_padded: bool = True,
140
+ oov_token: str = DEFAULT_OOV_TOKEN,
141
+ namespace: str = "tokens"
142
+ ):
143
+ if is_padded:
144
+ self._token_to_index[namespace] = {self._padding_token: 0}
145
+ self._index_to_token[namespace] = {0: self._padding_token}
146
+ else:
147
+ self._token_to_index[namespace] = {}
148
+ self._index_to_token[namespace] = {}
149
+
150
+ with open(filename, 'r', encoding='utf-8') as f:
151
+ index = 1 if is_padded else 0
152
+ for row in f:
153
+ token = str(row).strip()
154
+ if token == oov_token:
155
+ token = self._oov_token
156
+ self._token_to_index[namespace][token] = index
157
+ self._index_to_token[namespace][index] = token
158
+ index += 1
159
+
160
+ def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
161
+ result = list()
162
+ for token in tokens:
163
+ idx = self._token_to_index[namespace].get(token)
164
+ if idx is None:
165
+ idx = self._token_to_index[namespace][self._oov_token]
166
+ result.append(idx)
167
+ return result
168
+
169
+ def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
170
+ result = list()
171
+ for idx in ids:
172
+ idx = self._index_to_token[namespace][idx]
173
+ result.append(idx)
174
+ return result
175
+
176
+ def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
177
+ pad_idx = self._token_to_index[namespace][self._padding_token]
178
+
179
+ length = len(ids)
180
+ if length > max_length:
181
+ result = ids[:max_length]
182
+ else:
183
+ result = ids + [pad_idx] * (max_length - length)
184
+ return result
185
+
186
+
187
+ def demo1():
188
+ import jieba
189
+
190
+ vocabulary = Vocabulary()
191
+ vocabulary.add_token_to_namespace('白天', 'tokens')
192
+ vocabulary.add_token_to_namespace('晚上', 'tokens')
193
+
194
+ text = '不是在白天, 就是在晚上'
195
+ tokens = jieba.lcut(text)
196
+
197
+ print(tokens)
198
+
199
+ ids = vocabulary.convert_tokens_to_ids(tokens)
200
+ print(ids)
201
+
202
+ padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
203
+ print(padded_idx)
204
+
205
+ tokens = vocabulary.convert_ids_to_tokens(padded_idx)
206
+ print(tokens)
207
+ return
208
+
209
+
210
+ if __name__ == '__main__':
211
+ demo1()
toolbox/torchaudio/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/configuration_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import copy
4
+ import os
5
+ from typing import Any, Dict, Union
6
+
7
+ import yaml
8
+
9
+
10
+ CONFIG_FILE = "config.yaml"
11
+
12
+
13
+ class PretrainedConfig(object):
14
+ def __init__(self, **kwargs):
15
+ pass
16
+
17
+ @classmethod
18
+ def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]):
19
+ with open(yaml_file, encoding="utf-8") as f:
20
+ config_dict = yaml.safe_load(f)
21
+ return config_dict
22
+
23
+ @classmethod
24
+ def get_config_dict(
25
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike]
26
+ ) -> Dict[str, Any]:
27
+ if os.path.isdir(pretrained_model_name_or_path):
28
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE)
29
+ else:
30
+ config_file = pretrained_model_name_or_path
31
+ config_dict = cls._dict_from_yaml_file(config_file)
32
+ return config_dict
33
+
34
+ @classmethod
35
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
36
+ for k, v in kwargs.items():
37
+ if k in config_dict.keys():
38
+ config_dict[k] = v
39
+ config = cls(**config_dict)
40
+ return config
41
+
42
+ @classmethod
43
+ def from_pretrained(
44
+ cls,
45
+ pretrained_model_name_or_path: Union[str, os.PathLike],
46
+ **kwargs,
47
+ ):
48
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path)
49
+ return cls.from_dict(config_dict, **kwargs)
50
+
51
+ def to_dict(self):
52
+ output = copy.deepcopy(self.__dict__)
53
+ return output
54
+
55
+ def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]):
56
+ config_dict = self.to_dict()
57
+
58
+ with open(yaml_file_path, "w", encoding="utf-8") as writer:
59
+ yaml.safe_dump(config_dict, writer)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ pass
toolbox/torchaudio/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/models/cnn_audio_classifier/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/models/cnn_audio_classifier/configuration_cnn_audio_classifier.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Any, Dict, List, Tuple, Union
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class CnnAudioClassifierConfig(PretrainedConfig):
9
+ def __init__(self,
10
+ mel_spectrogram_param: dict,
11
+ cls_head_param: dict,
12
+ conv1d_block_param_list: List[dict] = None,
13
+ conv2d_block_param_list: List[dict] = None,
14
+ **kwargs
15
+ ):
16
+ super(CnnAudioClassifierConfig, self).__init__(**kwargs)
17
+ self.mel_spectrogram_param = mel_spectrogram_param
18
+ self.cls_head_param = cls_head_param
19
+ self.conv1d_block_param_list = conv1d_block_param_list
20
+ self.conv2d_block_param_list = conv2d_block_param_list
21
+
22
+
23
+ if __name__ == "__main__":
24
+ pass
toolbox/torchaudio/models/cnn_audio_classifier/examples/conv2d_classifier.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 4
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 4
23
+ out_channels: 4
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 4
30
+ out_channels: 4
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head:
38
+ input_dim: 352
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 3
toolbox/torchaudio/models/cnn_audio_classifier/modeling_cnn_audio_classifier.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchaudio
9
+
10
+ from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig
11
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
+
13
+
14
+ MODEL_FILE = "model.pt"
15
+
16
+
17
+ name2activation = {
18
+ "relu": nn.ReLU,
19
+ }
20
+
21
+
22
+ class Conv1dBlock(nn.Module):
23
+ def __init__(self,
24
+ in_channels: int,
25
+ out_channels: int,
26
+ kernel_size: int,
27
+ stride: Tuple[int, int],
28
+ padding: str = 0,
29
+ dilation: int = 1,
30
+ batch_norm: bool = False,
31
+ activation: str = None,
32
+ dropout: float = None,
33
+ ):
34
+ super().__init__()
35
+ self.in_channels = in_channels
36
+ self.out_channels = out_channels
37
+ self.kernel_size = kernel_size
38
+
39
+ if batch_norm:
40
+ self.batch_norm = nn.BatchNorm1d(in_channels)
41
+ else:
42
+ self.batch_norm = None
43
+
44
+ self.conv = nn.Conv1d(
45
+ in_channels,
46
+ out_channels,
47
+ kernel_size=(kernel_size,),
48
+ stride=stride,
49
+ padding=padding,
50
+ dilation=(dilation,),
51
+ )
52
+
53
+ if activation is None:
54
+ self.activation = None
55
+ else:
56
+ self.activation = name2activation[activation]()
57
+
58
+ if dropout is not None:
59
+ self.dropout = nn.Dropout(p=dropout)
60
+ else:
61
+ self.dropout = None
62
+
63
+ def forward(self, x):
64
+ # x: [batch_size, seq_length, spec_dim]
65
+ x = torch.transpose(x, dim0=-1, dim1=-2)
66
+
67
+ # x: [batch_size, spec_dim, seq_length]
68
+ if self.batch_norm is not None:
69
+ x = self.batch_norm(x)
70
+
71
+ x = self.conv(x)
72
+
73
+ if self.activation is not None:
74
+ x = self.activation(x)
75
+
76
+ if self.dropout is not None:
77
+ x = self.dropout(x)
78
+
79
+ x = torch.transpose(x, dim0=-1, dim1=-2)
80
+ # x: [batch_size, seq_length, spec_dim]
81
+ return x
82
+
83
+
84
+ class Conv2dBlock(nn.Module):
85
+ def __init__(self,
86
+ in_channels: int,
87
+ out_channels: int,
88
+ kernel_size: Union[int, Tuple[int, int]],
89
+ stride: Tuple[int, int],
90
+ padding: str = 0,
91
+ dilation: int = 1,
92
+ batch_norm: bool = False,
93
+ activation: str = None,
94
+ dropout: float = None,
95
+ ):
96
+ super().__init__()
97
+ self.in_channels = in_channels
98
+ self.out_channels = out_channels
99
+ self.kernel_size: Tuple[int, int] = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
100
+
101
+ if batch_norm:
102
+ self.batch_norm = nn.BatchNorm2d(in_channels)
103
+ else:
104
+ self.batch_norm = None
105
+
106
+ self.conv = nn.Conv2d(
107
+ in_channels,
108
+ out_channels,
109
+ kernel_size=kernel_size,
110
+ stride=stride,
111
+ padding=(padding,),
112
+ dilation=(dilation,),
113
+ )
114
+
115
+ if activation is None:
116
+ self.activation = None
117
+ else:
118
+ self.activation = name2activation[activation]()
119
+
120
+ if dropout is not None:
121
+ self.dropout = nn.Dropout(p=dropout)
122
+ else:
123
+ self.dropout = None
124
+
125
+ def forward(self, x):
126
+
127
+ if self.batch_norm is not None:
128
+ x = self.batch_norm(x)
129
+
130
+ x = self.conv(x)
131
+
132
+ if self.activation is not None:
133
+ x = self.activation(x)
134
+
135
+ if self.dropout is not None:
136
+ x = self.dropout(x)
137
+
138
+ return x
139
+
140
+
141
+ class FeedForward(nn.Module):
142
+ def __init__(self,
143
+ input_dim: int,
144
+ num_layers: int,
145
+ hidden_dims: Union[int, List[int]],
146
+ activations: Union[str, List[str]],
147
+ dropout: Union[float, List[float]] = 0.0) -> None:
148
+
149
+ super(FeedForward, self).__init__()
150
+ if not isinstance(hidden_dims, list):
151
+ hidden_dims = [hidden_dims] * num_layers # type: ignore
152
+ if not isinstance(activations, list):
153
+ activations = [activations] * num_layers # type: ignore
154
+ if not isinstance(dropout, list):
155
+ dropout = [dropout] * num_layers # type: ignore
156
+ if len(hidden_dims) != num_layers:
157
+ raise AssertionError("len(hidden_dims) (%d) != num_layers (%d)" %
158
+ (len(hidden_dims), num_layers))
159
+ if len(activations) != num_layers:
160
+ raise AssertionError("len(activations) (%d) != num_layers (%d)" %
161
+ (len(activations), num_layers))
162
+ if len(dropout) != num_layers:
163
+ raise AssertionError("len(dropout) (%d) != num_layers (%d)" %
164
+ (len(dropout), num_layers))
165
+ self._activations = torch.nn.ModuleList([name2activation[activation]() for activation in activations])
166
+
167
+ input_dims = [input_dim] + hidden_dims[:-1]
168
+ linear_layers = []
169
+ for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims):
170
+ linear_layers.append(torch.nn.Linear(layer_input_dim, layer_output_dim))
171
+ self._linear_layers = torch.nn.ModuleList(linear_layers)
172
+ dropout_layers = [torch.nn.Dropout(p=value) for value in dropout]
173
+ self._dropout = torch.nn.ModuleList(dropout_layers)
174
+ self.output_dim = hidden_dims[-1]
175
+ self.input_dim = input_dim
176
+
177
+ def get_output_dim(self):
178
+ return self.output_dim
179
+
180
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
181
+ output = inputs
182
+ for layer, activation, dropout in zip(self._linear_layers, self._activations, self._dropout):
183
+ output = dropout(activation(layer(output)))
184
+ return output
185
+
186
+
187
+ class SpectrogramEncoder(nn.Module):
188
+ def __init__(self,
189
+ conv1d_block_param_list: List[dict] = None,
190
+ conv2d_block_param_list: List[dict] = None,
191
+ ):
192
+ super(SpectrogramEncoder, self).__init__()
193
+ if conv1d_block_param_list is None and conv2d_block_param_list is None:
194
+ raise AssertionError(
195
+ "At least one of the `conv1d_block_param_list` and `conv2d_block_param_list` is required."
196
+ )
197
+
198
+ self.conv1d_block_list = None
199
+ if conv1d_block_param_list is not None:
200
+ self.conv1d_block_list = nn.ModuleList(modules=[
201
+ Conv1dBlock(
202
+ **conv1d_block_param
203
+ )
204
+ for conv1d_block_param in conv1d_block_param_list
205
+ ])
206
+
207
+ self.conv2d_block_list = None
208
+ if conv2d_block_param_list is not None:
209
+ self.conv2d_block_list = nn.ModuleList(modules=[
210
+ Conv2dBlock(**conv2d_block_param)
211
+ for conv2d_block_param in conv2d_block_param_list
212
+ ])
213
+
214
+ def forward(self,
215
+ inputs: torch.Tensor,
216
+ ):
217
+ # x: [batch_size, spec_dim, seq_length]
218
+ x = inputs
219
+
220
+ if self.conv1d_block_list is not None:
221
+ for conv1d_block in self.conv1d_block_list:
222
+ x = conv1d_block(x)
223
+
224
+ if self.conv2d_block_list is not None:
225
+ x = torch.unsqueeze(x, dim=1)
226
+ # x: [batch_size, channel, seq_length, spec_dim]
227
+ for conv2d_block in self.conv2d_block_list:
228
+ x = conv2d_block(x)
229
+
230
+ # x: [batch_size, channel, seq_length, spec_dim]
231
+ x = torch.transpose(x, dim0=1, dim1=2)
232
+ # x: [batch_size, seq_length, channel, spec_dim]
233
+ batch_size, seq_length, channel, spec_dim = x.shape
234
+ x = torch.reshape(x, shape=(batch_size, seq_length, -1))
235
+
236
+ # x: [batch_size, seq_length, spec_dim]
237
+ return x
238
+
239
+
240
+ class WaveEncoder(nn.Module):
241
+ def __init__(self,
242
+ mel_spectrogram_param: dict,
243
+ conv1d_block_param_list: List[dict] = None,
244
+ conv2d_block_param_list: List[dict] = None,
245
+ ):
246
+ super(WaveEncoder, self).__init__()
247
+ if conv1d_block_param_list is None and conv2d_block_param_list is None:
248
+ raise AssertionError(
249
+ "At least one of the `conv1d_block_param_list` and `conv2d_block_param_list` is required."
250
+ )
251
+
252
+ self.wave_to_mel_spectrogram = torch.nn.Sequential(
253
+ torchaudio.transforms.MelSpectrogram(
254
+ sample_rate=mel_spectrogram_param["sample_rate"],
255
+ n_fft=mel_spectrogram_param["n_fft"],
256
+ win_length=mel_spectrogram_param["win_length"],
257
+ hop_length=mel_spectrogram_param["hop_length"],
258
+ f_min=mel_spectrogram_param["f_min"],
259
+ f_max=mel_spectrogram_param["f_max"],
260
+ window_fn=torch.hamming_window if mel_spectrogram_param["window_fn"] == "hamming" else torch.hann_window,
261
+ n_mels=mel_spectrogram_param["n_mels"],
262
+ ),
263
+ )
264
+
265
+ self.spectrogram_encoder = SpectrogramEncoder(
266
+ conv1d_block_param_list=conv1d_block_param_list,
267
+ conv2d_block_param_list=conv2d_block_param_list,
268
+ )
269
+
270
+ def forward(self, inputs: torch.Tensor):
271
+ # x: [batch_size, spec_dim, seq_length]
272
+ x = inputs
273
+
274
+ with torch.no_grad():
275
+ # shape = [batch_size, spec_dim, seq_length]
276
+ x = self.wave_to_mel_spectrogram(x) + 1e-6
277
+ x = x.log()
278
+ x = x - torch.mean(x, dim=-1, keepdim=True)
279
+
280
+ x = x.transpose(1, 2)
281
+
282
+ features = self.spectrogram_encoder.forward(x)
283
+ # features: [batch_size, seq_length, spec_dim]
284
+ return features
285
+
286
+
287
+ class ClsHead(nn.Module):
288
+ def __init__(self,
289
+ input_dim: int,
290
+ num_layers: int,
291
+ hidden_dims: Union[int, List[int]],
292
+ activations: Union[str, List[str]],
293
+ num_labels: int,
294
+ dropout: Union[float, List[float]] = 0.0
295
+ ):
296
+ super(ClsHead, self).__init__()
297
+
298
+ self.feedforward = FeedForward(
299
+ input_dim=input_dim,
300
+ num_layers=num_layers,
301
+ hidden_dims=hidden_dims,
302
+ activations=activations,
303
+ dropout=dropout,
304
+ )
305
+
306
+ self.output_project_layer = nn.Linear(self.feedforward.get_output_dim(), num_labels)
307
+
308
+ def forward(self, inputs: torch.Tensor):
309
+ # inputs: [batch_size, seq_length, spec_dim]
310
+ x = self.feedforward(inputs)
311
+ # x: [batch_size, seq_length, hidden_size]
312
+
313
+ x = torch.mean(x, dim=1)
314
+ # x: [batch_size, hidden_size]
315
+
316
+ logits = self.output_project_layer.forward(x)
317
+ # logits: [batch_size, num_labels]
318
+ return logits
319
+
320
+
321
+ class WaveClassifier(nn.Module):
322
+ def __init__(self,
323
+ wave_encoder: WaveEncoder,
324
+ cls_head: ClsHead,
325
+ ):
326
+ super(WaveClassifier, self).__init__()
327
+ self.wave_encoder = wave_encoder
328
+ self.cls_head = cls_head
329
+
330
+ def forward(self, inputs: torch.Tensor):
331
+ # x: [batch_size, spec_dim, seq_length]
332
+ x = inputs
333
+
334
+ x = self.wave_encoder.forward(x)
335
+
336
+ # x: [batch_size, seq_length, spec_dim]
337
+ logits = self.cls_head.forward(x)
338
+
339
+ # logits: [batch_size, num_labels]
340
+ return logits
341
+
342
+
343
+ class WaveClassifierPretrainedModel(WaveClassifier):
344
+ def __init__(self,
345
+ config: CnnAudioClassifierConfig,
346
+ ):
347
+ super(WaveClassifierPretrainedModel, self).__init__(
348
+ wave_encoder=WaveEncoder(
349
+ mel_spectrogram_param=config.mel_spectrogram_param,
350
+ conv1d_block_param_list=config.conv1d_block_param_list,
351
+ conv2d_block_param_list=config.conv2d_block_param_list,
352
+ ),
353
+ cls_head=ClsHead(
354
+ input_dim=config.cls_head_param["input_dim"],
355
+ num_layers=config.cls_head_param["num_layers"],
356
+ hidden_dims=config.cls_head_param["hidden_dims"],
357
+ activations=config.cls_head_param["activations"],
358
+ num_labels=config.cls_head_param["num_labels"],
359
+ dropout=config.cls_head_param["dropout"],
360
+ )
361
+ )
362
+ self.config = config
363
+
364
+ @classmethod
365
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
366
+ config = CnnAudioClassifierConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
367
+
368
+ model = cls(config)
369
+
370
+ if os.path.isdir(pretrained_model_name_or_path):
371
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
372
+ else:
373
+ ckpt_file = pretrained_model_name_or_path
374
+
375
+ with open(ckpt_file, "rb") as f:
376
+ state_dict = torch.load(f, map_location="cpu")
377
+ model.load_state_dict(state_dict, strict=True)
378
+ return model
379
+
380
+ def save_pretrained(self,
381
+ save_directory: Union[str, os.PathLike],
382
+ state_dict: Optional[dict] = None,
383
+ ):
384
+
385
+ model = self
386
+
387
+ if state_dict is None:
388
+ state_dict = model.state_dict()
389
+
390
+ os.makedirs(save_directory, exist_ok=True)
391
+
392
+ # save state dict
393
+ model_file = os.path.join(save_directory, MODEL_FILE)
394
+ torch.save(state_dict, model_file)
395
+
396
+ # save config
397
+ config_file = os.path.join(save_directory, CONFIG_FILE)
398
+ self.config.to_yaml_file(config_file)
399
+ return save_directory
400
+
401
+
402
+ if __name__ == "__main__":
403
+ pass
toolbox/torchaudio/models/lstm_audio_classifier/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/lstm_audio_classifier/configuration_lstm_audio_classifier.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class WaveClassifierConfig(PretrainedConfig):
7
+ def __init__(self,
8
+ mel_spectrogram_param: dict,
9
+ lstm_layer_param: dict,
10
+ pooling_layer_param: dict,
11
+ cls_head_param: dict,
12
+ **kwargs
13
+ ):
14
+ super(WaveClassifierConfig, self).__init__(**kwargs)
15
+ self.mel_spectrogram_param = mel_spectrogram_param
16
+ self.lstm_layer_param = lstm_layer_param
17
+ self.pooling_layer_param = pooling_layer_param
18
+ self.cls_head_param = cls_head_param
19
+
20
+
21
+ if __name__ == "__main__":
22
+ pass