Miaode commited on
Commit
8499730
·
verified ·
1 Parent(s): 2917cfc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. DataFlow/OpenSeek_CoT/Openseek_Math_Chain_of_Thoughts_pipeline.py.py +137 -0
  3. DataFlow/dataflow/__init__.py +14 -0
  4. DataFlow/dataflow/cli.py +80 -0
  5. DataFlow/dataflow/core/LLMServing.py +27 -0
  6. DataFlow/dataflow/core/Operator.py +31 -0
  7. DataFlow/dataflow/core/__init__.py +7 -0
  8. DataFlow/dataflow/core/__pycache__/LLMServing.cpython-310.pyc +0 -0
  9. DataFlow/dataflow/core/__pycache__/Operator.cpython-310.pyc +0 -0
  10. DataFlow/dataflow/core/__pycache__/__init__.cpython-310.pyc +0 -0
  11. DataFlow/dataflow/logger.py +38 -0
  12. DataFlow/dataflow/operators/__init__.py +4 -0
  13. DataFlow/dataflow/operators/eval/AgenticRAG/statistics/f1_scorer.py +108 -0
  14. DataFlow/dataflow/operators/eval/GeneralText/APIcaller/__pycache__/perspective_scorer.cpython-310.pyc +0 -0
  15. DataFlow/dataflow/operators/eval/GeneralText/APIcaller/meta_scorer.py +70 -0
  16. DataFlow/dataflow/operators/eval/GeneralText/APIcaller/treeinstruct_scorer.py +53 -0
  17. DataFlow/dataflow/operators/eval/GeneralText/__init__.py +55 -0
  18. DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/task2vec_scorer.cpython-310.pyc +0 -0
  19. DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/vendi_scorer.cpython-310.pyc +0 -0
  20. DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task2vec.cpython-310.pyc +0 -0
  21. DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task_similarity.cpython-310.pyc +0 -0
  22. DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/utils.cpython-310.pyc +0 -0
  23. DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task2vec.py +544 -0
  24. DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task_similarity.py +485 -0
  25. DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/utils.py +76 -0
  26. DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec_scorer.py +76 -0
  27. DataFlow/dataflow/operators/eval/GeneralText/diversity/vendi_scorer.py +36 -0
  28. DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bert_scorer.cpython-310.pyc +0 -0
  29. DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bleu_scorer.cpython-310.pyc +0 -0
  30. DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/cider_scorer.cpython-310.pyc +0 -0
  31. DataFlow/dataflow/operators/eval/GeneralText/gen/bert_scorer.py +46 -0
  32. DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__init__.py +0 -0
  33. DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/__init__.cpython-310.pyc +0 -0
  34. DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/bleu.cpython-310.pyc +0 -0
  35. DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/bleu.py +236 -0
  36. DataFlow/dataflow/operators/eval/GeneralText/gen/bleu_scorer.py +47 -0
  37. DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__init__.py +0 -0
  38. DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/__init__.cpython-310.pyc +0 -0
  39. DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/cider.cpython-310.pyc +0 -0
  40. DataFlow/dataflow/operators/eval/GeneralText/gen/cider/cider.py +134 -0
  41. DataFlow/dataflow/operators/eval/GeneralText/gen/cider_scorer.py +60 -0
  42. DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/__pycache__/model.cpython-310.pyc +0 -0
  43. DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/model.py +161 -0
  44. DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/__pycache__/qurater_annotate.cpython-310.pyc +0 -0
  45. DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/__pycache__/modeling_flash_llama.cpython-310.pyc +0 -0
  46. DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/modeling_flash_llama.py +853 -0
  47. DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/qurater_annotate.py +190 -0
  48. DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/__pycache__/data_analysis.cpython-310.pyc +0 -0
  49. DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/data_analysis.py +53 -0
  50. DataFlow/dataflow/operators/eval/GeneralText/models/__pycache__/debertav3_scorer.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -34,3 +34,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  baidu.zip filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  baidu.zip filter=lfs diff=lfs merge=lfs -text
37
+ Qwen2.5-Math/evaluation/latex2sympy/antlr-4.11.1-complete.jar filter=lfs diff=lfs merge=lfs -text
38
+ report.pdf filter=lfs diff=lfs merge=lfs -text
39
+ Qwen2.5-Math/evaluation/data/tabmwp/test.jsonl filter=lfs diff=lfs merge=lfs -text
40
+ Reproducibility/evaluation_log/result/ernie_openseek/math/test_qwen25-math-cot_-1_seed0_t0.0_s0_e-1.jsonl filter=lfs diff=lfs merge=lfs -text
41
+ Reproducibility/cleaned_data/for_erniekit_training/sft_dataflow_finemath.jsonl filter=lfs diff=lfs merge=lfs -text
42
+ Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSLexer.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
43
+ Reproducibility/cleaned_data/for_erniekit_training/sft_dataflow_dolmino.jsonl filter=lfs diff=lfs merge=lfs -text
44
+ Reproducibility/cleaned_data/dataflow_finemath.jsonl filter=lfs diff=lfs merge=lfs -text
45
+ Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSLexer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
46
+ Reproducibility/evaluation_log/result/ernie_dataflow/math/test_qwen25-math-cot_-1_seed0_t0.0_s0_e-1.jsonl filter=lfs diff=lfs merge=lfs -text
47
+ Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSParser.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
48
+ Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSParser.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
49
+ Reproducibility/cleaned_data/dataflow_dolmino.jsonl filter=lfs diff=lfs merge=lfs -text
50
+ ernie/ERNIE/examples/pre-training/demo_data/data-1-part1.idx filter=lfs diff=lfs merge=lfs -text
DataFlow/OpenSeek_CoT/Openseek_Math_Chain_of_Thoughts_pipeline.py.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataflow.operators.generate import (
2
+ QuestionCategoryClassifier,
3
+ QuestionDifficultyClassifier,
4
+ QuestionGenerator,
5
+ AnswerGenerator,
6
+ )
7
+
8
+ from dataflow.operators.filter import (
9
+ QuestionFilter,
10
+ AnswerPipelineRoot,
11
+ AnswerFormatterFilter,
12
+ AnswerTokenLengthFilter,
13
+ AnswerGroundTruthFilter,
14
+ AnswerNgramFilter,
15
+ )
16
+ from dataflow.utils.storage import FileStorage
17
+ from dataflow.serving import APILLMServing_request, LocalModelLLMServing
18
+
19
+
20
+ class ReasoningPipeline():
21
+ def __init__(self):
22
+
23
+ self.storage = FileStorage(
24
+ # first_entry_file_name="/cpfs/user/boyuan/verl_workspace/baidu/DataFlow/demo/example_data/ReasoningPipeline/pipeline_math_short.json",
25
+ # first_entry_file_name="/cpfs/user/boyuan/verl_workspace/baidu/data/fulldata/math/dolmino-mix-1124-math-merged.jsonl",
26
+ first_entry_file_name="/cpfs/user/boyuan/verl_workspace/baidu/data/fulldata/math/second_half_math.jsonl",
27
+ cache_path="./second_half_math",
28
+ file_name_prefix="dataflow_cache_step",
29
+ cache_type="jsonl",
30
+ )
31
+
32
+ # use API server as LLM serving
33
+ llm_serving = APILLMServing_request(
34
+ api_url="http://10.39.1.99:23456/v1/chat/completions",
35
+ # api_url = "https://aistudio.baidu.com/llm/lmapi/v3",
36
+ # api_url= "https://api.deepseek.com/v1/chat/completions",
37
+ model_name="ernie300",
38
+ # model_name="ernie-4.5-turbo-128k-preview",
39
+ # model_name="qwen3",
40
+ max_workers=50
41
+ )
42
+
43
+
44
+ self.question_filter_step3 = QuestionFilter(
45
+ system_prompt="You are an expert in evaluating mathematical problems. Follow the user's instructions strictly and output your final judgment in the required JSON format.",
46
+ llm_serving=llm_serving
47
+ )
48
+ self.question_difficulty_classifier_step4 = QuestionDifficultyClassifier(
49
+ llm_serving=llm_serving
50
+ )
51
+ self.question_category_classifier_step5 = QuestionCategoryClassifier(
52
+ llm_serving=llm_serving
53
+ )
54
+ ########################## branch ############################
55
+ # self.answer_pipeline_root_step6 = AnswerPipelineRoot()
56
+ ########################## answer ############################
57
+ self.answer_generator_step7 = AnswerGenerator(
58
+ llm_serving=llm_serving
59
+ )
60
+
61
+ self.answer_format_filter_step8 = AnswerFormatterFilter()
62
+
63
+ self.answer_token_length_filter_step9 = AnswerTokenLengthFilter(
64
+ max_answer_token_length = 8192,
65
+ tokenizer_dir = "/cpfs/user/boyuan/verl_workspace/baidu/models300/qwen3",
66
+ )
67
+
68
+ self.answer_groundtruth_filter_step10 = AnswerGroundTruthFilter()
69
+
70
+ self.answer_ngram_filter_step11 = AnswerNgramFilter(
71
+ min_score = 0.1,
72
+ max_score = 1.0,
73
+ ngrams = 5
74
+ )
75
+
76
+ def forward(self):
77
+
78
+ # self.question_filter_step1.run(
79
+ # storage = self.storage.step(),
80
+ # input_key = "instruction",
81
+ # )
82
+
83
+ # self.question_gen_step2.run(
84
+ # storage = self.storage.step(),
85
+ # input_key = "instruction",
86
+ # )
87
+
88
+ self.question_filter_step3.run(
89
+ storage = self.storage.step(),
90
+ input_key = "instruction",
91
+ )
92
+
93
+ self.question_difficulty_classifier_step4.run(
94
+ storage = self.storage.step(),
95
+ input_key = "instruction",
96
+ output_key = "question_difficulty"
97
+ )
98
+
99
+ self.question_category_classifier_step5.run(
100
+ storage = self.storage.step(),
101
+ input_key = "instruction",
102
+ output_key = "question_category"
103
+ )
104
+ ############# branch #############
105
+ # self.answer_pipeline_root_step6.run(
106
+ # storage = self.storage.step(),
107
+ # input_answer_key = "output",
108
+ # input_gt_key = "golden_answer"
109
+ # )
110
+ ############## answer #############
111
+ self.answer_generator_step7.run(
112
+ storage = self.storage.step(),
113
+ input_key = "instruction",
114
+ output_key = "generated_cot"
115
+ )
116
+ self.answer_format_filter_step8.run(
117
+ storage = self.storage.step(),
118
+ input_key = "generated_cot",
119
+ )
120
+ self.answer_token_length_filter_step9.run(
121
+ storage = self.storage.step(),
122
+ input_key = "generated_cot"
123
+ )
124
+ # self.answer_groundtruth_filter_step10.run(
125
+ # storage = self.storage.step(),
126
+ # test_answer_key = "generated_cot",
127
+ # gt_answer_key = "golden_answer"
128
+ # )
129
+ self.answer_ngram_filter_step11.run(
130
+ storage = self.storage.step(),
131
+ question_key = "instruction",
132
+ answer_key = "generated_cot"
133
+ )
134
+
135
+ if __name__ == "__main__":
136
+ model = ReasoningPipeline()
137
+ model.forward()
DataFlow/dataflow/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import *
2
+ from .version import __version__, version_info
3
+ from .logger import get_logger
4
+ from .operators import *
5
+ __all__ = [
6
+ '__version__',
7
+ 'version_info',
8
+ 'get_logger',
9
+ ]
10
+
11
+
12
+
13
+ def hello():
14
+ return "Hello from open-dataflow!"
DataFlow/dataflow/cli.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import requests
4
+
5
+ from colorama import init, Fore, Style
6
+
7
+ # from dataflow.utils.paths import BencoPath
8
+ from dataflow.cli_funcs import cli_env, cli_init
9
+ import importlib.metadata
10
+
11
+ PYPI_API_URL = 'https://pypi.org/pypi/open-dataflow/json'
12
+ from dataflow.version import __version__
13
+
14
+ def version_and_check_for_updates():
15
+ # print a bar by the length of the shell width
16
+ print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
17
+ print(f'open-dataflow codebase version: {__version__}')
18
+ try:
19
+ response = requests.get(PYPI_API_URL, timeout=5)
20
+ response.raise_for_status() # 如果响应码不是200,则抛出异常
21
+ pypi_data = response.json()
22
+ cloud_version = pypi_data['info']['version'] # 获取最新版本号
23
+ # cloud_version = '0.1.21' # for debug & test
24
+ print("\tChecking for updates...")
25
+ print("\tLocal version: ", __version__)
26
+ print("\tPyPI newest version: ", cloud_version)
27
+
28
+ local_version = __version__ # 通过 importlib.metadata 获取当前安装版本
29
+
30
+ if cloud_version != local_version:
31
+ print(Fore.YELLOW + f"New version available: {cloud_version}. Your version: {local_version}." + Style.RESET_ALL)
32
+ print("Run 'pip install --upgrade open-dataflow' to upgrade.")
33
+ else:
34
+ print(Fore.GREEN + f"You are using the latest version: {local_version}." + Style.RESET_ALL)
35
+ except requests.exceptions.RequestException as e:
36
+ print(Fore.RED + "Failed to check for updates from PyPI. Please check your internet connection." + Style.RESET_ALL)
37
+ print(f"Error: {e}")
38
+ print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
39
+ def main():
40
+ parser = argparse.ArgumentParser(description='Command line interface for DataFlow, with codebase version: ' + __version__)
41
+
42
+ # Add version argument with update check only when user requests version
43
+ parser.add_argument('-v', '--version', action='store_true', help="Show the version of the tool")
44
+
45
+ subparsers = parser.add_subparsers(dest='command', required=False)
46
+
47
+ # init command
48
+ parser_init = subparsers.add_parser('init', help='Initialize the scripts and configs in a directory')
49
+ init_subparsers = parser_init.add_subparsers(dest='subcommand', required=False)
50
+
51
+ # init all
52
+ parser_init_all = init_subparsers.add_parser('all', help='Initialize all components')
53
+ parser_init_all.set_defaults(subcommand='all')
54
+
55
+ # init reasoning
56
+ parser_init_reasoning = init_subparsers.add_parser('reasoning', help='Initialize reasoning components')
57
+ parser_init_reasoning.set_defaults(subcommand='reasoning')
58
+
59
+ # env command
60
+ parser_env = subparsers.add_parser('env', help='Show environment information')
61
+
62
+ # parser.add_argument('--config', type=str, help='Path to the configuration file')
63
+
64
+ args = parser.parse_args()
65
+ if args.version:
66
+ version_and_check_for_updates()
67
+
68
+ if args.command == 'init':
69
+ if args.subcommand is None:
70
+ args.subcommand = 'base'
71
+ cli_init(subcommand=args.subcommand)
72
+ # print("TODO Calling cli_init with subcommand:", args.subcommand)
73
+ from dataflow.cli_funcs.paths import DataFlowPath
74
+ # print(DataFlowPath.get_dataflow_dir())
75
+ # print(DataFlowPath.get_dataflow_scripts_dir())
76
+ elif args.command == 'env':
77
+ cli_env()
78
+
79
+ if __name__ == '__main__':
80
+ main()
DataFlow/dataflow/core/LLMServing.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, List
3
+
4
+
5
+ class LLMServingABC(ABC):
6
+ """Abstract base class for data generators. Which may be used to generate data from a model or API. Called by operators
7
+ """
8
+ @abstractmethod
9
+ def generate_from_input(self, user_inputs: List[str], system_prompt: str) -> List[str]:
10
+ """
11
+ Generate data from input.
12
+ input: List[str], the input of the generator
13
+ """
14
+ pass
15
+ @abstractmethod
16
+ def cleanup(self):
17
+ """
18
+ Cleanup the generator and garbage collect all GPU/CPU memory.
19
+ """
20
+ pass
21
+
22
+ def load_model(self, model_name_or_path: str, **kwargs: Any):
23
+ """
24
+ Load the model from the given path.
25
+ This method is optional and can be overridden by subclasses if needed.
26
+ """
27
+ raise NotImplementedError("This method should be implemented by subclasses.")
DataFlow/dataflow/core/Operator.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataflow.logger import get_logger
3
+
4
+ class OperatorABC(ABC):
5
+
6
+ # @abstractmethod
7
+ # def check_config(self, config: dict) -> None:
8
+ # """
9
+ # Check the config of the operator. If config lacks any required keys, raise an error.
10
+ # """
11
+ # pass
12
+
13
+ @abstractmethod
14
+ def run(self) -> None:
15
+ """
16
+ Main function to run the operator.
17
+ """
18
+ pass
19
+
20
+ def get_operator(operator_name, args) -> OperatorABC:
21
+ from dataflow.utils import OPERATOR_REGISTRY
22
+ print(operator_name, args)
23
+ operator = OPERATOR_REGISTRY.get(operator_name)(args)
24
+ logger = get_logger()
25
+ if operator is not None:
26
+ logger.info(f"Successfully get operator {operator_name}, args {args}")
27
+ else:
28
+ logger.error(f"operator {operator_name} is not found")
29
+ assert operator is not None
30
+ print(operator)
31
+ return operator
DataFlow/dataflow/core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .Operator import OperatorABC, get_operator
2
+ from .LLMServing import LLMServingABC
3
+ __all__ = [
4
+ 'OperatorABC',
5
+ 'get_operator',
6
+ 'LLMServingABC',
7
+ ]
DataFlow/dataflow/core/__pycache__/LLMServing.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
DataFlow/dataflow/core/__pycache__/Operator.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
DataFlow/dataflow/core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (317 Bytes). View file
 
DataFlow/dataflow/logger.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import colorlog
3
+
4
+ # 自定义日志等级
5
+ SUCCESS_LEVEL_NUM = 25
6
+ logging.addLevelName(SUCCESS_LEVEL_NUM, "SUCCESS")
7
+
8
+ def success(self, message, *args, **kwargs):
9
+ if self.isEnabledFor(SUCCESS_LEVEL_NUM):
10
+ self._log(SUCCESS_LEVEL_NUM, message, args, **kwargs)
11
+
12
+ logging.Logger.success = success # 添加方法到 Logger 类
13
+
14
+ def get_logger(level=logging.INFO) -> logging.Logger:
15
+ # 创建logger对象
16
+ logger = logging.getLogger("DataFlow")
17
+ if not logger.handlers:
18
+ logger.setLevel(level)
19
+ # 创建控制台日志处理器
20
+ console_handler = logging.StreamHandler()
21
+ console_handler.setLevel(level)
22
+ # 定义颜色输出格式
23
+ color_formatter = colorlog.ColoredFormatter(
24
+ '%(log_color)s %(asctime)s | %(filename)-20s- %(module)-20s- %(funcName)-20s- %(lineno)5d - %(name)-10s | %(levelname)8s | Processno %(process)5d - Threadno %(thread)-15d : %(message)s',
25
+ log_colors={
26
+ 'DEBUG': 'cyan',
27
+ # 'INFO': 'white',
28
+ 'SUCCESS': 'green',
29
+ 'WARNING': 'yellow',
30
+ 'ERROR': 'red',
31
+ 'CRITICAL': 'red,bg_white',
32
+ }
33
+ )
34
+ # 将颜色输出格式添加到控制台日志处理器
35
+ console_handler.setFormatter(color_formatter)
36
+ # 将控制台日志处理器添加到logger对象
37
+ logger.addHandler(console_handler)
38
+ return logger
DataFlow/dataflow/operators/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from .eval import *
2
+ # from .generate import *
3
+ # from .filter import *
4
+ # from .refine import *
DataFlow/dataflow/operators/eval/AgenticRAG/statistics/f1_scorer.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+ from collections import Counter
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ from dataflow.core import OperatorABC
7
+ from dataflow.utils.storage import DataFlowStorage
8
+ from dataflow.utils.registry import OPERATOR_REGISTRY
9
+ from dataflow import get_logger
10
+
11
+
12
+ @OPERATOR_REGISTRY.register()
13
+ class F1Scorer(OperatorABC):
14
+
15
+ def __init__(self, prediction_key, ground_truth_key):
16
+ self.logger = get_logger()
17
+ self.logger.info(f"Initializing {self.__class__.__name__}...")
18
+ self.prediction_key = prediction_key
19
+ self.ground_truth_key = ground_truth_key
20
+ self.logger.info(f"{self.__class__.__name__} initialized.")
21
+
22
+ @staticmethod
23
+ def get_desc(lang: str = "zh"):
24
+ return "用于评估预测答案与多个参考答案之间的 F1 分数"
25
+
26
+ def _validate_dataframe(self, dataframe: pd.DataFrame):
27
+ required_keys = [self.prediction_key, self.ground_truth_key]
28
+ forbidden_keys = [self.output_key ]
29
+
30
+ missing = [k for k in required_keys if k not in dataframe.columns]
31
+ conflict = [k for k in forbidden_keys if k in dataframe.columns]
32
+
33
+ if missing:
34
+ raise ValueError(f"Missing required column(s): {missing}")
35
+ if conflict:
36
+ raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
37
+
38
+ def normalize_answer(self, s: str) -> str:
39
+ def remove_articles(text):
40
+ return re.sub(r"\b(a|an|the)\b", " ", text)
41
+
42
+ def white_space_fix(text):
43
+ return " ".join(text.split())
44
+
45
+ def remove_punc(text):
46
+ exclude = set(string.punctuation)
47
+ return "".join(ch for ch in text if ch not in exclude)
48
+
49
+ def lower(text):
50
+ return text.lower()
51
+
52
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
53
+
54
+ def compute_f1(self, prediction: str, ground_truths) -> float:
55
+ if prediction is None or ground_truths is None:
56
+ return 0.0
57
+
58
+ if isinstance(ground_truths, str):
59
+ ground_truths = [ground_truths]
60
+
61
+ max_f1 = 0.0
62
+
63
+ for ground_truth in ground_truths:
64
+ if ground_truth is None:
65
+ continue
66
+
67
+ normalized_prediction = self.normalize_answer(prediction)
68
+ normalized_ground_truth = self.normalize_answer(ground_truth)
69
+
70
+ if normalized_prediction in ["yes", "no", "noanswer"] or normalized_ground_truth in ["yes", "no", "noanswer"]:
71
+ if normalized_prediction != normalized_ground_truth:
72
+ continue
73
+
74
+ pred_tokens = normalized_prediction.split()
75
+ gold_tokens = normalized_ground_truth.split()
76
+ common = Counter(pred_tokens) & Counter(gold_tokens)
77
+ num_same = sum(common.values())
78
+
79
+ if num_same == 0:
80
+ continue
81
+
82
+ precision = num_same / len(pred_tokens)
83
+ recall = num_same / len(gold_tokens)
84
+ f1 = (2 * precision * recall) / (precision + recall)
85
+ max_f1 = max(max_f1, f1)
86
+
87
+ return max_f1
88
+
89
+ def eval(self, dataframe: pd.DataFrame) -> list:
90
+ self.logger.info(f"Evaluating {self.output_key}...")
91
+ f1_scores = []
92
+
93
+ for _, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="F1Scorer Evaluating..."):
94
+ prediction = row.get(self.prediction_key, None)
95
+ ground_truths = row.get(self.ground_truth_key, None)
96
+ score = self.compute_f1(prediction, ground_truths)
97
+ f1_scores.append(score)
98
+
99
+ self.logger.info("Evaluation complete!")
100
+ return f1_scores
101
+
102
+ def run(self, storage: DataFlowStorage, output_key):
103
+ dataframe = storage.read("dataframe")
104
+ self.output_key = output_key
105
+ self._validate_dataframe(dataframe)
106
+ scores = self.eval(dataframe)
107
+ dataframe[self.output_key] = scores
108
+ storage.write(dataframe)
DataFlow/dataflow/operators/eval/GeneralText/APIcaller/__pycache__/perspective_scorer.cpython-310.pyc ADDED
Binary file (2.3 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/APIcaller/meta_scorer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataflow.utils.registry import OPERATOR_REGISTRY
2
+ from dataflow import get_logger
3
+ from dataflow.core import OperatorABC
4
+ from dataflow.utils.storage import DataFlowStorage
5
+ import pandas as pd
6
+ from dataflow.core import LLMServingABC
7
+ from dataflow.prompts.general_text import MetaPrompt
8
+ import ast
9
+
10
+ @OPERATOR_REGISTRY.register()
11
+ class MetaScorer(OperatorABC):
12
+ def __init__(self, llm_serving: LLMServingABC = None):
13
+ self.logger = get_logger()
14
+ self.logger.info(f'Initializing {self.__class__.__name__}...')
15
+ self.llm_serving = llm_serving
16
+ self.score_name = 'MetaScore'
17
+ self.prompt = MetaPrompt()
18
+ self.logger.info(f'{self.__class__.__name__} initialized.')
19
+
20
+ self.output_columns = [
21
+ "Text Structure",
22
+ "Diversity & Complexity",
23
+ "Fluency & Understandability",
24
+ "Safety",
25
+ "Educational Value",
26
+ "Content Accuracy & Effectiveness"
27
+ ]
28
+
29
+ def get_score(self, samples, input_key):
30
+ system_prompt = self.prompt.build_system_prompt()
31
+ user_prompts = []
32
+ for sample in samples:
33
+ input_text = sample.get(input_key, '')
34
+ user_prompt = self.prompt.build_user_prompt(input_text)
35
+ full_prompt = system_prompt + "\n" + user_prompt
36
+ user_prompts.append(full_prompt)
37
+
38
+ responses = self.llm_serving.generate_from_input(user_inputs=user_prompts)
39
+ scores = []
40
+
41
+ for i, response in enumerate(responses):
42
+ try:
43
+ lines = response.strip().split("\n")
44
+ last_line = lines[-1].strip()
45
+ parsed_scores = ast.literal_eval(last_line)
46
+ if isinstance(parsed_scores, list) and len(parsed_scores) == 6:
47
+ scores.append(parsed_scores)
48
+ else:
49
+ raise ValueError("Score format invalid")
50
+ except Exception as e:
51
+ self.logger.warning(f"Failed to extract score from response {i}: {e}")
52
+ scores.append([float('nan')] * 6)
53
+
54
+ return scores
55
+
56
+ def eval(self, dataframe: pd.DataFrame, input_key: str):
57
+ samples = dataframe.to_dict(orient='records')
58
+ self.logger.info(f"Evaluating {self.score_name}...")
59
+ scores = self.get_score(samples, input_key)
60
+ self.logger.info("Evaluation complete!")
61
+ return scores
62
+
63
+ def run(self, storage: DataFlowStorage, input_key: str):
64
+ self.input_key = input_key
65
+ dataframe = storage.read("dataframe")
66
+ scores = self.eval(dataframe, self.input_key)
67
+ # 展开6列固定命名
68
+ score_df = pd.DataFrame(scores, columns=self.output_columns)
69
+ dataframe = pd.concat([dataframe, score_df], axis=1)
70
+ storage.write(dataframe)
DataFlow/dataflow/operators/eval/GeneralText/APIcaller/treeinstruct_scorer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataflow.utils.registry import OPERATOR_REGISTRY
2
+ from dataflow import get_logger
3
+ from dataflow.core import OperatorABC
4
+ from dataflow.utils.storage import DataFlowStorage
5
+ import pandas as pd
6
+ from dataflow.core import LLMServingABC
7
+ from dataflow.prompts.general_text import TreeinstructPrompt
8
+
9
+ @OPERATOR_REGISTRY.register()
10
+ class TreeinstructScorer(OperatorABC):
11
+ def __init__(self, llm_serving: LLMServingABC = None):
12
+ self.logger = get_logger()
13
+ self.logger.info(f'Initializing {self.__class__.__name__}...')
14
+ self.llm_serving = llm_serving
15
+ self.score_name = 'TreeinstructScore'
16
+ self.prompt = TreeinstructPrompt()
17
+ self.logger.info(f'{self.__class__.__name__} initialized.')
18
+
19
+ def get_score(self, samples, input_instruction_key):
20
+ system_prompts = []
21
+ user_prompts = []
22
+ for sample in samples:
23
+ instruction = sample.get(input_instruction_key, [''])
24
+ system_prompts.append(self.prompt.build_system_prompt(instruction))
25
+ user_prompts.append(self.prompt.build_user_prompt())
26
+
27
+ inputs = [system + "\n" + user for system, user in zip(system_prompts, user_prompts)]
28
+ responses = self.llm_serving.generate_from_input(user_inputs=inputs)
29
+
30
+ scores = []
31
+ for response in responses:
32
+ response_lines = response.strip().split("\n")
33
+ score_line = response_lines[-1]
34
+ score = float(score_line.split()[0])
35
+ scores.append(score)
36
+
37
+ return scores
38
+
39
+ def eval(self, dataframe: pd.DataFrame, input_instruction_key: str):
40
+ self.logger.info(f"Evaluating {self.score_name}...")
41
+ samples = dataframe.to_dict(orient='records')
42
+ scores = self.get_score(samples, input_instruction_key)
43
+ self.logger.info("Evaluation complete!")
44
+ return scores
45
+
46
+
47
+ def run(self, storage: DataFlowStorage, input_instruction_key: str, output_key: str='TreeinstructScore'):
48
+ self.input_instruction_key = input_instruction_key
49
+ self.output_key = output_key
50
+ dataframe = storage.read("dataframe")
51
+ scores = self.eval(dataframe, self.input_instruction_key)
52
+ dataframe[self.output_key] = scores
53
+ storage.write(dataframe)
DataFlow/dataflow/operators/eval/GeneralText/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .statistics.ngram_scorer import NgramScorer
2
+ # from .statistics.lexical_diversity_scorer import LexicalDiversityScorer
3
+ # from .statistics.langkit_scorer import LangkitScorer
4
+
5
+ # from .models.deita_quality_scorer import DeitaQualityScorer
6
+ # from .models.instag_scorer import InstagScorer
7
+ # from .models.debertav3_scorer import DebertaV3Scorer
8
+ # from .models.deita_complexity_scorer import DeitaComplexityScorer
9
+ # from .models.fineweb_edu_scorer import FineWebEduScorer
10
+ # from .models.pair_qual_scorer import PairQualScorer
11
+ # from .models.presidio_scorer import PresidioScorer
12
+ # from .models.rm_scorer import RMScorer
13
+ # from .models.textbook_scorer import TextbookScorer
14
+ # from .models.superfiltering_scorer import SuperfilteringScorer
15
+ # from .models.qurating_scorer import QuratingScorer
16
+ # from .models.perplexity_scorer import PerplexityScorer
17
+
18
+ # from .APIcaller.alpagasus_scorer import AlpagasusScorer
19
+ # from .APIcaller.treeinstruct_scorer import TreeinstructScorer
20
+ # from .APIcaller.perspective_scorer import PerspectiveScorer
21
+ # from .APIcaller.meta_scorer import MetaScorer
22
+
23
+ # from .diversity.vendi_scorer import VendiScorer
24
+ # from .diversity.task2vec_scorer import Task2VecScorer
25
+
26
+ # from .gen.bleu_scorer import BleuScorer
27
+ # from .gen.cider_scorer import CiderScorer
28
+ # from .gen.bert_scorer import BERTScorer
29
+
30
+ # __all__ = [
31
+ # 'NgramScorer',
32
+ # 'LexicalDiversityScorer',
33
+ # 'LangkitScorer',
34
+ # 'DeitaQualityScorer',
35
+ # 'InstagScorer',
36
+ # 'DebertaV3Scorer',
37
+ # 'DeitaComplexityScorer',
38
+ # 'FineWebEduScorer',
39
+ # 'PairQualScorer',
40
+ # 'PresidioScorer',
41
+ # 'RMScorer',
42
+ # 'TextbookScorer',
43
+ # 'SuperfilteringScorer',
44
+ # 'QuratingScorer',
45
+ # 'PerplexityScorer',
46
+ # 'AlpagasusScorer',
47
+ # 'TreeinstructScorer',
48
+ # 'PerspectiveScorer',
49
+ # "MetaScorer",
50
+ # 'VendiScorer',
51
+ # 'Task2VecScorer',
52
+ # 'BleuScorer',
53
+ # 'CiderScorer',
54
+ # 'BERTScorer'
55
+ # ]
DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/task2vec_scorer.cpython-310.pyc ADDED
Binary file (4.26 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/vendi_scorer.cpython-310.pyc ADDED
Binary file (1.83 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task2vec.cpython-310.pyc ADDED
Binary file (17.6 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task_similarity.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task2vec.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+
14
+ import itertools
15
+ import math
16
+ import random
17
+ from abc import ABC, abstractmethod
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch.optim import Optimizer
23
+ import numpy as np
24
+ from tqdm.auto import tqdm, trange
25
+ import logging
26
+ from torch.utils.data import DataLoader, Dataset
27
+
28
+ from .utils import AverageMeter, get_error, get_device
29
+
30
+ ## LLM DIV
31
+ def set_seed(seed):
32
+ random.seed(seed)
33
+ np.random.seed(seed)
34
+ torch.manual_seed(seed)
35
+ torch.cuda.manual_seed_all(seed)
36
+
37
+ ## LLM DIV
38
+ def get_loss(logits: torch.tensor, targets: torch.tensor, ignore_index=None) -> torch.tensor:
39
+ """
40
+ Computes the cross-entropy loss for either sequence classification or generation.
41
+ """
42
+ assert logits.dim() == 3 and ignore_index is not None
43
+ loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
44
+ logits = logits[:,:-1,:]
45
+ logits = logits.transpose(1, 2) # batch_size, vocab_size (i.e. num_classes), sequence_length
46
+ targets = targets[:,1:]
47
+
48
+ return loss(logits, targets)
49
+
50
+ class Embedding:
51
+ """
52
+ task_embedding = diagonal of the FIM for the filters of size [F_total, 1] total filters for a network.
53
+
54
+ Notes:
55
+ - the diagonal of the Fisher Information Matrix for each layer.
56
+ - embedding size should be the size of the total number of filters for the network.
57
+ """
58
+
59
+ def __init__(self, hessian, scale, meta=None):
60
+ self.hessian = np.array(hessian)
61
+ self.scale = np.array(scale)
62
+ self.meta = meta
63
+
64
+ def __repr__(self):
65
+ return f'{self.hessian}'
66
+
67
+
68
+ class ProbeNetwork(ABC, nn.Module):
69
+ """Abstract class that all probe networks should inherit from.
70
+
71
+ This is a standard torch.nn.Module but needs to expose a classifier property that returns the final classicifation
72
+ module (e.g., the last fully connected layer).
73
+ """
74
+
75
+ @property
76
+ @abstractmethod
77
+ def classifier(self):
78
+ raise NotImplementedError("Override the classifier property to return the submodules of the network that"
79
+ " should be interpreted as the classifier")
80
+
81
+ @classifier.setter
82
+ @abstractmethod
83
+ def classifier(self, val):
84
+ raise NotImplementedError("Override the classifier setter to set the submodules of the network that"
85
+ " should be interpreted as the classifier")
86
+
87
+
88
+ class Task2Vec:
89
+
90
+ def __init__(self, model: ProbeNetwork, skip_layers=0, max_samples=None, classifier_opts=None,
91
+ method='montecarlo', method_opts=None, loader_opts=None, bernoulli=False, mode='autoregressive'): ## LLM DIV
92
+ if classifier_opts is None:
93
+ classifier_opts = {}
94
+ if method_opts is None:
95
+ method_opts = {}
96
+ if loader_opts is None:
97
+ loader_opts = {}
98
+ assert method in ('variational', 'montecarlo')
99
+ assert skip_layers >= 0
100
+
101
+ self.model = model
102
+ # Fix batch norm running statistics (i.e., put batch_norm layers in eval mode)
103
+ self.model.train()
104
+ self.device = get_device(self.model)
105
+ self.skip_layers = skip_layers
106
+ self.max_samples = max_samples
107
+ self.classifier_opts = classifier_opts
108
+ self.method = method
109
+ self.method_opts = method_opts
110
+ self.loader_opts = loader_opts
111
+ self.bernoulli = bernoulli
112
+ self.mode = mode
113
+ if self.mode == "autoregressive":
114
+ self.loss_fn = get_loss
115
+ else:
116
+ self.loss_fn = nn.CrossEntropyLoss() if not self.bernoulli else nn.BCEWithLogitsLoss()
117
+ self.loss_fn = self.loss_fn.to(self.device)
118
+
119
+ def embed(self, dataset: Dataset, epochs: int = 5):
120
+ ## LLM DIV
121
+ # Cache the last layer features (needed to train the classifier) and (if needed) the intermediate layer features
122
+ # so that we can skip the initial layers when computing the embedding
123
+ # dataset.train()
124
+ if self.mode == "autoregressive":
125
+ loss = None
126
+ print(f'{self.classifier_opts=}')
127
+ if self.classifier_opts: # is it something truthy? e.g., dict with something in it?
128
+ if self.classifier_opts.get('finetune', False): # finetune only if specified True, else no finetuning if not specified or False.
129
+ epochs = 0
130
+ print(f'Warning: classifier_opts doesnt specify finetune or break early, thus no finetuning is being done. See: {self.classifier_opts=} {epochs=}')
131
+ loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
132
+ else:
133
+ loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
134
+ else: # self.classifier_opts might be None or {}
135
+ loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
136
+ print(f'{loss=} (after fine tune, if not done it will be None)')
137
+ assert loss is not None, f'Err: {loss=}'
138
+ self.compute_fisher(dataset)
139
+ embedding = self.extract_embedding(self.model)
140
+ return embedding, loss
141
+ else:
142
+ if self.skip_layers > 0:
143
+ self._cache_features(dataset, indexes=(self.skip_layers, -1), loader_opts=self.loader_opts,
144
+ max_samples=self.max_samples)
145
+ else:
146
+ self._cache_features(dataset, max_samples=self.max_samples)
147
+ # Fits the last layer classifier using cached features
148
+ self._fit_classifier(**self.classifier_opts)
149
+
150
+ if self.skip_layers > 0:
151
+ dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features,
152
+ self.model.layers[-1].targets)
153
+
154
+ # dataset.eval() # I added this so that the embedding is computed on the val set
155
+ self.compute_fisher(dataset)
156
+ embedding = self.extract_embedding(self.model)
157
+ # dataset.train() # returns to using the support set
158
+ return embedding
159
+
160
+ ### LLM DIV
161
+ def _finetune_classifier(self, dataset: Dataset, loader_opts: dict = None, classifier_opts: dict = None, max_samples=None, epochs = 5, learning_rate = 5e-5, adam_epsilon = 1e-8):
162
+ """Fits the last layer of the HuggingFace transformer probe network."""
163
+ logging.info("Finetune classifier...")
164
+ if loader_opts is None:
165
+ loader_opts = {}
166
+ if classifier_opts is None:
167
+ classifier_opts = {}
168
+ data_loader = DataLoader(dataset, shuffle=False, batch_size=loader_opts.get('batch_size', 8),
169
+ num_workers=loader_opts.get('num_workers', 0), drop_last=False)
170
+
171
+ device = next(self.model.parameters()).device
172
+ print("MODEL DEVICE: ", device)
173
+
174
+ # num_examples = int(classifier_opts.get("task_batch_size", 256) / loader_opts.get('batch_size', 8))
175
+ num_examples = len(list(data_loader)) # not ideal but it's quicker in dev time, usually we won't feed the entire data set to task2vec so this should be fine
176
+ n_batches = num_examples
177
+
178
+ optimizer_grouped_parameters = [
179
+ {'params': [p for p in self.model.lm_head.parameters()],
180
+ 'weight_decay': classifier_opts.get("weight_decay",0.0001)},
181
+ ]
182
+
183
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=classifier_opts.get("learning_rate",learning_rate), eps=classifier_opts.get("adam_epsilon",adam_epsilon))
184
+
185
+ # Train!
186
+ logging.info("***** Running training *****")
187
+ # logging.info(" Num examples = %d", num_examples)
188
+ logging.info(" Num Epochs = %d", epochs)
189
+ logging.info(" Batch size = %d", loader_opts.get('batch_size', 8))
190
+
191
+ train_iterator = trange(classifier_opts.get("epochs", epochs), desc="Epoch", leave=False)
192
+ set_seed(classifier_opts.get("seed", 42)) # Added here for reproductibility (even between python 2 and 3)
193
+
194
+ self.model.train()
195
+ for epoch in train_iterator:
196
+ metrics = AverageMeter()
197
+ epoch_iterator = tqdm(data_loader, desc="Iteration", total=n_batches, leave=False)
198
+ for step, batch in enumerate(epoch_iterator):
199
+ optimizer.zero_grad()
200
+ inputs = {'input_ids': batch['input_ids'].to(device),
201
+ 'attention_mask': batch['attention_mask'].to(device)}
202
+ logits = self.model(**inputs, labels=inputs["input_ids"]).logits
203
+ loss = self.loss_fn(logits, inputs["input_ids"], ignore_index=50256)
204
+ print(f'\nInitial loss {loss.item()} ({step=} {epoch=})') if step == 0 else None
205
+ error = get_error(logits, inputs['input_ids'], ignore_index=50256)
206
+ loss.backward()
207
+ optimizer.step()
208
+
209
+ metrics.update(n=batch['input_ids'].shape[0], loss=loss.item(), error=error)
210
+ epoch_iterator.update(1)
211
+
212
+ if classifier_opts.get("break_early", False):
213
+ print("----> breaking early")
214
+ break
215
+ if classifier_opts.get("break_early", False):
216
+ break
217
+ logging.info(f"[epoch {epoch}]: " + "\t".join(f"{k}: {v}" for k, v in metrics.avg.items()))
218
+ print(f'\nfinal loss {step=} {epoch=} of final layer loss {loss.item()} (note we are not recomputing loss after a step so this loss printed is larger than it should be/one off)')
219
+ return loss.item()
220
+
221
+ ### LLM DIV
222
+ def montecarlo_fisher_autoregressive(self, dataset: Dataset, epochs: int = 1):
223
+ logging.info("Using montecarlo Fisher")
224
+ if self.loader_opts is None:
225
+ loader_opts = {}
226
+ else:
227
+ loader_opts = self.loader_opts
228
+
229
+ data_loader = DataLoader(dataset, shuffle=False, batch_size=loader_opts.get('batch_size', 8),
230
+ num_workers=loader_opts.get('num_workers', 0), drop_last=False)
231
+ device = get_device(self.model)
232
+
233
+ # num_examples = int(classifier_opts.get("task_batch_size", 256) / loader_opts.get('batch_size', 8))
234
+ num_examples = len(list(data_loader)) # not idea but it's quicker in dev time, usually we won't feed the entire data set to task2vec so this should be fine
235
+ n_batches = num_examples
236
+
237
+ logging.info("Computing Fisher...")
238
+ for p in self.model.parameters():
239
+ p.grad2_acc = torch.zeros_like(p.data)
240
+ p.grad_counter = 0
241
+
242
+ for k in range(epochs):
243
+ logging.info(f"\tepoch {k + 1}/{epochs}")
244
+
245
+ epoch_iterator = tqdm(data_loader, desc="Iteration", total=n_batches, leave=False)
246
+ for step, batch in enumerate(epoch_iterator):
247
+ inputs = {'input_ids': batch['input_ids'].to(device),
248
+ 'attention_mask': batch['attention_mask'].to(device)}
249
+ logits = self.model(**inputs, labels=inputs["input_ids"]).logits
250
+
251
+ # The gradients used to compute the FIM needs to be for y sampled from
252
+ # the model distribution y ~ p_w(y|x), not for y from the dataset
253
+ if self.bernoulli:
254
+ target = torch.bernoulli(F.sigmoid(logits[:,:-1,:])).detach()
255
+ else:
256
+ softmax_output = F.softmax(logits, dim=-1)
257
+ lst = [torch.multinomial(softmax_output[i,:,:], 1).detach().view(-1) for i in range(len(softmax_output))]
258
+ target = torch.stack(lst, dim=0)
259
+
260
+ loss = self.loss_fn(logits, target, ignore_index=50256)
261
+ self.model.zero_grad()
262
+ loss.backward()
263
+ for p in self.model.parameters():
264
+ if p.grad is not None:
265
+ p.grad2_acc += p.grad.data ** 2
266
+ p.grad_counter += 1
267
+ if self.classifier_opts.get("break_early", False):
268
+ break # for debugging faster, otherwise FIM is really slow
269
+ if self.classifier_opts.get("break_early", False):
270
+ break # for debugging faster, otherwise FIM is really slow
271
+ for p in self.model.parameters():
272
+ if p.grad_counter == 0:
273
+ del p.grad2_acc
274
+ else:
275
+ p.grad2_acc /= p.grad_counter
276
+ logging.info("done")
277
+
278
+ def montecarlo_fisher(self, dataset: Dataset, epochs: int = 1):
279
+ logging.info("Using montecarlo Fisher")
280
+ if self.skip_layers > 0:
281
+ dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features,
282
+ self.model.layers[-1].targets)
283
+ data_loader = _get_loader(dataset, **self.loader_opts)
284
+ device = get_device(self.model)
285
+ logging.info("Computing Fisher...")
286
+
287
+ for p in self.model.parameters():
288
+ p.grad2_acc = torch.zeros_like(p.data)
289
+ p.grad_counter = 0
290
+ for k in range(epochs):
291
+ logging.info(f"\tepoch {k + 1}/{epochs}")
292
+ for i, (data, target) in enumerate(tqdm(data_loader, leave=False, desc="Computing Fisher")):
293
+ data = data.to(device)
294
+ output = self.model(data, start_from=self.skip_layers)
295
+ # The gradients used to compute the FIM needs to be for y sampled from
296
+ # the model distribution y ~ p_w(y|x), not for y from the dataset
297
+ if self.bernoulli:
298
+ target = torch.bernoulli(F.sigmoid(output)).detach()
299
+ else:
300
+ target = torch.multinomial(F.softmax(output, dim=-1), 1).detach().view(-1)
301
+ loss = self.loss_fn(output, target)
302
+ self.model.zero_grad()
303
+ loss.backward()
304
+ for p in self.model.parameters():
305
+ if p.grad is not None:
306
+ p.grad2_acc += p.grad.data ** 2
307
+ p.grad_counter += 1
308
+ for p in self.model.parameters():
309
+ if p.grad_counter == 0:
310
+ del p.grad2_acc
311
+ else:
312
+ p.grad2_acc /= p.grad_counter
313
+ logging.info("done")
314
+
315
+ def _run_epoch(self, data_loader: DataLoader, model: ProbeNetwork, loss_fn,
316
+ optimizer: Optimizer, epoch: int, train: bool = True,
317
+ add_compression_loss: bool = False, skip_layers=0, beta=1.0e-7):
318
+ metrics = AverageMeter()
319
+ device = get_device(model)
320
+
321
+ for i, (input, target) in enumerate(tqdm(data_loader, leave=False, desc="Computing Fisher")):
322
+ input = input.to(device)
323
+ target = target.to(device)
324
+ output = model(input, start_from=skip_layers)
325
+
326
+ loss = loss_fn(output, target)
327
+ lz = beta * variational.get_compression_loss(model) if add_compression_loss else torch.zeros_like(loss)
328
+ loss += lz
329
+
330
+ error = get_error(output, target)
331
+
332
+ metrics.update(n=input.size(0), loss=loss.item(), lz=lz.item(), error=error)
333
+ if train:
334
+ optimizer.zero_grad()
335
+ loss.backward()
336
+ optimizer.step()
337
+ # logging.info(
338
+ print(
339
+ "{}: [{epoch}] ".format('Epoch' if train else '', epoch=epoch) +
340
+ "Data/Batch: {:.3f}/{:.3f} ".format(metrics.avg["data_time"], metrics.avg["batch_time"]) +
341
+ "Loss {:.3f} Lz: {:.3f} ".format(metrics.avg["loss"], metrics.avg["lz"]) +
342
+ "Error: {:.2f}".format(metrics.avg["error"])
343
+ )
344
+ return metrics.avg
345
+
346
+ def variational_fisher(self, dataset: Dataset, epochs=1, beta=1e-7):
347
+ logging.info("Training variational fisher...")
348
+ parameters = []
349
+ for layer in self.model.layers[self.skip_layers:-1]:
350
+ if isinstance(layer, nn.Module): # Skip lambda functions
351
+ variational.make_variational(layer)
352
+ parameters += variational.get_variational_vars(layer)
353
+ bn_params = []
354
+ # Allows batchnorm parameters to change
355
+ for m in self.model.modules():
356
+ if isinstance(m, nn.BatchNorm2d):
357
+ bn_params += list(m.parameters())
358
+ # Avoids computing the gradients wrt to the weights to save time and memory
359
+ for p in self.model.parameters():
360
+ if p not in set(parameters) and p not in set(self.model.classifier.parameters()):
361
+ p.old_requires_grad = p.requires_grad
362
+ p.requires_grad = False
363
+
364
+ optimizer = torch.optim.Adam([
365
+ {'params': parameters},
366
+ {'params': bn_params, 'lr': 5e-4},
367
+ {'params': self.model.classifier.parameters(), 'lr': 5e-4}],
368
+ lr=1e-2, betas=(.9, 0.999))
369
+ if self.skip_layers > 0:
370
+ dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features,
371
+ self.model.layers[-1].targets)
372
+ train_loader = _get_loader(dataset, **self.loader_opts)
373
+
374
+ for epoch in range(epochs):
375
+ self._run_epoch(train_loader, self.model, self.loss_fn, optimizer, epoch, beta=beta,
376
+ add_compression_loss=True, train=True)
377
+
378
+ # Resets original value of requires_grad
379
+ for p in self.model.parameters():
380
+ if hasattr(p, 'old_requires_grad'):
381
+ p.requires_grad = p.old_requires_grad
382
+ del p.old_requires_grad
383
+
384
+ def compute_fisher(self, dataset: Dataset):
385
+ """
386
+ Computes the Fisher Information of the weights of the model wrt the model output on the dataset and stores it.
387
+
388
+ The Fisher Information Matrix is defined as:
389
+ F = E_{x ~ dataset} E_{y ~ p_w(y|x)} [\nabla_w log p_w(y|x) \nabla_w log p_w(y|x)^t]
390
+ where p_w(y|x) is the output probability vector of the network and w are the weights of the network.
391
+ Notice that the label y is sampled from the model output distribution and not from the dataset.
392
+
393
+ This code only approximate the diagonal of F. The result is stored in the model layers and can be extracted
394
+ using the `get_fisher` method. Different approximation methods of the Fisher information matrix are available,
395
+ and can be selected in the __init__.
396
+
397
+ :param dataset: dataset with the task to compute the Fisher on
398
+ """
399
+ if self.mode == 'autoregressive' and self.method == 'montecarlo':
400
+ fisher_fn = self.montecarlo_fisher_autoregressive
401
+ elif self.method == 'variational':
402
+ fisher_fn = self.variational_fisher
403
+ elif self.method == 'montecarlo':
404
+ fisher_fn = self.montecarlo_fisher
405
+ else:
406
+ raise ValueError(f"Invalid Fisher method {self.method}")
407
+ fisher_fn(dataset, **self.method_opts)
408
+
409
+ def _cache_features(self, dataset: Dataset, indexes=(-1,), max_samples=None, loader_opts: dict = None):
410
+ logging.info("Caching features...")
411
+ if loader_opts is None:
412
+ loader_opts = {}
413
+ data_loader = DataLoader(dataset, shuffle=False, batch_size=loader_opts.get('batch_size', 64),
414
+ num_workers=loader_opts.get('num_workers', 0), drop_last=False)
415
+
416
+ device = next(self.model.parameters()).device
417
+
418
+ def _hook(layer, inputs):
419
+ if not hasattr(layer, 'input_features'):
420
+ layer.input_features = []
421
+ layer.input_features.append(inputs[0].data.cpu().clone())
422
+
423
+ hooks = [self.model.layers[index].register_forward_pre_hook(_hook)
424
+ for index in indexes]
425
+ if max_samples is not None:
426
+ n_batches = min(
427
+ math.floor(max_samples / data_loader.batch_size) - 1, len(data_loader))
428
+ else:
429
+ n_batches = len(data_loader)
430
+ targets = []
431
+
432
+ for i, (input, target) in tqdm(enumerate(itertools.islice(data_loader, 0, n_batches)), total=n_batches,
433
+ leave=False,
434
+ desc="Caching features"):
435
+ targets.append(target.clone())
436
+ self.model(input.to(device))
437
+ for hook in hooks:
438
+ hook.remove()
439
+ for index in indexes:
440
+ self.model.layers[index].input_features = torch.cat(self.model.layers[index].input_features)
441
+ self.model.layers[-1].targets = torch.cat(targets)
442
+
443
+ def _fit_classifier(self, optimizer='adam', learning_rate=0.0004, weight_decay=0.0001,
444
+ epochs=10):
445
+ """Fits the last layer of the network using the cached features."""
446
+ logging.info("Fitting final classifier...")
447
+ if not hasattr(self.model.classifier, 'input_features'):
448
+ raise ValueError("You need to run `cache_features` on model before running `fit_classifier`")
449
+ targets = self.model.classifier.targets.to(self.device)
450
+ features = self.model.classifier.input_features.to(self.device)
451
+
452
+ dataset = torch.utils.data.TensorDataset(features, targets)
453
+ data_loader = _get_loader(dataset, **self.loader_opts)
454
+
455
+ if optimizer == 'adam':
456
+ optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=learning_rate, weight_decay=weight_decay)
457
+ elif optimizer == 'sgd':
458
+ optimizer = torch.optim.SGD(self.model.fc.parameters(), lr=learning_rate, weight_decay=weight_decay)
459
+ else:
460
+ raise ValueError(f'Unsupported optimizer {optimizer}')
461
+
462
+ loss_fn = nn.CrossEntropyLoss()
463
+ for epoch in tqdm(range(epochs), desc="Fitting classifier", leave=False):
464
+ metrics = AverageMeter()
465
+ for data, target in data_loader:
466
+ optimizer.zero_grad()
467
+ output = self.model.classifier(data)
468
+ loss = loss_fn(self.model.classifier(data), target)
469
+ error = get_error(output, target)
470
+ loss.backward()
471
+ optimizer.step()
472
+ metrics.update(n=data.size(0), loss=loss.item(), error=error)
473
+ logging.info(f"[epoch {epoch}]: " + "\t".join(f"{k}: {v}" for k, v in metrics.avg.items()))
474
+ print(f'\nfinal loss after fitting final layer {loss=}')
475
+
476
+ def extract_embedding(self, model: ProbeNetwork):
477
+ """
478
+ Reads the values stored by `compute_fisher` and returns them in a common format that describes the diagonal of the
479
+ Fisher Information Matrix for each layer.
480
+
481
+ :param model:
482
+ :return:
483
+ """
484
+ if self.mode == 'autoregressive':
485
+ hess, scale = [], []
486
+ for name, module in model.named_modules():
487
+ if module is model.lm_head:
488
+ continue
489
+ # The other Fisher approximation methods directly approximate the hessian at the minimum
490
+ if hasattr(module, 'weight') and hasattr(module.weight, 'grad2_acc'):
491
+ grad2 = module.weight.grad2_acc.cpu().detach().numpy()
492
+ filterwise_hess = grad2.reshape(grad2.shape[0], -1).mean(axis=1)
493
+ hess.append(filterwise_hess)
494
+ scale.append(np.ones_like(filterwise_hess))
495
+ else:
496
+ hess, scale = [], []
497
+ for name, module in model.named_modules():
498
+ if module is model.classifier:
499
+ continue
500
+ # The variational Fisher approximation estimates the variance of noise that can be added to the weights
501
+ # without increasing the error more than a threshold. The inverse of this is proportional to an
502
+ # approximation of the hessian in the local minimum.
503
+ if hasattr(module, 'logvar0') and hasattr(module, 'loglambda2'):
504
+ logvar = module.logvar0.view(-1).detach().cpu().numpy()
505
+ hess.append(np.exp(-logvar))
506
+ loglambda2 = module.loglambda2.detach().cpu().numpy()
507
+ scale.append(np.exp(-loglambda2).repeat(logvar.size))
508
+ # The other Fisher approximation methods directly approximate the hessian at the minimum
509
+ elif hasattr(module, 'weight') and hasattr(module.weight, 'grad2_acc'):
510
+ grad2 = module.weight.grad2_acc.cpu().detach().numpy()
511
+ filterwise_hess = grad2.reshape(grad2.shape[0], -1).mean(axis=1)
512
+ hess.append(filterwise_hess)
513
+ scale.append(np.ones_like(filterwise_hess))
514
+ return Embedding(hessian=np.concatenate(hess), scale=np.concatenate(scale), meta=None)
515
+
516
+
517
+ def _get_loader(trainset, testset=None, batch_size=64, num_workers=0, num_samples=10000, drop_last=True):
518
+ if getattr(trainset, 'is_multi_label', False):
519
+ raise ValueError("Multi-label datasets not supported")
520
+ # TODO: Find a way to standardize this
521
+ if hasattr(trainset, 'labels'):
522
+ labels = trainset.labels
523
+ elif hasattr(trainset, 'targets'):
524
+ labels = trainset.targets
525
+ else:
526
+ labels = list(trainset.tensors[1].cpu().numpy())
527
+ num_classes = int(getattr(trainset, 'num_classes', max(labels) + 1))
528
+ class_count = np.eye(num_classes)[labels].sum(axis=0)
529
+ weights = 1. / class_count[labels] / num_classes
530
+ weights /= weights.sum()
531
+
532
+ sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=num_samples)
533
+ # No need for mutli-threaded loading if everything is already in memory,
534
+ # and would raise an error if TensorDataset is on CUDA
535
+ num_workers = num_workers if not isinstance(trainset, torch.utils.data.TensorDataset) else 0
536
+ trainloader = torch.utils.data.DataLoader(trainset, sampler=sampler, batch_size=batch_size,
537
+ num_workers=num_workers, drop_last=drop_last)
538
+
539
+ if testset is None:
540
+ return trainloader
541
+ else:
542
+ testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, pin_memory=True, shuffle=False,
543
+ num_workers=num_workers)
544
+ return trainloader, testloader
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task_similarity.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
6
+ # may not use this file except in compliance with the License. A copy of
7
+ # the License is located at
8
+ #
9
+ # http://aws.amazon.com/apache2.0/
10
+ #
11
+ # or in the "license" file accompanying this file. This file is
12
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
13
+ # ANY KIND, either express or implied. See the License for the specific
14
+ # language governing permissions and limitations under the License.
15
+
16
+ import itertools
17
+ from typing import Tuple
18
+
19
+ import scipy.spatial.distance as distance
20
+ import numpy as np
21
+ import copy
22
+ import pickle
23
+
24
+ # import uutils
25
+
26
+ _DISTANCES = {}
27
+
28
+
29
+ # TODO: Remove methods that do not perform well
30
+
31
+ def _register_distance(distance_fn):
32
+ _DISTANCES[distance_fn.__name__] = distance_fn
33
+ return distance_fn
34
+
35
+
36
+ def is_excluded(k):
37
+ exclude = ['fc', 'linear']
38
+ return any([e in k for e in exclude])
39
+
40
+
41
+ def load_embedding(filename):
42
+ with open(filename, 'rb') as f:
43
+ e = pickle.load(f)
44
+ return e
45
+
46
+
47
+ def get_trivial_embedding_from(e):
48
+ trivial_embedding = copy.deepcopy(e)
49
+ for l in trivial_embedding['layers']:
50
+ a = np.array(l['filter_logvar'])
51
+ a[:] = l['filter_lambda2']
52
+ l['filter_logvar'] = list(a)
53
+ return trivial_embedding
54
+
55
+
56
+ def binary_entropy(p):
57
+ from scipy.special import xlogy
58
+ return - (xlogy(p, p) + xlogy(1. - p, 1. - p))
59
+
60
+
61
+ def get_layerwise_variance(e, normalized=False):
62
+ var = [np.exp(l['filter_logvar']) for l in e['layers']]
63
+ if normalized:
64
+ var = [v / np.linalg.norm(v) for v in var]
65
+ return var
66
+
67
+
68
+ def get_variance(e, normalized=False):
69
+ var = 1. / np.array(e.hessian)
70
+ if normalized:
71
+ lambda2 = 1. / np.array(e.scale)
72
+ var = var / lambda2
73
+ return var
74
+
75
+
76
+ def get_variances(*embeddings, normalized=False):
77
+ return [get_variance(e, normalized=normalized) for e in embeddings]
78
+
79
+
80
+ def get_hessian(e, normalized=False):
81
+ hess = np.array(e.hessian)
82
+ if normalized:
83
+ scale = np.array(e.scale)
84
+ hess = hess / scale
85
+ return hess
86
+
87
+
88
+ def get_hessians(*embeddings, normalized=False):
89
+ return [get_hessian(e, normalized=normalized) for e in embeddings]
90
+
91
+
92
+ def get_scaled_hessian(e0, e1):
93
+ h0, h1 = get_hessians(e0, e1, normalized=False)
94
+ return h0 / (h0 + h1 + 1e-8), h1 / (h0 + h1 + 1e-8)
95
+
96
+
97
+ def get_full_kl(e0, e1):
98
+ var0, var1 = get_variance(e0), get_variance(e1)
99
+ kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))
100
+ kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1))
101
+ return kl0, kl1
102
+
103
+
104
+ def layerwise_kl(e0, e1):
105
+ layers0, layers1 = get_layerwise_variance(e0), get_layerwise_variance(e1)
106
+ kl0 = []
107
+ for var0, var1 in zip(layers0, layers1):
108
+ kl0.append(np.sum(.5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))))
109
+ return kl0
110
+
111
+
112
+ def layerwise_cosine(e0, e1):
113
+ layers0, layers1 = get_layerwise_variance(e0, normalized=True), get_layerwise_variance(e1, normalized=True)
114
+ res = []
115
+ for var0, var1 in zip(layers0, layers1):
116
+ res.append(distance.cosine(var0, var1))
117
+ return res
118
+
119
+
120
+ @_register_distance
121
+ def kl(e0, e1):
122
+ var0, var1 = get_variance(e0), get_variance(e1)
123
+ kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))
124
+ kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1))
125
+ return np.maximum(kl0, kl1).sum()
126
+
127
+
128
+ @_register_distance
129
+ def asymmetric_kl(e0, e1):
130
+ var0, var1 = get_variance(e0), get_variance(e1)
131
+ kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))
132
+ kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1))
133
+ return kl0.sum()
134
+
135
+
136
+ @_register_distance
137
+ def jsd(e0, e1):
138
+ var0, var1 = get_variance(e0), get_variance(e1)
139
+ var = .5 * (var0 + var1)
140
+ kl0 = .5 * (var0 / var - 1 + np.log(var) - np.log(var0))
141
+ kl1 = .5 * (var1 / var - 1 + np.log(var) - np.log(var1))
142
+ return (.5 * (kl0 + kl1)).mean()
143
+
144
+
145
+ @_register_distance
146
+ def cosine(e0, e1):
147
+ h1, h2 = get_scaled_hessian(e0, e1)
148
+ return distance.cosine(h1, h2)
149
+
150
+
151
+ @_register_distance
152
+ def normalized_cosine(e0, e1):
153
+ h1, h2 = get_variances(e0, e1, normalized=True)
154
+ return distance.cosine(h1, h2)
155
+
156
+
157
+ @_register_distance
158
+ def correlation(e0, e1):
159
+ v1, v2 = get_variances(e0, e1, normalized=False)
160
+ return distance.correlation(v1, v2)
161
+
162
+
163
+ @_register_distance
164
+ def entropy(e0, e1):
165
+ h1, h2 = get_scaled_hessian(e0, e1)
166
+ return np.log(2) - binary_entropy(h1).mean()
167
+
168
+
169
+ def get_normalized_embeddings(embeddings, normalization=None):
170
+ F = [1. / get_variance(e, normalized=False) if e is not None else None for e in embeddings]
171
+ zero_embedding = np.zeros_like([x for x in F if x is not None][0])
172
+ F = np.array([x if x is not None else zero_embedding for x in F])
173
+ # FIXME: compute variance using only valid embeddings
174
+ if normalization is None:
175
+ normalization = np.sqrt((F ** 2).mean(axis=0, keepdims=True))
176
+ F /= normalization
177
+ return F, normalization
178
+
179
+
180
+ def pdist(embeddings, distance='cosine') -> np.ndarray:
181
+ distance_fn = _DISTANCES[distance]
182
+ n = len(embeddings)
183
+ distance_matrix = np.zeros([n, n])
184
+ if distance != 'asymmetric_kl':
185
+ for (i, e1), (j, e2) in itertools.combinations(enumerate(embeddings), 2):
186
+ distance_matrix[i, j] = distance_fn(e1, e2)
187
+ distance_matrix[j, i] = distance_matrix[i, j]
188
+ else:
189
+ for (i, e1) in enumerate(embeddings):
190
+ for (j, e2) in enumerate(embeddings):
191
+ distance_matrix[i, j] = distance_fn(e1, e2)
192
+ return distance_matrix
193
+
194
+
195
+ def cross_pdist(embeddings1, embeddings2, distance='cosine') -> np.ndarray :
196
+ """
197
+ Compute pairwise distance between embeddings1 and embeddings2.
198
+
199
+ ref: https://chat.openai.com/share/a5ca38dc-3393-4cfd-971c-4a29b0c56b63
200
+ """
201
+ distance_fn = _DISTANCES[distance]
202
+ n1 = len(embeddings1)
203
+ n2 = len(embeddings2)
204
+ distance_matrix = np.zeros([n1, n2])
205
+ if distance != 'asymmetric_kl':
206
+ for i, e1 in enumerate(embeddings1):
207
+ for j, e2 in enumerate(embeddings2):
208
+ distance_matrix[i, j] = distance_fn(e1, e2)
209
+ else:
210
+ for i, e1 in enumerate(embeddings1):
211
+ for j, e2 in enumerate(embeddings2):
212
+ distance_matrix[i, j] = distance_fn(e1, e2)
213
+ return distance_matrix
214
+
215
+
216
+ def cdist(from_embeddings, to_embeddings, distance='cosine'):
217
+ distance_fn = _DISTANCES[distance]
218
+ distance_matrix = np.zeros([len(from_embeddings), len(to_embeddings)])
219
+ for (i, e1) in enumerate(from_embeddings):
220
+ for (j, e2) in enumerate(to_embeddings):
221
+ if e1 is None or e2 is None:
222
+ continue
223
+ distance_matrix[i, j] = distance_fn(e1, e2)
224
+ return distance_matrix
225
+
226
+
227
+ def plot_distance_matrix(embeddings, labels=None, distance='cosine', show_plot=True):
228
+ import seaborn as sns
229
+ from scipy.cluster.hierarchy import linkage
230
+ from scipy.spatial.distance import squareform
231
+ import pandas as pd
232
+ import matplotlib.pyplot as plt
233
+ distance_matrix = pdist(embeddings, distance=distance)
234
+ cond_distance_matrix = squareform(distance_matrix, checks=False)
235
+ linkage_matrix = linkage(cond_distance_matrix, method='complete', optimal_ordering=True)
236
+ if labels is not None:
237
+ distance_matrix = pd.DataFrame(distance_matrix, index=labels, columns=labels)
238
+ sns.clustermap(distance_matrix, row_linkage=linkage_matrix, col_linkage=linkage_matrix, cmap='viridis_r')
239
+ if show_plot:
240
+ plt.show()
241
+
242
+ ## LLM DIV
243
+ def plot_distance_matrix_heatmap_only(embeddings, labels=None, distance='cosine', show_plot=True, title=None, save_file=None):
244
+ import seaborn as sns
245
+ import pandas as pd
246
+ import matplotlib.pyplot as plt
247
+ distance_matrix = pdist(embeddings, distance=distance)
248
+ if labels is not None:
249
+ distance_matrix = pd.DataFrame(distance_matrix, index=labels, columns=labels)
250
+ sns.heatmap(distance_matrix, cmap='viridis_r')
251
+ if title:
252
+ plt.title(title)
253
+ if save_file:
254
+ _ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
255
+ if show_plot:
256
+ plt.show()
257
+
258
+ ## LLM DIV
259
+ def plot_distance_matrix_from_distance_matrix(distance_matrix, labels=None, show_plot=True, title=None, save_file=None, cluster=False, plot_multi=False):
260
+ import seaborn as sns
261
+ from scipy.cluster.hierarchy import linkage
262
+ from scipy.spatial.distance import squareform
263
+ import pandas as pd
264
+ import matplotlib.pyplot as plt
265
+
266
+ cond_distance_matrix = squareform(distance_matrix, checks=False)
267
+ linkage_matrix = linkage(cond_distance_matrix, method='complete', optimal_ordering=True)
268
+ if labels is not None:
269
+ distance_matrix = pd.DataFrame(distance_matrix, index=labels, columns=labels)
270
+
271
+ # plot multiple subplots in one figure
272
+ # distance_matrix passed in is a list of distance_matrix (np.arrays)
273
+ if plot_multi and not cluster:
274
+ num_rows, num_cols = 3, 2
275
+ f, ax = plt.subplots(num_rows, num_cols)#, figsize=(12, 15))
276
+ i = 0
277
+ for row_ind in range(len(num_rows)):
278
+ for col_ind in range(len(num_cols)):
279
+ sns.heatmap(distance_matrix[i], cmap='viridis_r', ax=ax[row_ind, col_ind])
280
+ i += 1
281
+ else:
282
+ if cluster:
283
+ sns.clustermap(distance_matrix, row_linkage=linkage_matrix, col_linkage=linkage_matrix, cmap='viridis_r')
284
+ else:
285
+ sns.heatmap(distance_matrix, cmap='viridis_r')
286
+
287
+ if title:
288
+ plt.title(title)
289
+ if save_file:
290
+ _ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
291
+ if show_plot:
292
+ plt.show()
293
+
294
+ ## LLM DIV
295
+ # plot multiple subplots in one figure
296
+ # distance_matrix passed in is a list of distance_matrix np.arrays
297
+ def plot_multi_distance_matrix_from_distance_matrix_list(distance_matrix_lst, title_lst, labels, main_title=None, show_plot=True, title=None, save_file=None, vmin=None, vmax=None):
298
+ import seaborn as sns
299
+ from scipy.cluster.hierarchy import linkage
300
+ from scipy.spatial.distance import squareform
301
+ import pandas as pd
302
+ import matplotlib.pyplot as plt
303
+ import math
304
+ num_rows, num_cols = math.ceil(len(distance_matrix_lst)/2), 2
305
+ if len(distance_matrix_lst) % 2 == 1:
306
+ figsize = (12,10)
307
+ else:
308
+ figsize = (12,10)
309
+ f, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
310
+ i = 0
311
+ for row_ind in range(num_rows):
312
+ for col_ind in range(num_cols):
313
+ if i >= len(distance_matrix_lst):
314
+ break
315
+ distance_matrix = distance_matrix_lst[i]
316
+ distance_matrix = pd.DataFrame(distance_matrix, index=labels[i], columns=labels[i])
317
+ if len(distance_matrix_lst) > 2:
318
+ ax[row_ind, col_ind].set_aspect('equal')
319
+ if vmin is not None and vmax is not None:
320
+ sns.heatmap(distance_matrix, cmap='viridis_r', ax=ax[row_ind, col_ind], vmin=vmin, vmax=vmax)
321
+ else:
322
+ sns.heatmap(distance_matrix, cmap='viridis_r', ax=ax[row_ind, col_ind])
323
+ ax[row_ind, col_ind].set_title(title_lst[i])
324
+ else:
325
+ ax[col_ind].set_aspect('equal')
326
+ sns.heatmap(distance_matrix, cmap='viridis_r', ax=ax[col_ind])
327
+ ax[col_ind].set_title(title_lst[i])
328
+
329
+ i += 1
330
+ if len(distance_matrix_lst) % 2 == 1:
331
+ f.delaxes(ax[num_rows-1,1])
332
+
333
+ if main_title:
334
+ f.suptitle(main_title)
335
+ f.subplots_adjust(top=0.5)
336
+
337
+ if len(distance_matrix_lst) % 2 == 1:
338
+ plt.tight_layout(h_pad=2)
339
+ else:
340
+ plt.tight_layout(h_pad=2, w_pad=5)
341
+ if save_file:
342
+ _ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
343
+ if show_plot:
344
+ plt.show()
345
+
346
+ ## LLM DIV
347
+ def stats_of_distance_matrix(distance_matrix: np.ndarray,
348
+ remove_diagonal: bool = True,
349
+ variance_type: str = 'std', # TODO: was ci_0.95. Changed to rid uutils call
350
+ get_total: bool = False,
351
+ ) -> Tuple[float, float]:
352
+ if remove_diagonal:
353
+ # - remove diagonal: ref https://stackoverflow.com/questions/46736258/deleting-diagonal-elements-of-a-numpy-array
354
+ triu: np.ndarray = np.triu(distance_matrix)
355
+ tril: np.ndarray = np.tril(distance_matrix)
356
+ # distance_matrix = distance_matrix[~np.eye(distance_matrix.shape[0], dtype=bool)].reshape(distance_matrix.shape[0], -1)
357
+ # remove diagonal and dummy zeros where the other triangular matrix was artificially placed.
358
+ distance_matrix = triu[triu != 0.0]
359
+
360
+ # - flatten
361
+ distance_matrix: np.ndarray = distance_matrix.flatten()
362
+
363
+ # - compute stats of distance matrix
364
+ if variance_type == 'std':
365
+ mu, var = distance_matrix.mean(), distance_matrix.std()
366
+ # elif variance_type == 'ci_0.95':
367
+ # from uutils.torch_uu.metrics.confidence_intervals import mean_confidence_interval
368
+ # mu, var = mean_confidence_interval(distance_matrix, confidence=0.95)
369
+ else:
370
+ raise ValueError(f'Invalid variance type, got: {variance_type=}')
371
+
372
+ # - double checks the mean was computed corrects. Since it's symmetric the mean after removing diagonal should be equal to just one side of the diagonals
373
+ if remove_diagonal:
374
+ # from uutils.torch_uu import approx_equal
375
+ # assert approx_equal(triu.sum(), tril.sum(), tolerance=1e-4), f'Distance matrix is not symmetric, are you sure this is correct?'
376
+ # assert approx_equal(distance_matrix.mean(), triu[triu != 0.0].mean(), tolerance=1e-4), f'Mean should be equal to triangular matrix'
377
+ # assert approx_equal(mu, triu[triu != 0.0].mean(), tolerance=1e-4)
378
+
379
+ print('Lower tri sum', tril.sum(), ' / Upper tri sum', triu.sum(), '| These should be approx equal!!')
380
+ print('Total mean', distance_matrix.mean(), ' / Upper mean', triu[triu != 0.0].mean(), ' / Lower mean', tril[tril != 0.0].mean(), '| These should all be approx equal!!')
381
+ print('mu (div coefficient)', mu, ' / Upper mean', triu[triu != 0.0].mean(), '| These should all be approx equal!!')
382
+ if get_total:
383
+ total = distance_matrix.sum()
384
+ return mu, var, total
385
+ else:
386
+ return mu, var
387
+
388
+
389
+ def stats_cross_distance_matrix(distance_matrix: np.ndarray,
390
+ remove_diagonal: bool = False,
391
+ variance_type: str = 'std', # TODO: was ci_0.95. Changed to rid uutils call
392
+ get_total: bool = False,
393
+ ) -> Tuple[float, float]:
394
+ return stats_of_distance_matrix(distance_matrix, remove_diagonal=remove_diagonal, variance_type=variance_type, get_total=get_total)
395
+
396
+
397
+ def plot_histogram_of_distances(distance_matrix: np.ndarray, title, show_plot=True, save_file=None, bins_width=None, grid=True):
398
+ import matplotlib.pyplot as plt
399
+ triu = np.triu(distance_matrix)
400
+ triu = triu[triu != 0.0]
401
+ distance_values = triu.flatten()
402
+
403
+ if grid:
404
+ plt.grid(zorder=0)
405
+ plt.axvline(np.mean(distance_values), color='k', linestyle='dashed', linewidth=1, zorder=4)
406
+ if bins_width is not None:
407
+ plt.hist(distance_values, edgecolor ="black", bins=np.arange(min(distance_values), max(distance_values) + bins_width, bins_width), zorder=3)
408
+ else:
409
+ plt.hist(distance_values, edgecolor ="black", zorder=3)
410
+ plt.title(title)
411
+ plt.xlabel("Cosine Distance between Task Pairs")
412
+ plt.ylabel("Frequency")
413
+
414
+ plt.tight_layout()
415
+ if save_file:
416
+ _ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
417
+
418
+ if show_plot:
419
+ plt.show()
420
+
421
+
422
+ ## LLM DIV
423
+ # plot multiple subplots in one figure
424
+ # distance_matrix passed in is a list of distance_matrix (np.arrays)
425
+ def plot_multi_histogram_of_distances(distance_matrix_lst, title_lst, main_title=None, show_plot=True, save_file=None,
426
+ xlabel="Cosine Distance between Task Pairs", grid=True, bins_width=None,
427
+ num_cols=2, figsize=(12,10)):
428
+ import seaborn as sns
429
+ from scipy.cluster.hierarchy import linkage
430
+ from scipy.spatial.distance import squareform
431
+ import pandas as pd
432
+ import matplotlib.pyplot as plt
433
+ import math
434
+
435
+ if num_cols == 2:
436
+ num_rows = math.ceil(len(distance_matrix_lst)/2)
437
+ else:
438
+ num_rows = math.ceil(len(distance_matrix_lst)/num_cols)
439
+
440
+ f, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
441
+ i = 0
442
+ for row_ind in range(num_rows):
443
+ for col_ind in range(num_cols):
444
+ if i >= len(distance_matrix_lst):
445
+ break
446
+ triu = np.triu(distance_matrix_lst[i])
447
+ triu = triu[triu != 0.0]
448
+ distance_values = triu.flatten()
449
+
450
+ if len(distance_matrix_lst) > 2:
451
+ if grid:
452
+ ax[row_ind, col_ind].grid(zorder=0)
453
+ if bins_width is not None:
454
+ ax[row_ind, col_ind].hist(distance_values, edgecolor ="black", zorder=3, bins=np.arange(min(distance_values), max(distance_values) + bins_width, bins_width))
455
+ else:
456
+ ax[row_ind, col_ind].hist(distance_values, edgecolor ="black", zorder=3)
457
+ ax[row_ind, col_ind].set_xlabel(xlabel)
458
+ ax[row_ind, col_ind].set_ylabel("Frequency")
459
+ ax[row_ind, col_ind].axvline(np.mean(distance_values), color='k', linestyle='dashed', linewidth=1, zorder=4)
460
+ ax[row_ind, col_ind].set_title(title_lst[i])
461
+ else:
462
+ if grid:
463
+ ax[col_ind].grid(zorder=0)
464
+ ax[col_ind].hist(distance_values, edgecolor ="black", zorder=3)
465
+ if bins_width is not None:
466
+ ax[col_ind].hist(distance_values, edgecolor ="black", zorder=3, bins=np.arange(min(distance_values), max(distance_values) + bins_width, bins_width))
467
+ else:
468
+ ax[col_ind].hist(distance_values, edgecolor ="black", zorder=3)
469
+ ax[col_ind].set_xlabel(xlabel)
470
+ ax[col_ind].set_ylabel("Frequency")
471
+ ax[col_ind].set_title(title_lst[i])
472
+ i += 1
473
+ if len(distance_matrix_lst) % 2 == 1 and num_cols == 2:
474
+ f.delaxes(ax[num_rows-1,1])
475
+
476
+ if main_title:
477
+ f.suptitle(main_title)
478
+ f.subplots_adjust(top=1)
479
+
480
+ plt.grid(True)
481
+ plt.tight_layout()
482
+ if save_file:
483
+ _ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
484
+ if show_plot:
485
+ plt.show()
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+
14
+ from collections import defaultdict
15
+ import torch
16
+ import numpy as np
17
+
18
+
19
+ class AverageMeter(object):
20
+ """Computes and stores the average and current value"""
21
+
22
+ def __init__(self):
23
+ self.reset()
24
+
25
+ def reset(self):
26
+ self.val = defaultdict(int)
27
+ self.avg = defaultdict(float)
28
+ self.sum = defaultdict(int)
29
+ self.count = defaultdict(int)
30
+
31
+ def update(self, n=1, **val):
32
+ for k in val:
33
+ self.val[k] = val[k]
34
+ self.sum[k] += val[k] * n
35
+ self.count[k] += n
36
+ self.avg[k] = self.sum[k] / self.count[k]
37
+
38
+
39
+ def set_batchnorm_mode(model, train=True):
40
+ """Allows to set batch_norm layer mode to train or eval, independendtly on the mode of the model."""
41
+ def _set_batchnorm_mode(module):
42
+ if isinstance(module, torch.nn.BatchNorm1d) or isinstance(module, torch.nn.BatchNorm2d):
43
+ if train:
44
+ module.train()
45
+ else:
46
+ module.eval()
47
+
48
+ model.apply(_set_batchnorm_mode)
49
+
50
+ ### LLM DIV
51
+ def get_error(output, target, mode='autoregressive', ignore_index=None):
52
+ if mode == 'autoregressive': # output = logits here
53
+ assert ignore_index is not None
54
+ output = output[:,:-1,:]
55
+ logits_inds = torch.argmax(output, dim=-1)
56
+ target = target[:,1:]
57
+ if ignore_index is not None:
58
+ acc = torch.eq(logits_inds, target.unsqueeze(0))[:, target != ignore_index]
59
+ else:
60
+ acc = torch.eq(logits_inds, target.unsqueeze(0))
61
+ acc = acc.float().mean()
62
+ return 1 - acc
63
+ else:
64
+ pred = output.argmax(dim=1)
65
+ correct = pred.eq(target).float().sum()
66
+ return float((1. - correct / output.size(0)) * 100.)
67
+
68
+
69
+ def adjust_learning_rate(optimizer, epoch, optimizer_cfg):
70
+ lr = optimizer_cfg.lr * (0.1 ** np.less(optimizer_cfg.schedule, epoch).sum())
71
+ for param_group in optimizer.param_groups:
72
+ param_group['lr'] = lr
73
+
74
+
75
+ def get_device(model: torch.nn.Module):
76
+ return next(model.parameters()).device
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec_scorer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataflow.operators.eval.GeneralText.diversity.task2vec.task2vec import Task2Vec
2
+ from dataflow.operators.eval.GeneralText.diversity.task2vec import task_similarity
3
+ import torch
4
+ import random
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
+ from dataflow.utils.storage import DataFlowStorage
7
+ from dataflow.core import OperatorABC
8
+ from dataflow.utils.registry import OPERATOR_REGISTRY
9
+ from torch.utils.data import Dataset
10
+ from dataflow import get_logger
11
+ from typing import Optional
12
+ # Task2Vec dataset diversity evaluation
13
+ # Cited from: Beyond Scale: the Diversity Coefficient as a Data Quality Metric Demonstrates LLMs are Pre-trained on Formally Diverse Data
14
+ @OPERATOR_REGISTRY.register()
15
+ class Task2VecScorer(OperatorABC):
16
+ def __init__(self, device='cuda', sample_nums=10, sample_size=1, method: Optional[str]='montecarlo', model_cache_dir='./dataflow_cache'):
17
+ self.logger = get_logger()
18
+ self.logger.info(f'Initializing {self.__class__.__name__}...')
19
+ # evaluating diversity by extract sample_nums * sample_size samples
20
+ self.sample_nums = sample_nums
21
+ self.sample_size = sample_size
22
+ self.device = device
23
+ self.model_cache_dir = model_cache_dir
24
+ self.score_name = 'Task2VecScore'
25
+ self.method = method
26
+ if method not in ['montecarlo', 'variational']:
27
+ raise ValueError(f"Invalid method '{method}'. Valid options are 'montecarlo' and 'variational'.")
28
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=self.model_cache_dir)
29
+ self.probe_network = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=self.model_cache_dir)
30
+ self.device = torch.device(self.device if self.device and torch.cuda.is_available() else "cpu")
31
+ self.probe_network = self.probe_network.to(self.device)
32
+ self.logger.info(f'{self.__class__.__name__} initialized.')
33
+
34
+ def preprocess(self, texts):
35
+ self.tokenizer.pad_token = self.tokenizer.eos_token
36
+ tokenized_outputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
37
+ return {key: value.to(self.device) for key, value in tokenized_outputs.items()}
38
+
39
+ def get_score(self, sentences):
40
+ embeddings = []
41
+ data_length = len(sentences)
42
+ for sample_num in range(self.sample_nums):
43
+ self.logger.info(f'--> Sample {sample_num + 1}/{self.sample_nums}')
44
+ indices = random.sample(range(data_length), self.sample_size)
45
+ texts = [sentences[i] for i in indices]
46
+ tokenized_batch = self.preprocess(texts)
47
+ tokenized_dataset = CustomTensorDataset(tokenized_batch)
48
+ embedding, _ = Task2Vec(self.probe_network, method=self.method).embed(tokenized_dataset)
49
+ embeddings.append(embedding)
50
+ distance_matrix = task_similarity.pdist(embeddings, distance='cosine')
51
+ div_coeff, conf_interval = task_similarity.stats_of_distance_matrix(distance_matrix)
52
+
53
+ return {
54
+ "Task2VecDiversityScore": div_coeff,
55
+ "ConfidenceInterval": conf_interval
56
+ }
57
+
58
+ def run(self, storage: DataFlowStorage, input_key: str):
59
+ dataframe = storage.read("dataframe")
60
+ samples = dataframe[input_key].to_list()
61
+ self.logger.info(f"Evaluating {self.score_name}...")
62
+ task2vec_score = self.get_score(samples)
63
+ self.logger.info("Evaluation complete!")
64
+ self.logger.info(f"Task2Vec Diversity Score: {task2vec_score}")
65
+ return task2vec_score
66
+
67
+
68
+ class CustomTensorDataset(Dataset):
69
+ def __init__(self, tokenized_batch):
70
+ self.tokenized_batch = tokenized_batch
71
+
72
+ def __getitem__(self, index):
73
+ return {key: self.tokenized_batch[key][index] for key in self.tokenized_batch}
74
+
75
+ def __len__(self):
76
+ return len(next(iter(self.tokenized_batch.values())))
DataFlow/dataflow/operators/eval/GeneralText/diversity/vendi_scorer.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vendi_score import text_utils
2
+ from dataflow.utils.storage import DataFlowStorage
3
+ import pandas as pd
4
+ from dataflow.core import OperatorABC
5
+ from dataflow.utils.registry import OPERATOR_REGISTRY
6
+ from dataflow import get_logger
7
+
8
+ # VendiScore dataset diversity evaluation
9
+ # Cited from: The Vendi Score: A Diversity Evaluation Metric for Machine Learning
10
+ @OPERATOR_REGISTRY.register()
11
+ class VendiScorer(OperatorABC):
12
+ def __init__(self, device='cuda'):
13
+ self.logger = get_logger()
14
+ self.logger.info(f'Initializing {self.__class__.__name__}...')
15
+ self.bert_model_path = 'bert-base-uncased'
16
+ self.simcse_model_path = 'princeton-nlp/unsup-simcse-bert-base-uncased'
17
+ self.device = device
18
+ self.score_name = 'VendiScore'
19
+ self.logger.info(f'{self.__class__.__name__} initialized.')
20
+
21
+ def get_score(self, sentences):
22
+ result = {}
23
+ bert_vs = text_utils.embedding_vendi_score(sentences, model_path=self.bert_model_path, device=self.device)
24
+ result["BERTVendiScore"] = round(bert_vs, 2)
25
+ simcse_vs = text_utils.embedding_vendi_score(sentences, model_path=self.simcse_model_path, device=self.device)
26
+ result["SimCSEVendiScore"] = round(simcse_vs, 2)
27
+ return result
28
+
29
+ def run(self, storage: DataFlowStorage, input_key: str):
30
+ dataframe = storage.read("dataframe")
31
+ samples = dataframe[input_key].to_list()
32
+ self.logger.info(f"Evaluating {self.score_name}...")
33
+ vendiscore = self.get_score(samples)
34
+ self.logger.info("Evaluation complete!")
35
+ self.logger.info(f"VendiScore: {vendiscore}")
36
+ return vendiscore
DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bert_scorer.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bleu_scorer.cpython-310.pyc ADDED
Binary file (2.39 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/cider_scorer.cpython-310.pyc ADDED
Binary file (2.67 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/gen/bert_scorer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataflow.core import OperatorABC
2
+ from dataflow.utils.storage import DataFlowStorage
3
+ from dataflow.utils.registry import OPERATOR_REGISTRY
4
+ from dataflow import get_logger
5
+ import evaluate
6
+
7
+ @OPERATOR_REGISTRY.register()
8
+ class BERTScorer(OperatorABC):
9
+ def __init__(self, lang='en', model_cache_dir='./dataflow_cache'):
10
+ self.logger = get_logger()
11
+ self.logger.info(f'Initializing {self.__class__.__name__}...')
12
+ self.data_type = "text"
13
+ self.score_name = "BERTScore"
14
+ self.lang = lang
15
+ self.model_type = "distilbert-base-uncased"
16
+ self.idf = False
17
+ self.rescale_with_baseline = False
18
+ self.bertscore = evaluate.load("bertscore", cache_dir=model_cache_dir)
19
+ self.logger.info(f'{self.__class__.__name__} initialized.')
20
+
21
+ def eval(self, dataframe, input_key, reference_key):
22
+ eval_data = dataframe[input_key].to_list()
23
+ ref_data = dataframe[reference_key].to_list()
24
+ self.logger.info(f"Evaluating {self.score_name}...")
25
+ if ref_data is None:
26
+ raise ValueError("Reference data must be provided for BERTScorer")
27
+ results = self.bertscore.compute(
28
+ predictions=eval_data,
29
+ references=ref_data,
30
+ lang=self.lang,
31
+ model_type=self.model_type,
32
+ idf=self.idf,
33
+ rescale_with_baseline=self.rescale_with_baseline
34
+ )
35
+ scores = results["f1"]
36
+ self.logger.info("Evaluation complete!")
37
+ return scores
38
+
39
+ def run(self, storage: DataFlowStorage, input_key: str, reference_key: str, output_key: str='BertScore'):
40
+ self.input_key = input_key
41
+ self.reference_key = reference_key
42
+ self.output_key = output_key
43
+ dataframe = storage.read("dataframe")
44
+ scores = self.eval(dataframe, input_key, reference_key)
45
+ dataframe[self.output_key] = scores
46
+ storage.write(dataframe)
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__init__.py ADDED
File without changes
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (197 Bytes). View file
 
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/bleu.cpython-310.pyc ADDED
Binary file (6.82 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/bleu.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import sys, math, re
3
+ from collections import defaultdict
4
+
5
+ import six
6
+ from six.moves import xrange as range
7
+
8
+
9
+ def precook(s, n=4, out=False):
10
+
11
+ words = s.split()
12
+ counts = defaultdict(int)
13
+ for k in range(1,n+1):
14
+ for i in range(len(words)-k+1):
15
+ ngram = tuple(words[i:i+k])
16
+ counts[ngram] += 1
17
+ return (len(words), counts)
18
+
19
+ def cook_refs(refs, eff=None, n=4):
20
+ reflen = []
21
+ maxcounts = {}
22
+ for ref in refs:
23
+ rl, counts = precook(ref, n)
24
+ reflen.append(rl)
25
+ for (ngram,count) in six.iteritems(counts):
26
+ maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
27
+
28
+
29
+ if eff == "shortest":
30
+ reflen = min(reflen)
31
+ elif eff == "average":
32
+ reflen = float(sum(reflen))/len(reflen)
33
+
34
+ return (reflen, maxcounts)
35
+
36
+ def cook_test(test, reflen_refmaxcounts, eff=None, n=4):
37
+ reflen, refmaxcounts = reflen_refmaxcounts
38
+ testlen, counts = precook(test, n, True)
39
+
40
+ result = {}
41
+
42
+
43
+
44
+ if eff == "closest":
45
+ result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
46
+ else:
47
+ result["reflen"] = reflen
48
+
49
+ result["testlen"] = testlen
50
+
51
+ result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
52
+
53
+ result['correct'] = [0]*n
54
+ for (ngram, count) in six.iteritems(counts):
55
+ result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
56
+
57
+ return result
58
+
59
+ class Bleu(object):
60
+ """Bleu scorer.
61
+ """
62
+
63
+ __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
64
+
65
+ def copy(self):
66
+
67
+ new = Bleu(n=self.n)
68
+ new.ctest = copy.copy(self.ctest)
69
+ new.crefs = copy.copy(self.crefs)
70
+ new._score = None
71
+ return new
72
+
73
+ def __init__(self, test=None, refs=None, n=4, special_reflen=None):
74
+
75
+
76
+ self.n = n
77
+ self.crefs = []
78
+ self.ctest = []
79
+ self.cook_append(test, refs)
80
+ self.special_reflen = special_reflen
81
+
82
+ def cook_append(self, test, refs):
83
+
84
+
85
+ if refs is not None:
86
+ self.crefs.append(cook_refs(refs))
87
+ if test is not None:
88
+ cooked_test = cook_test(test, self.crefs[-1])
89
+ self.ctest.append(cooked_test) ## N.B.: -1
90
+ else:
91
+ self.ctest.append(None) # lens of crefs and ctest have to match
92
+
93
+ self._score = None ## need to recompute
94
+
95
+ def ratio(self, option=None):
96
+ self.compute_score(option=option)
97
+ return self._ratio
98
+
99
+ def score_ratio(self, option=None):
100
+
101
+ return (self.fscore(option=option), self.ratio(option=option))
102
+
103
+ def score_ratio_str(self, option=None):
104
+ return "%.4f (%.2f)" % self.score_ratio(option)
105
+
106
+ def reflen(self, option=None):
107
+ self.compute_score(option=option)
108
+ return self._reflen
109
+
110
+ def testlen(self, option=None):
111
+ self.compute_score(option=option)
112
+ return self._testlen
113
+
114
+ def retest(self, new_test):
115
+ if type(new_test) is str:
116
+ new_test = [new_test]
117
+ assert len(new_test) == len(self.crefs), new_test
118
+ self.ctest = []
119
+ for t, rs in zip(new_test, self.crefs):
120
+ self.ctest.append(cook_test(t, rs))
121
+ self._score = None
122
+
123
+ return self
124
+
125
+ def rescore(self, new_test):
126
+ ''' replace test(s) with new test(s), and returns the new score.'''
127
+
128
+ return self.retest(new_test).compute_score()
129
+
130
+ def size(self):
131
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
132
+ return len(self.crefs)
133
+
134
+ def __iadd__(self, other):
135
+ '''add an instance (e.g., from another sentence).'''
136
+
137
+ if type(other) is tuple:
138
+ ## avoid creating new BleuScorer instances
139
+ self.cook_append(other[0], other[1])
140
+ else:
141
+ assert self.compatible(other), "incompatible BLEUs."
142
+ self.ctest.extend(other.ctest)
143
+ self.crefs.extend(other.crefs)
144
+ self._score = None ## need to recompute
145
+
146
+ return self
147
+
148
+ def compatible(self, other):
149
+ return isinstance(other, Bleu) and self.n == other.n
150
+
151
+ def single_reflen(self, option="average"):
152
+ return self._single_reflen(self.crefs[0][0], option)
153
+
154
+ def _single_reflen(self, reflens, option=None, testlen=None):
155
+
156
+ if option == "shortest":
157
+ reflen = min(reflens)
158
+ elif option == "average":
159
+ reflen = float(sum(reflens))/len(reflens)
160
+ elif option == "closest":
161
+ reflen = min((abs(l-testlen), l) for l in reflens)[1]
162
+ else:
163
+ assert False, "unsupported reflen option %s" % option
164
+
165
+ return reflen
166
+
167
+ def recompute_score(self, option=None, verbose=0):
168
+ self._score = None
169
+ return self.compute_score(option, verbose)
170
+
171
+ def compute_score(self, option=None, verbose=0):
172
+ n = self.n
173
+ small = 1e-9
174
+ tiny = 1e-15 ## so that if guess is 0 still return 0
175
+ bleu_list = [[] for _ in range(n)]
176
+
177
+ if self._score is not None:
178
+ return self._score
179
+
180
+ if option is None:
181
+ option = "average" if len(self.crefs) == 1 else "closest"
182
+
183
+ self._testlen = 0
184
+ self._reflen = 0
185
+ totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
186
+
187
+ # for each sentence
188
+ for comps in self.ctest:
189
+ testlen = comps['testlen']
190
+ self._testlen += testlen
191
+
192
+ if self.special_reflen is None: ## need computation
193
+ reflen = self._single_reflen(comps['reflen'], option, testlen)
194
+ else:
195
+ reflen = self.special_reflen
196
+
197
+ self._reflen += reflen
198
+
199
+ for key in ['guess','correct']:
200
+ for k in range(n):
201
+ totalcomps[key][k] += comps[key][k]
202
+
203
+ # append per image bleu score
204
+ bleu = 1.
205
+ for k in range(n):
206
+ bleu *= (float(comps['correct'][k]) + tiny) \
207
+ /(float(comps['guess'][k]) + small)
208
+ bleu_list[k].append(bleu ** (1./(k+1)))
209
+ ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
210
+ if ratio < 1:
211
+ for k in range(n):
212
+ bleu_list[k][-1] *= math.exp(1 - 1/ratio)
213
+
214
+ if verbose > 1:
215
+ print(comps, reflen)
216
+
217
+ totalcomps['reflen'] = self._reflen
218
+ totalcomps['testlen'] = self._testlen
219
+
220
+ bleus = []
221
+ bleu = 1.
222
+ for k in range(n):
223
+ bleu *= float(totalcomps['correct'][k] + tiny) \
224
+ / (totalcomps['guess'][k] + small)
225
+ bleus.append(bleu ** (1./(k+1)))
226
+ ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
227
+ if ratio < 1:
228
+ for k in range(n):
229
+ bleus[k] *= math.exp(1 - 1/ratio)
230
+
231
+ if verbose > 0:
232
+ print(totalcomps)
233
+ print("ratio:", ratio)
234
+
235
+ self._score = bleus
236
+ return self._score, bleu_list
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu_scorer.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataflow.core import OperatorABC
2
+ from dataflow.utils.storage import DataFlowStorage
3
+ from dataflow.utils.registry import OPERATOR_REGISTRY
4
+ from dataflow import get_logger
5
+ from dataflow.operators.eval.GeneralText.gen.bleu.bleu import Bleu
6
+ from tqdm import tqdm
7
+
8
+ @OPERATOR_REGISTRY.register()
9
+ class BleuScorer(OperatorABC):
10
+ def __init__(self, n=4, eff="average", special_reflen=None):
11
+ self.logger = get_logger()
12
+ self.logger.info(f'Initializing {self.__class__.__name__}...')
13
+ self.score_name = 'BleuScore'
14
+ valid_eff_options = ["shortest", "average", "longest"]
15
+ if eff not in valid_eff_options:
16
+ raise ValueError(f"Invalid value for 'eff'. Must be one of {valid_eff_options}, but got '{eff}'.")
17
+ self.n = n # Max n-gram length (default: 4)
18
+ self.eff = eff # [shortest, average, longest]
19
+ self.special_reflen = special_reflen # Special reference length if specified
20
+ self.logger.info(f'{self.__class__.__name__} initialized.')
21
+
22
+ def _score_func(self, eval_text, ref_text):
23
+ bleu_scorer = Bleu(
24
+ test=eval_text,
25
+ refs=[ref_text],
26
+ n=self.n,
27
+ special_reflen=self.special_reflen,
28
+ )
29
+ bleu_score, _ = bleu_scorer.compute_score(option=self.eff)
30
+ return bleu_score[0]
31
+
32
+ def eval(self, dataframe, input_key, reference_key):
33
+ eval_data = dataframe[input_key]
34
+ ref_data = dataframe[reference_key]
35
+ self.logger.info(f"Evaluating {self.score_name}...")
36
+ scores = [self._score_func(eval_text, ref_text) for eval_text, ref_text in tqdm(zip(eval_data, ref_data), desc="BleuScorer Evaluating...")]
37
+ self.logger.info("Evaluation complete!")
38
+ return scores
39
+
40
+ def run(self, storage: DataFlowStorage, input_key: str, reference_key: str, output_key: str='BleuScore'):
41
+ self.input_key = input_key
42
+ self.reference_key = reference_key
43
+ self.output_key = output_key
44
+ dataframe = storage.read("dataframe")
45
+ scores = self.eval(dataframe, input_key, reference_key)
46
+ dataframe[self.output_key] = scores
47
+ storage.write(dataframe)
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__init__.py ADDED
File without changes
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/cider.cpython-310.pyc ADDED
Binary file (5.54 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/cider.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import pickle
4
+ import numpy as np
5
+ from collections import defaultdict
6
+ import os
7
+ from six.moves import xrange
8
+ import six
9
+
10
+ def precook(s, n=4, out=False):
11
+ words = s.split()
12
+ counts = defaultdict(int)
13
+ for k in xrange(1, n+1):
14
+ for i in xrange(len(words)-k+1):
15
+ ngram = tuple(words[i:i+k])
16
+ counts[ngram] += 1
17
+ return counts
18
+
19
+ def cook_refs(refs, n=4):
20
+ return [precook(ref, n) for ref in refs]
21
+
22
+ def cook_test(test, n=4):
23
+ return precook(test, n, True)
24
+
25
+ class Cider(object):
26
+ """CIDEr scorer."""
27
+
28
+ def copy(self):
29
+ new = Cider(n=self.n)
30
+ new.ctest = copy.copy(self.ctest)
31
+ new.crefs = copy.copy(self.crefs)
32
+ return new
33
+
34
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0, idf=None):
35
+ self.n = n
36
+ self.sigma = sigma
37
+ self.crefs = []
38
+ self.ctest = []
39
+ self.document_frequency = defaultdict(float)
40
+ self.ref_len = None
41
+
42
+ if idf:
43
+ self.document_frequency = idf['df']
44
+ self.ref_len = np.log(float(idf['ref_len'])) # Use reference length from the IDF
45
+
46
+ self.cook_append(test, refs)
47
+
48
+ def cook_append(self, test, refs):
49
+ if refs is not None:
50
+ self.crefs.append(cook_refs(refs))
51
+ if test is not None:
52
+ self.ctest.append(cook_test(test))
53
+ else:
54
+ self.ctest.append(None)
55
+
56
+ def size(self):
57
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
58
+ return len(self.crefs)
59
+
60
+ def __iadd__(self, other):
61
+ if type(other) is tuple:
62
+ self.cook_append(other[0], other[1])
63
+ else:
64
+ self.ctest.extend(other.ctest)
65
+ self.crefs.extend(other.crefs)
66
+ return self
67
+
68
+ def compute_doc_freq(self):
69
+ '''Compute term frequency for reference data to generate IDF.'''
70
+ if not self.document_frequency: # Handle empty DF (for 'corpus' mode)
71
+ for refs in self.crefs:
72
+ for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]):
73
+ self.document_frequency[ngram] += 1
74
+
75
+ def compute_cider(self, df_mode):
76
+ def counts2vec(cnts):
77
+ vec = [defaultdict(float) for _ in range(self.n)]
78
+ length = 0
79
+ norm = [0.0 for _ in range(self.n)]
80
+ for (ngram, term_freq) in cnts.items():
81
+ df = np.log(max(1.0, self.document_frequency[ngram]))
82
+ n = len(ngram) - 1
83
+ vec[n][ngram] = float(term_freq) * (self.ref_len - df)
84
+ norm[n] += pow(vec[n][ngram], 2)
85
+
86
+ if n == 1:
87
+ length += term_freq
88
+ norm = [np.sqrt(n) for n in norm]
89
+ return vec, norm, length
90
+
91
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
92
+ delta = float(length_hyp - length_ref)
93
+ val = np.array([0.0 for _ in range(self.n)])
94
+ for n in range(self.n):
95
+ for (ngram, count) in vec_hyp[n].items():
96
+ val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
97
+
98
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
99
+ val[n] /= (norm_hyp[n] * norm_ref[n])
100
+
101
+ val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2))
102
+ return val
103
+
104
+ if df_mode == "corpus":
105
+ self.ref_len = np.log(float(len(self.crefs))) # Use total references in corpus as ref_len
106
+
107
+ scores = []
108
+ for test, refs in zip(self.ctest, self.crefs):
109
+ vec, norm, length = counts2vec(test)
110
+ score = np.array([0.0 for _ in range(self.n)])
111
+ for ref in refs:
112
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
113
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
114
+ score_avg = np.mean(score)
115
+ score_avg /= len(refs)
116
+ score_avg *= 10.0
117
+ scores.append(score_avg)
118
+ return scores
119
+
120
+ def compute_score(self, df_mode, option=None, verbose=0):
121
+ '''Compute the CIDEr score based on df_mode (corpus or IDF-based).'''
122
+ self.compute_doc_freq()
123
+
124
+ if df_mode == "corpus":
125
+ if not self.document_frequency: # Handle the case where DF is empty
126
+ raise ValueError("Document frequency is empty. Please check the corpus data.")
127
+
128
+ min_required_data = max(self.document_frequency.values())
129
+ # print(min_required_data)# For corpus mode, we require at least one reference
130
+ # if len(self.ctest) < min_required_data:
131
+ # raise ValueError(f"Insufficient test data: {len(self.ctest)} samples, but at least {min_required_data} are required.")
132
+
133
+ score = self.compute_cider(df_mode)
134
+ return np.mean(np.array(score)), np.array(score)
DataFlow/dataflow/operators/eval/GeneralText/gen/cider_scorer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ from tqdm import tqdm
5
+ from dataflow.core import OperatorABC
6
+ from dataflow.utils.storage import DataFlowStorage
7
+ from dataflow.utils.registry import OPERATOR_REGISTRY
8
+ from dataflow import get_logger
9
+ from dataflow.operators.eval.GeneralText.gen.cider.cider import Cider
10
+
11
+ def load_idf(idf_path):
12
+ with open(idf_path, 'rb') as f:
13
+ idf = pickle.load(f, encoding='utf-8')
14
+ return idf
15
+
16
+ @OPERATOR_REGISTRY.register()
17
+ class CiderScorer(OperatorABC):
18
+ def __init__(self, n=4, sigma=6.0, df_mode="coco-val-df", idf_path="./dataflow/operators/eval/GeneralText/gen/cider/coco-val-df.p"):
19
+ self.logger = get_logger()
20
+ self.logger.info(f'Initializing {self.__class__.__name__}...')
21
+ self.score_name = 'CiderScore'
22
+ self.n = n # Max n-gram length (default: 4)
23
+ self.sigma = sigma # Sigma for Gaussian penalty (default: 6.0)
24
+ self.df_mode = df_mode
25
+ if self.df_mode != "corpus":
26
+ # The idf file can be downloaded at https://github.com/ramavedantam/coco-caption/blob/master/data/coco-val-df.p
27
+ # Put the file in the correct idf_path
28
+ self.idf = load_idf(idf_path)
29
+ else:
30
+ self.idf = None # No need to load IDF for 'corpus' mode
31
+ self.logger.info(f'{self.__class__.__name__} initialized.')
32
+
33
+ def _score_func(self, eval_text, ref_text):
34
+ cider_scorer = Cider(
35
+ test=eval_text,
36
+ refs=[ref_text],
37
+ n=self.n,
38
+ sigma=self.sigma,
39
+ idf=self.idf # Pass IDF (None if using 'corpus')
40
+ )
41
+ # Pass df_mode dynamically based on the argument
42
+ cider_score, _ = cider_scorer.compute_score(df_mode='corpus' if self.idf is None else 'coco-val-df')
43
+ return cider_score
44
+
45
+ def eval(self, dataframe, input_key, reference_key):
46
+ eval_data = dataframe[input_key]
47
+ ref_data = dataframe[reference_key]
48
+ self.logger.info(f"Evaluating {self.score_name}...")
49
+ scores = [self._score_func(eval_text, ref_text) for eval_text, ref_text in tqdm(zip(eval_data, ref_data), desc="CiderScorer Evaluating...")]
50
+ self.logger.info("Evaluation complete!")
51
+ return scores
52
+
53
+ def run(self, storage: DataFlowStorage, input_key: str, reference_key: str, output_key: str='CiderScore'):
54
+ self.input_key = input_key
55
+ self.reference_key = reference_key
56
+ self.output_key = output_key
57
+ dataframe = storage.read("dataframe")
58
+ scores = self.eval(dataframe, input_key, reference_key)
59
+ dataframe[self.output_key] = scores
60
+ storage.write(dataframe)
DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/__pycache__/model.cpython-310.pyc ADDED
Binary file (5.21 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/model.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import unicodedata
4
+ from typing import Dict
5
+
6
+ import kenlm
7
+ import sentencepiece
8
+
9
+ class SentencePiece:
10
+ def __init__(
11
+ self,
12
+ model: str,
13
+ ):
14
+ super().__init__()
15
+ self.sp = sentencepiece.SentencePieceProcessor()
16
+ self.sp.load(str(model))
17
+
18
+ def do(self, text: dict) -> dict:
19
+ tokenized = self.sp.encode_as_pieces(text)
20
+ return " ".join(tokenized)
21
+
22
+
23
+ class KenlmModel:
24
+ digit_re: re.Pattern = re.compile(r"\d")
25
+ unicode_punct: Dict[str, str] = {
26
+ ",": ",",
27
+ "。": ".",
28
+ "、": ",",
29
+ "„": '"',
30
+ "”": '"',
31
+ "“": '"',
32
+ "«": '"',
33
+ "»": '"',
34
+ "1": '"',
35
+ "」": '"',
36
+ "「": '"',
37
+ "《": '"',
38
+ "》": '"',
39
+ "´": "'",
40
+ "∶": ":",
41
+ ":": ":",
42
+ "?": "?",
43
+ "!": "!",
44
+ "(": "(",
45
+ ")": ")",
46
+ ";": ";",
47
+ "–": "-",
48
+ "—": " - ",
49
+ ".": ". ",
50
+ "~": "~",
51
+ "’": "'",
52
+ "…": "...",
53
+ "━": "-",
54
+ "〈": "<",
55
+ "〉": ">",
56
+ "【": "[",
57
+ "】": "]",
58
+ "%": "%",
59
+ "►": "-",
60
+ }
61
+ unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
62
+ non_printing_chars_re = re.compile(
63
+ f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
64
+ )
65
+ kenlm_model_dir = None
66
+ sentence_piece_model_dir = None
67
+
68
+ def __init__(
69
+ self,
70
+ model_dataset: str,
71
+ language: str,
72
+ lower_case: bool = False,
73
+ remove_accents: bool = False,
74
+ normalize_numbers: bool = True,
75
+ punctuation: int = 1,
76
+ ):
77
+ self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin"))
78
+ self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model"))
79
+ self.accent = remove_accents
80
+ self.case = lower_case
81
+ self.numbers = normalize_numbers
82
+ self.punct = punctuation
83
+
84
+ @classmethod
85
+ def from_pretrained(
86
+ cls,
87
+ model_dataset: str,
88
+ language: str,
89
+ ):
90
+ return cls(
91
+ model_dataset,
92
+ language,
93
+ False,
94
+ False,
95
+ True,
96
+ 1,
97
+ )
98
+
99
+ def pp(self, log_score, length):
100
+ return 10.0 ** (-log_score / length)
101
+
102
+ def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
103
+ if normalize_cc_net:
104
+ doc = self.normalize(
105
+ doc,
106
+ accent=self.accent,
107
+ case=self.case,
108
+ numbers=self.numbers,
109
+ punct=self.punct,
110
+ )
111
+ # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
112
+ doc = self.tokenizer.do(doc)
113
+ doc_log_score, doc_length = 0, 0
114
+ for line in doc.split("\n"):
115
+ log_score = self.model.score(line)
116
+ length = len(line.split()) + 1
117
+ doc_log_score += log_score
118
+ doc_length += length
119
+ return round(self.pp(doc_log_score, doc_length), 1)
120
+
121
+ def normalize(
122
+ self,
123
+ line: str,
124
+ accent: bool = True,
125
+ case: bool = True,
126
+ numbers: bool = True,
127
+ punct: int = 1,
128
+ ) -> str:
129
+ line = line.strip()
130
+ if not line:
131
+ return line
132
+ if case:
133
+ line = line.lower()
134
+ if accent:
135
+ line = self.strip_accents(line)
136
+ if numbers:
137
+ line = self.digit_re.sub("0", line)
138
+ if punct == 1:
139
+ line = self.replace_unicode_punct(line)
140
+ elif punct == 2:
141
+ line = self.remove_unicode_punct(line)
142
+ line = self.remove_non_printing_char(line)
143
+ return line
144
+
145
+ def strip_accents(self, line: str) -> str:
146
+ """Strips accents from a piece of text."""
147
+ nfd = unicodedata.normalize("NFD", line)
148
+ output = [c for c in nfd if unicodedata.category(c) != "Mn"]
149
+ if len(output) == line:
150
+ return line
151
+ return "".join(output)
152
+
153
+ def replace_unicode_punct(self, text: str) -> str:
154
+ return "".join(self.unicode_punct.get(c, c) for c in text)
155
+
156
+ def remove_unicode_punct(self, text: str) -> str:
157
+ """More aggressive version of replace_unicode_punct but also faster."""
158
+ return self.unicode_punct_re.sub("", text)
159
+
160
+ def remove_non_printing_char(self, text: str) -> str:
161
+ return self.non_printing_chars_re.sub("", text)
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/__pycache__/qurater_annotate.cpython-310.pyc ADDED
Binary file (7.02 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/__pycache__/modeling_flash_llama.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/modeling_flash_llama.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ from typing import List, Optional, Tuple, Union, Any
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ import torch.distributed as dist
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import logging
35
+ from transformers.models.llama.configuration_llama import LlamaConfig
36
+
37
+ def try_import_flash_attention():
38
+ try:
39
+ from flash_attn import flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_with_kvcache
40
+ from flash_attn.bert_padding import unpad_input, pad_input
41
+ from flash_attn.layers.rotary import apply_rotary_emb_func
42
+ except ImportError as e:
43
+ if 'flash_attn.layers.rotary' in str(e):
44
+ raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`')
45
+ else:
46
+ raise ImportError('Please install flash_attention dependency in GPU environment')
47
+ from dataflow import get_logger
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ # @torch.jit.script
52
+ def rmsnorm_func(hidden_states, weight, variance_epsilon):
53
+ input_dtype = hidden_states.dtype
54
+ hidden_states = hidden_states.to(torch.float32)
55
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
56
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
57
+ return (weight * hidden_states).to(input_dtype)
58
+
59
+
60
+ class LlamaRMSNorm(nn.Module):
61
+ def __init__(self, hidden_size, eps=1e-6):
62
+ """
63
+ LlamaRMSNorm is equivalent to T5LayerNorm
64
+ """
65
+ super().__init__()
66
+ self.weight = nn.Parameter(torch.ones(hidden_size))
67
+ self.register_buffer(
68
+ "variance_epsilon",
69
+ torch.tensor(eps),
70
+ persistent=False,
71
+ )
72
+
73
+ def forward(self, hidden_states):
74
+ return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
75
+
76
+
77
+ class FlashRotaryEmbedding(torch.nn.Module):
78
+ """
79
+ The rotary position embeddings from RoFormer_ (Su et. al).
80
+ A crucial insight from the method is that the query and keys are
81
+ transformed by rotation matrices which depend on the relative positions.
82
+
83
+ Other implementations are available in the Rotary Transformer repo_ and in
84
+ GPT-NeoX_, GPT-NeoX was an inspiration
85
+
86
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
87
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
88
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
89
+
90
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
91
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
92
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
93
+ """
94
+
95
+ def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
96
+ scaling_factor=1.0, pos_idx_in_fp32=True, device=None):
97
+ """
98
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
99
+ of 1st half and 2nd half (GPT-NeoX style).
100
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
101
+ otherwise they might be in lower precision.
102
+ This option was added because previously (before 2023-07-02), when we construct
103
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
104
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
105
+ self.inv_freq would be bf16, and the position indices are also in bf16.
106
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
107
+ embeddings for some positions will coincide.
108
+ To maintain compatibility with models previously trained in pure bf16,
109
+ we add this option.
110
+ scaling_factor: RotaryEmbedding extended with linear scaling.
111
+ """
112
+ super().__init__()
113
+ self.dim = dim
114
+ self.base = float(base)
115
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
116
+ # Generate and save the inverse frequency buffer (non trainable)
117
+ inv_freq = self._compute_inv_freq(device)
118
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
119
+ self.interleaved = interleaved
120
+ self.scale_base = scale_base
121
+ self.scaling_factor = scaling_factor
122
+ scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
123
+ / (1.4 * dim) if scale_base is not None else None)
124
+ self.register_buffer("scale", scale)
125
+
126
+ self._seq_len_cached = 0
127
+ self._cos_cached = None
128
+ self._sin_cached = None
129
+ self._cos_k_cached = None
130
+ self._sin_k_cached = None
131
+
132
+ def _compute_inv_freq(self, device=None):
133
+ return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
134
+ dtype=torch.float32) / self.dim))
135
+
136
+
137
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
138
+ # Reset the tables if the sequence length has changed,
139
+ # if we're on a new device (possibly due to tracing for instance),
140
+ # or if we're switching from inference mode to training
141
+ if (seqlen > self._seq_len_cached or self._cos_cached.device != device
142
+ or self._cos_cached.dtype != dtype
143
+ or (self.training and self._cos_cached.is_inference())):
144
+ self._seq_len_cached = seqlen
145
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
146
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
147
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
148
+ if self.pos_idx_in_fp32:
149
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
150
+ t /= self.scaling_factor
151
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
152
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
153
+ # cos & sin output to change significantly.
154
+ # We want to recompute self.inv_freq if it was not loaded in fp32
155
+ if self.inv_freq.dtype != torch.float32:
156
+ inv_freq = self.inv_freq.to(torch.float32)
157
+ else:
158
+ inv_freq = self.inv_freq
159
+ else:
160
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
161
+ t /= self.scaling_factor
162
+ inv_freq = self.inv_freq
163
+ # Don't do einsum, it converts fp32 to fp16 under AMP
164
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
165
+ freqs = torch.outer(t, inv_freq)
166
+ if self.scale is None:
167
+ self._cos_cached = torch.cos(freqs).to(dtype)
168
+ self._sin_cached = torch.sin(freqs).to(dtype)
169
+ else:
170
+ power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
171
+ - seqlen // 2) / self.scale_base)
172
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
173
+ # We want the multiplication by scale to happen in fp32
174
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
175
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
176
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
177
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
178
+
179
+ def forward(self,
180
+ q: torch.Tensor, k: torch.Tensor,
181
+ seqlen_offset: int = 0,
182
+ unpadded_lengths: Optional[Tuple[torch.Tensor]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
183
+ """
184
+ q: (batch, seqlen, nheads, headdim)
185
+ k: (batch, seqlen, nheads, headdim)
186
+ seqlen_offset: can be used in generation where the qkv being passed in is only the last
187
+ token in the batch.
188
+ """
189
+ if unpadded_lengths is not None:
190
+ cu_seqlens, max_seqlen = unpadded_lengths
191
+ else:
192
+ cu_seqlens, max_seqlen = None, q.shape[1]
193
+ self._update_cos_sin_cache(max_seqlen + seqlen_offset, device=q.device, dtype=q.dtype)
194
+
195
+ if self.scale is None:
196
+ return apply_rotary_emb_func(
197
+ q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
198
+ self.interleaved, True, # inplace=True,
199
+ cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
200
+ ), apply_rotary_emb_func(
201
+ k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
202
+ self.interleaved, True, # inplace=True
203
+ cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
204
+ )
205
+ else:
206
+ assert False
207
+
208
+
209
+ class LlamaMLP(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ self.config = config
213
+ self.hidden_size = config.hidden_size
214
+ self.intermediate_size = config.intermediate_size
215
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
216
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
217
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
218
+ self.act_fn = ACT2FN[config.hidden_act]
219
+
220
+ def forward(self, x):
221
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
222
+
223
+
224
+ @torch.jit.script
225
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
226
+ if n_rep == 1:
227
+ return hidden_states
228
+ final_shape = list(hidden_states.shape[:-2]) + [-1] + [hidden_states.shape[-1]]
229
+ expand_shape = [-1] * (len(hidden_states.shape) - 1) + [n_rep] + [-1]
230
+ hidden_states = hidden_states.unsqueeze(-2).expand(expand_shape)
231
+ return hidden_states.reshape(final_shape)
232
+
233
+
234
+ class LlamaAttention(nn.Module):
235
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
236
+
237
+ def __init__(self, config: LlamaConfig):
238
+ super().__init__()
239
+ self.config = config
240
+ self.hidden_size = config.hidden_size
241
+ self.num_heads = config.num_attention_heads
242
+ self.head_dim = self.hidden_size // self.num_heads
243
+ self.num_key_value_heads = getattr(config, "num_key_value_heads", self.num_heads)
244
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
245
+ self.max_position_embeddings = config.max_position_embeddings
246
+
247
+ if (self.head_dim * self.num_heads) != self.hidden_size:
248
+ raise ValueError(
249
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
250
+ f" and `num_heads`: {self.num_heads})."
251
+ )
252
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
253
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
254
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
255
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
256
+
257
+ self.register_buffer(
258
+ "norm_factor",
259
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
260
+ persistent=False,
261
+ )
262
+
263
+ if not getattr(self.config, "rope_scaling", None):
264
+ scaling_factor = 1
265
+ else:
266
+ scaling_type = self.config.rope_scaling["type"]
267
+ scaling_factor = self.config.rope_scaling["factor"]
268
+ assert scaling_type == 'linear'
269
+ theta = getattr(self.config, "rope_theta", 10000)
270
+ self.rotary_emb = FlashRotaryEmbedding(
271
+ self.head_dim, base=theta, interleaved=False, scaling_factor=scaling_factor,
272
+ )
273
+
274
+ self.distributed_attn_func = flash_attn_kvpacked_func
275
+
276
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
277
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
278
+
279
+ def forward(
280
+ self,
281
+ hidden_states: torch.Tensor,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ position_ids: Optional[torch.LongTensor] = None,
284
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
285
+ output_attentions: bool = False,
286
+ use_cache: bool = False,
287
+ unpadded_lengths: Optional[Tuple[torch.Tensor]] = None,
288
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
289
+ h_size = hidden_states.size(-1)
290
+
291
+ has_layer_past = past_key_value is not None
292
+
293
+ if has_layer_past:
294
+ past_kv = past_key_value[0]
295
+ past_len = past_key_value[1]
296
+ else:
297
+ past_len = 0
298
+
299
+ # NOTE: Hack to include position_ids, assuming they are increasing uniformly per block
300
+ if position_ids is not None:
301
+ past_len += position_ids.min()
302
+
303
+ q = self.q_proj(hidden_states)
304
+ k = self.k_proj(hidden_states)
305
+ v = self.v_proj(hidden_states)
306
+
307
+ q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
308
+ k = k.view(*k.shape[:-1], self.num_key_value_heads, self.head_dim)
309
+ v = v.view(*v.shape[:-1], self.num_key_value_heads, self.head_dim)
310
+
311
+ q, k = self.rotary_emb(q, k, past_len, unpadded_lengths)
312
+
313
+ kv = torch.stack([k, v], -3)
314
+ kv = repeat_kv(kv, self.num_key_value_groups)
315
+
316
+ # Cache QKV values
317
+ if has_layer_past:
318
+ new_len = past_len+q.size(1)
319
+ if new_len > past_kv.size(1):
320
+ past_kv = torch.cat([past_kv, torch.empty(hidden_states.size(0), 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1)
321
+ past_kv[:, past_len:new_len] = kv
322
+ kv = past_kv[:, :new_len]
323
+ else:
324
+ past_kv = kv
325
+
326
+ if unpadded_lengths is not None:
327
+ # varlen, ignore padding tokens, efficient for large batch with many paddings
328
+ assert attention_mask is not None
329
+ cu_seqlens, max_seqlen = unpadded_lengths
330
+
331
+ attn_outputs = flash_attn_varlen_kvpacked_func(
332
+ q, kv,
333
+ cu_seqlens, cu_seqlens,
334
+ max_seqlen, max_seqlen,
335
+ dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
336
+ causal=True, return_attn_probs=output_attentions
337
+ )
338
+ # elif use_cache and past_key_value is not None:
339
+ # attn_outputs = flash_attn_with_kvcache(
340
+ # q,
341
+ # kv[:, :, 0],
342
+ # kv[:, :, 1],
343
+ # softmax_scale=1.0/self.norm_factor,
344
+ # causal=True,
345
+ # )
346
+ else:
347
+ attn_outputs = flash_attn_kvpacked_func(
348
+ q, kv,
349
+ dropout_p=0.0,
350
+ softmax_scale=1.0/self.norm_factor,
351
+ causal=True,
352
+ return_attn_probs=output_attentions,
353
+ )
354
+ past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None
355
+
356
+
357
+ attn_output = attn_outputs[0] if output_attentions else attn_outputs
358
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], h_size)
359
+ attn_weights = attn_outputs[2] if output_attentions else None
360
+
361
+ attn_output = self.o_proj(attn_output)
362
+
363
+ if not output_attentions:
364
+ attn_weights = None
365
+
366
+ return attn_output, attn_weights, past_key_value
367
+
368
+
369
+ class LlamaDecoderLayer(nn.Module):
370
+ def __init__(self, config: LlamaConfig):
371
+ super().__init__()
372
+ self.hidden_size = config.hidden_size
373
+ self.self_attn = LlamaAttention(config=config)
374
+ self.mlp = LlamaMLP(config)
375
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377
+ self._fsdp_wrap = True
378
+
379
+ def forward(
380
+ self,
381
+ hidden_states: torch.Tensor,
382
+ attention_mask: Optional[torch.Tensor] = None,
383
+ position_ids: Optional[torch.LongTensor] = None,
384
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
385
+ unpadded_lengths: Optional[Tuple[torch.Tensor]] = None,
386
+ output_attentions: Optional[bool] = False,
387
+ use_cache: Optional[bool] = False,
388
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
389
+ """
390
+ Args:
391
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
392
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
393
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
394
+ output_attentions (`bool`, *optional*):
395
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
396
+ returned tensors for more detail.
397
+ use_cache (`bool`, *optional*):
398
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
399
+ (see `past_key_values`).
400
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
401
+ """
402
+
403
+ residual = hidden_states
404
+
405
+ hidden_states = self.input_layernorm(hidden_states)
406
+
407
+ # Self Attention
408
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
409
+ hidden_states=hidden_states,
410
+ attention_mask=attention_mask,
411
+ position_ids=position_ids,
412
+ past_key_value=past_key_value,
413
+ output_attentions=output_attentions,
414
+ use_cache=use_cache,
415
+ unpadded_lengths=unpadded_lengths,
416
+ )
417
+ hidden_states = residual + hidden_states
418
+
419
+ # Fully Connected
420
+ residual = hidden_states
421
+ hidden_states = self.post_attention_layernorm(hidden_states)
422
+ hidden_states = self.mlp(hidden_states)
423
+ hidden_states = residual + hidden_states
424
+
425
+ outputs = (hidden_states,)
426
+
427
+ if output_attentions:
428
+ outputs += (self_attn_weights,)
429
+
430
+ if use_cache:
431
+ outputs += (present_key_value,)
432
+
433
+ return outputs
434
+
435
+
436
+ class LlamaPreTrainedModel(PreTrainedModel):
437
+ config_class = LlamaConfig
438
+ base_model_prefix = "model"
439
+ supports_gradient_checkpointing = True
440
+ _no_split_modules = ["LlamaDecoderLayer"]
441
+ _skip_keys_device_placement = "past_key_values"
442
+
443
+ def _init_weights(self, module):
444
+ std = self.config.initializer_range
445
+ if isinstance(module, nn.Linear):
446
+ module.weight.data.normal_(mean=0.0, std=std)
447
+ if module.bias is not None:
448
+ module.bias.data.zero_()
449
+ elif isinstance(module, nn.Embedding):
450
+ module.weight.data.normal_(mean=0.0, std=std)
451
+ if module.padding_idx is not None:
452
+ module.weight.data[module.padding_idx].zero_()
453
+
454
+
455
+ class LlamaModel(LlamaPreTrainedModel):
456
+ """
457
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
458
+
459
+ Args:
460
+ config: LlamaConfig
461
+ """
462
+
463
+ def __init__(self, config: LlamaConfig):
464
+ super().__init__(config)
465
+ self.padding_idx = config.pad_token_id
466
+ self.vocab_size = config.vocab_size
467
+
468
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
469
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
470
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
+
472
+ self.gradient_checkpointing = False
473
+ # Initialize weights and apply final processing
474
+ try_import_flash_attention()
475
+ self.post_init()
476
+
477
+ def get_input_embeddings(self):
478
+ return self.embed_tokens
479
+
480
+ def set_input_embeddings(self, value):
481
+ self.embed_tokens = value
482
+
483
+ def forward(
484
+ self,
485
+ input_ids: torch.LongTensor = None,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
489
+ inputs_embeds: Optional[torch.FloatTensor] = None,
490
+ use_cache: Optional[bool] = None,
491
+ output_attentions: Optional[bool] = None,
492
+ output_hidden_states: Optional[bool] = None,
493
+ return_dict: Optional[bool] = None,
494
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
495
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
496
+ output_hidden_states = (
497
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
498
+ )
499
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
500
+
501
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
502
+
503
+ # retrieve input_ids and inputs_embeds
504
+ if input_ids is not None and inputs_embeds is not None:
505
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
506
+ elif input_ids is not None:
507
+ batch_size, seq_length = input_ids.shape
508
+ elif inputs_embeds is not None:
509
+ batch_size, seq_length, _ = inputs_embeds.shape
510
+ else:
511
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
512
+
513
+ # position_ids = None
514
+
515
+ if inputs_embeds is None:
516
+ inputs_embeds = self.embed_tokens(input_ids)
517
+
518
+ hidden_states = inputs_embeds
519
+ bsz = hidden_states.size(0)
520
+
521
+ if self.gradient_checkpointing and self.training:
522
+ if use_cache:
523
+ logger.warning_once(
524
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
525
+ )
526
+ use_cache = False
527
+
528
+ if (
529
+ ((attention_mask is not None) and (not attention_mask.all().item()))
530
+ and not use_cache
531
+ ):
532
+ try: # for flash-attn latest version
533
+ hidden_states, unpad_indices, cu_seqlens, max_seqlen, _ = unpad_input(hidden_states, attention_mask)
534
+ except: # for flash-attn 2.3.3 verstion
535
+ hidden_states, unpad_indices, cu_seqlens, max_seqlen = unpad_input(hidden_states, attention_mask)
536
+ unpadded_lengths = (cu_seqlens, max_seqlen)
537
+ else:
538
+ unpadded_lengths = None
539
+
540
+ # decoder layers
541
+ all_hidden_states = () if output_hidden_states else None
542
+ all_self_attns = () if output_attentions else None
543
+ next_decoder_cache = () if use_cache else None
544
+
545
+ for idx, decoder_layer in enumerate(self.layers):
546
+ if output_hidden_states:
547
+ if unpadded_lengths is not None:
548
+ all_hidden_states += (pad_input(hidden_states, unpad_indices, bsz, max_seqlen),)
549
+ else:
550
+ all_hidden_states += (hidden_states,)
551
+
552
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
553
+
554
+ if self.gradient_checkpointing and self.training:
555
+ layer_outputs = torch.utils.checkpoint.checkpoint(
556
+ decoder_layer,
557
+ hidden_states,
558
+ attention_mask,
559
+ position_ids,
560
+ None,
561
+ unpadded_lengths,
562
+ output_attentions,
563
+ False,
564
+ use_reentrant=False
565
+ )
566
+ else:
567
+ layer_outputs = decoder_layer(
568
+ hidden_states,
569
+ attention_mask=attention_mask,
570
+ position_ids=position_ids,
571
+ past_key_value=past_key_value,
572
+ unpadded_lengths=unpadded_lengths,
573
+ output_attentions=output_attentions,
574
+ use_cache=use_cache,
575
+ )
576
+
577
+ hidden_states = layer_outputs[0]
578
+
579
+ if use_cache:
580
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
581
+
582
+ if output_attentions:
583
+ all_self_attns += (layer_outputs[1],)
584
+
585
+ if unpadded_lengths is not None:
586
+ hidden_states = pad_input(hidden_states, unpad_indices, bsz, max_seqlen)
587
+ hidden_states = self.norm(hidden_states)
588
+
589
+ # add hidden states from the last decoder layer
590
+ if output_hidden_states:
591
+ all_hidden_states += (hidden_states,)
592
+
593
+ next_cache = next_decoder_cache if use_cache else None
594
+ if not return_dict:
595
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
596
+ return BaseModelOutputWithPast(
597
+ last_hidden_state=hidden_states,
598
+ past_key_values=next_cache,
599
+ hidden_states=all_hidden_states,
600
+ attentions=all_self_attns,
601
+ )
602
+
603
+
604
+ class LlamaForCausalLM(LlamaPreTrainedModel):
605
+ _tied_weights_keys = ["lm_head.weight"]
606
+
607
+ def __init__(self, config):
608
+ super().__init__(config)
609
+ self.model = LlamaModel(config)
610
+ self.vocab_size = config.vocab_size
611
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
612
+ try_import_flash_attention()
613
+ # Initialize weights and apply final processing
614
+ self.post_init()
615
+
616
+ def get_input_embeddings(self):
617
+ return self.model.embed_tokens
618
+
619
+ def set_input_embeddings(self, value):
620
+ self.model.embed_tokens = value
621
+
622
+ def get_output_embeddings(self):
623
+ return self.lm_head
624
+
625
+ def set_output_embeddings(self, new_embeddings):
626
+ self.lm_head = new_embeddings
627
+
628
+ def set_decoder(self, decoder):
629
+ self.model = decoder
630
+
631
+ def get_decoder(self):
632
+ return self.model
633
+
634
+ def forward(
635
+ self,
636
+ input_ids: torch.LongTensor = None,
637
+ attention_mask: Optional[torch.Tensor] = None,
638
+ position_ids: Optional[torch.LongTensor] = None,
639
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
640
+ inputs_embeds: Optional[torch.FloatTensor] = None,
641
+ labels: Optional[torch.LongTensor] = None,
642
+ use_cache: Optional[bool] = None,
643
+ output_attentions: Optional[bool] = None,
644
+ output_hidden_states: Optional[bool] = None,
645
+ return_dict: Optional[bool] = None,
646
+ avg_valid_labels_per_chunk: Optional[float] = None,
647
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
648
+ r"""
649
+ Args:
650
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
651
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
652
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
653
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
654
+
655
+ Returns:
656
+
657
+ Example:
658
+
659
+ ```python
660
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
661
+
662
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
663
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
664
+
665
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
666
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
667
+
668
+ >>> # Generate
669
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
670
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
671
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
672
+ ```"""
673
+
674
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
675
+ output_hidden_states = (
676
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
677
+ )
678
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
679
+
680
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
681
+ outputs = self.model(
682
+ input_ids=input_ids,
683
+ attention_mask=attention_mask,
684
+ position_ids=position_ids,
685
+ past_key_values=past_key_values,
686
+ inputs_embeds=inputs_embeds,
687
+ use_cache=use_cache,
688
+ output_attentions=output_attentions,
689
+ output_hidden_states=output_hidden_states,
690
+ return_dict=return_dict,
691
+ )
692
+
693
+ hidden_states = outputs[0]
694
+ logits = self.lm_head(hidden_states).float()
695
+
696
+ loss = None
697
+ if labels is not None:
698
+ # Shift so that tokens < n predict n
699
+ shift_logits = logits[..., :-1, :].contiguous()
700
+ shift_labels = labels[..., 1:].contiguous()
701
+ # Flatten the tokens
702
+ loss_fct = CrossEntropyLoss()
703
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
704
+ shift_labels = shift_labels.view(-1)
705
+ # Enable model parallelism
706
+ shift_labels = shift_labels.to(shift_logits.device)
707
+ loss = loss_fct(shift_logits, shift_labels)
708
+
709
+ if not return_dict:
710
+ output = (logits,) + outputs[1:]
711
+ return (loss,) + output if loss is not None else output
712
+
713
+ return CausalLMOutputWithPast(
714
+ loss=loss,
715
+ logits=logits,
716
+ past_key_values=outputs.past_key_values,
717
+ hidden_states=outputs.hidden_states,
718
+ attentions=outputs.attentions,
719
+ )
720
+
721
+ def prepare_inputs_for_generation(
722
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
723
+ ):
724
+ if past_key_values:
725
+ input_ids = input_ids[:, -1:]
726
+
727
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
728
+ if inputs_embeds is not None and past_key_values is None:
729
+ model_inputs = {"inputs_embeds": inputs_embeds}
730
+ else:
731
+ model_inputs = {"input_ids": input_ids}
732
+
733
+ model_inputs.update(
734
+ {
735
+ "past_key_values": past_key_values,
736
+ "use_cache": kwargs.get("use_cache"),
737
+ "attention_mask": attention_mask
738
+ }
739
+ )
740
+ return model_inputs
741
+
742
+ @staticmethod
743
+ def _reorder_cache(past_key_values, beam_idx):
744
+ reordered_past = ()
745
+ for layer_past in past_key_values:
746
+ reordered_past += (
747
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
748
+ )
749
+ return reordered_past
750
+
751
+
752
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
753
+ def __init__(self, config):
754
+ super().__init__(config)
755
+ self.num_labels = config.num_labels
756
+ self.model = LlamaModel(config)
757
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
758
+ try_import_flash_attention()
759
+ # Initialize weights and apply final processing
760
+ self.post_init()
761
+
762
+ def get_input_embeddings(self):
763
+ return self.model.embed_tokens
764
+
765
+ def set_input_embeddings(self, value):
766
+ self.model.embed_tokens = value
767
+
768
+ def forward(
769
+ self,
770
+ input_ids: torch.LongTensor = None,
771
+ attention_mask: Optional[torch.Tensor] = None,
772
+ position_ids: Optional[torch.LongTensor] = None,
773
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
774
+ inputs_embeds: Optional[torch.FloatTensor] = None,
775
+ labels: Optional[torch.LongTensor] = None,
776
+ use_cache: Optional[bool] = None,
777
+ output_attentions: Optional[bool] = None,
778
+ output_hidden_states: Optional[bool] = None,
779
+ return_dict: Optional[bool] = None,
780
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
781
+ r"""
782
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
783
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
784
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
785
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
786
+ """
787
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
788
+
789
+ transformer_outputs = self.model(
790
+ input_ids,
791
+ attention_mask=attention_mask,
792
+ position_ids=position_ids,
793
+ past_key_values=past_key_values,
794
+ inputs_embeds=inputs_embeds,
795
+ use_cache=use_cache,
796
+ output_attentions=output_attentions,
797
+ output_hidden_states=output_hidden_states,
798
+ return_dict=return_dict,
799
+ )
800
+ hidden_states = transformer_outputs[0]
801
+ logits = self.score(hidden_states)
802
+
803
+ if input_ids is not None:
804
+ batch_size = input_ids.shape[0]
805
+ else:
806
+ batch_size = inputs_embeds.shape[0]
807
+
808
+ if self.config.pad_token_id is None and batch_size != 1:
809
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
810
+ if self.config.pad_token_id is None:
811
+ sequence_lengths = -1
812
+ else:
813
+ if input_ids is not None:
814
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
815
+ else:
816
+ sequence_lengths = -1
817
+
818
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
819
+
820
+ loss = None
821
+ if labels is not None:
822
+ labels = labels.to(logits.device)
823
+ if self.config.problem_type is None:
824
+ if self.num_labels == 1:
825
+ self.config.problem_type = "regression"
826
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
827
+ self.config.problem_type = "single_label_classification"
828
+ else:
829
+ self.config.problem_type = "multi_label_classification"
830
+
831
+ if self.config.problem_type == "regression":
832
+ loss_fct = MSELoss()
833
+ if self.num_labels == 1:
834
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
835
+ else:
836
+ loss = loss_fct(pooled_logits, labels)
837
+ elif self.config.problem_type == "single_label_classification":
838
+ loss_fct = CrossEntropyLoss()
839
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
840
+ elif self.config.problem_type == "multi_label_classification":
841
+ loss_fct = BCEWithLogitsLoss()
842
+ loss = loss_fct(pooled_logits, labels)
843
+ if not return_dict:
844
+ output = (pooled_logits,) + transformer_outputs[1:]
845
+ return ((loss,) + output) if loss is not None else output
846
+
847
+ return SequenceClassifierOutputWithPast(
848
+ loss=loss,
849
+ logits=pooled_logits,
850
+ past_key_values=transformer_outputs.past_key_values,
851
+ hidden_states=transformer_outputs.hidden_states,
852
+ attentions=transformer_outputs.attentions,
853
+ )
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/qurater_annotate.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_from_disk, load_dataset, concatenate_datasets
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from .modeling.modeling_flash_llama import LlamaForSequenceClassification
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+
8
+ class TokenizeAndChunk:
9
+ def __init__(self, tokenizer_name, text_field, tokens_field, tokens, model_cache_dir):
10
+ self.tokens = tokens
11
+ self.tokenizer_name = tokenizer_name
12
+ self.text_field = text_field
13
+ self.tokens_field = tokens_field
14
+
15
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, cache_dir = model_cache_dir)
16
+ self.tokenizer.pad_token_id = 0
17
+
18
+ def __getstate__(self):
19
+ return {
20
+ "tokenizer_name": self.tokenizer_name,
21
+ "text_field": self.text_field,
22
+ "tokens_field": self.tokens_field,
23
+ "tokens": self.tokens,
24
+ }
25
+
26
+ def __setstate__(self, state):
27
+ self.__init__(**state)
28
+
29
+ def tokenize_and_chunk(self, source_tokens):
30
+ chunks_token_ids = []
31
+ chunks_token_counts = []
32
+
33
+ for seq in source_tokens:
34
+ chunks = torch.tensor(seq, dtype=torch.long).split(self.tokens)
35
+ chunks_token_ids.append([chunk.tolist() for chunk in chunks])
36
+ chunks_token_counts.append([len(x) for x in chunks])
37
+
38
+ return chunks_token_ids, chunks_token_counts
39
+
40
+ def __call__(self, example):
41
+ if self.tokens_field in example:
42
+ source_tokens = example[self.tokens_field]
43
+ else:
44
+ source_tokens = self.tokenizer(example[self.text_field], truncation=False, padding=False, add_special_tokens=False).input_ids
45
+
46
+ chunks_token_ids, chunks_token_counts = self.tokenize_and_chunk(source_tokens)
47
+
48
+ assert len(example[self.text_field]) == len(chunks_token_ids)
49
+ assert len(example[self.text_field]) == len(chunks_token_counts)
50
+
51
+ return {
52
+ "chunks_token_ids": chunks_token_ids,
53
+ "chunks_token_counts": chunks_token_counts,
54
+ }
55
+
56
+
57
+ class ModelAnnotator:
58
+ def __init__(self, model_name, labels, device_batch_size, device, model_cache_dir):
59
+
60
+ self.model_name = model_name
61
+ self.labels = labels
62
+ self.device_batch_size = device_batch_size
63
+
64
+ self.model = LlamaForSequenceClassification.from_pretrained(
65
+ self.model_name,
66
+ torch_dtype=torch.bfloat16,
67
+ cache_dir=model_cache_dir)
68
+ self.model.config.pad_token_id = 0
69
+ self.model.eval()
70
+
71
+ self.device = device
72
+ print(f"Using device {self.device}")
73
+ self.model.to(self.device)
74
+
75
+ self.num_labels = len(labels)
76
+ assert self.num_labels == self.model.config.num_labels, f"Number of labels ({self.num_labels}) does not match model config ({self.model.config.num_labels})"
77
+
78
+ def __getstate__(self):
79
+ return {
80
+ "model_name": self.model_name,
81
+ "labels": self.labels,
82
+ "device_batch_size": self.device_batch_size,
83
+ }
84
+
85
+ def __setstate__(self, state):
86
+ self.__init__(**state)
87
+
88
+ @torch.inference_mode()
89
+ def score_chunks(self, chunks_token_ids, chunks_token_counts):
90
+ sorted_indices = torch.argsort(chunks_token_counts)
91
+
92
+ scores = torch.zeros(len(chunks_token_ids), self.num_labels, dtype=torch.float32)
93
+
94
+ for batch_indices in sorted_indices.split(self.device_batch_size):
95
+ max_len = chunks_token_counts[batch_indices].max()
96
+
97
+ input_ids = torch.zeros((len(batch_indices), max_len), dtype=torch.long)
98
+ attention_mask = torch.zeros((len(batch_indices), max_len), dtype=torch.long)
99
+
100
+ for i, j in enumerate(batch_indices):
101
+ seq = chunks_token_ids[j]
102
+ input_ids[i, :len(seq)] = seq
103
+ attention_mask[i, :len(seq)] = 1
104
+
105
+ outputs = self.model(input_ids.to(self.device), attention_mask=attention_mask.to(self.device), use_cache=False)
106
+ scores[batch_indices] = outputs.logits.float().cpu()
107
+ return scores
108
+
109
+ def __call__(self, example, indices):
110
+ num_seqs = len(indices)
111
+
112
+ source_ids = [i for i, counts in enumerate(example["chunks_token_counts"]) for _ in range(len(counts))]
113
+ chunks_token_ids = [torch.tensor(chunk, dtype=torch.long) for chunks in example["chunks_token_ids"] for chunk in chunks]
114
+ flattened_chunks_token_counts = torch.tensor([chunk for chunks in example["chunks_token_counts"] for chunk in chunks], dtype=torch.long)
115
+
116
+ flattened_scores = self.score_chunks(chunks_token_ids, flattened_chunks_token_counts)
117
+
118
+ chunk_token_counts = example["chunks_token_counts"]
119
+ chunk_scores = [[[] for _ in range(num_seqs)] for _ in range(self.num_labels)]
120
+
121
+ for source_id, score in zip(source_ids, flattened_scores):
122
+ for label in range(self.num_labels):
123
+ chunk_scores[label][source_id].append(score[label].item())
124
+
125
+ output = {
126
+ "index": indices,
127
+ "chunk_lengths": chunk_token_counts,
128
+ "length": [sum(counts) for counts in chunk_token_counts],
129
+ }
130
+
131
+ for i, label in enumerate(self.labels):
132
+ output[f"{label}_chunks"] = chunk_scores[i]
133
+ output[f"{label}_average"] = [
134
+ np.average(scores, weights=token_counts).item()
135
+ for scores, token_counts in zip(chunk_scores[i], chunk_token_counts)
136
+ ]
137
+
138
+ return output
139
+
140
+
141
+ if __name__ == "__main__":
142
+ parser = argparse.ArgumentParser()
143
+ parser.add_argument("input", type=str)
144
+ parser.add_argument("output", type=str)
145
+
146
+ parser.add_argument("-F", "--data_files", type=str, nargs="+", default=[])
147
+ parser.add_argument("-S", "--shard", type=int, nargs=2, default=[0, 1])
148
+ parser.add_argument("-M", "--model", type=str, required=True)
149
+ parser.add_argument("-t", "--tokens", type=int, default=512)
150
+ parser.add_argument("--map_batch_size", type=int, default=512)
151
+ parser.add_argument("-b", "--device_batch_size", type=int, default=16)
152
+ parser.add_argument("-w", "--num_workers", type=int, default=1)
153
+ parser.add_argument("--text_field", type=str, default="text")
154
+ parser.add_argument("--tokens_field", type=str, default="input_ids")
155
+ parser.add_argument("--labels", type=str, nargs="+")
156
+
157
+ args = parser.parse_args()
158
+ print(args)
159
+
160
+ if args.input == "json":
161
+ dataset = load_dataset("json", data_files=args.data_files, split="train")
162
+ else:
163
+ dataset = load_from_disk(args.input)
164
+
165
+ src_dataset = dataset.shard(args.shard[1], args.shard[0], contiguous=True)
166
+ dataset = src_dataset
167
+
168
+ print(dataset)
169
+ print("Total number of examples:", len(dataset))
170
+ dataset = dataset.map(
171
+ TokenizeAndChunk(args.model, args.text_field, args.tokens_field, args.tokens),
172
+ batched=True,
173
+ batch_size=args.map_batch_size,
174
+ num_proc=args.num_workers,
175
+ remove_columns=dataset.column_names)
176
+
177
+ print("After tokenization: Total number of examples:", len(dataset))
178
+ dataset = dataset.map(
179
+ ModelAnnotator(args.model, args.labels, args.device_batch_size),
180
+ batched=True,
181
+ with_indices=True,
182
+ batch_size=args.map_batch_size,
183
+ remove_columns=dataset.column_names)
184
+
185
+ dataset = concatenate_datasets([dataset, src_dataset], axis=1)
186
+
187
+ print("After annotation: Total number of examples:", len(dataset))
188
+
189
+ print(f"Saving to {args.output}")
190
+ dataset.save_to_disk(args.output)
DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/__pycache__/data_analysis.cpython-310.pyc ADDED
Binary file (1.51 kB). View file
 
DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/data_analysis.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import argparse
5
+ from tqdm import tqdm
6
+
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+
10
+ PROMPT_DICT_NONE = {
11
+ "prompt_input": (
12
+ "{instruction}\n{input}\n"
13
+ ),
14
+ "prompt_no_input": (
15
+ "{instruction}\n"
16
+ ),
17
+ }
18
+ # Used to get the ppl and emb for the whole input
19
+ def get_perplexity_and_embedding_whole_text(tokenizer, model, text, max_length, device):
20
+
21
+ input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
22
+
23
+ with torch.no_grad():
24
+ outputs = model(input_ids, labels=input_ids.contiguous())
25
+ loss = outputs.loss
26
+ perplexity = torch.exp(loss)
27
+
28
+ return perplexity.to('cpu').item(), loss.to('cpu').item()
29
+
30
+
31
+ # Used to get the ppl and emb for part of input, used in conditional version, and token-wise loss
32
+ def get_perplexity_and_embedding_part_text(tokenizer, model, text, target_span, max_length, device):
33
+
34
+ try:
35
+ input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
36
+
37
+ start_index = text.rfind(target_span)
38
+ start_token = len(tokenizer.encode(text[:start_index]))
39
+ end_token = input_ids.shape[1]
40
+
41
+ labels = input_ids.clone()
42
+ labels[0, :start_token] = -100
43
+
44
+ with torch.no_grad():
45
+ outputs = model(input_ids, labels=labels)
46
+
47
+ loss = outputs.loss
48
+ perplexity = torch.exp(loss)
49
+
50
+ return perplexity.to('cpu').item(), loss.to('cpu').item()
51
+
52
+ except:
53
+ return 0, 0
DataFlow/dataflow/operators/eval/GeneralText/models/__pycache__/debertav3_scorer.cpython-310.pyc ADDED
Binary file (3.61 kB). View file