Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +14 -0
- DataFlow/OpenSeek_CoT/Openseek_Math_Chain_of_Thoughts_pipeline.py.py +137 -0
- DataFlow/dataflow/__init__.py +14 -0
- DataFlow/dataflow/cli.py +80 -0
- DataFlow/dataflow/core/LLMServing.py +27 -0
- DataFlow/dataflow/core/Operator.py +31 -0
- DataFlow/dataflow/core/__init__.py +7 -0
- DataFlow/dataflow/core/__pycache__/LLMServing.cpython-310.pyc +0 -0
- DataFlow/dataflow/core/__pycache__/Operator.cpython-310.pyc +0 -0
- DataFlow/dataflow/core/__pycache__/__init__.cpython-310.pyc +0 -0
- DataFlow/dataflow/logger.py +38 -0
- DataFlow/dataflow/operators/__init__.py +4 -0
- DataFlow/dataflow/operators/eval/AgenticRAG/statistics/f1_scorer.py +108 -0
- DataFlow/dataflow/operators/eval/GeneralText/APIcaller/__pycache__/perspective_scorer.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/APIcaller/meta_scorer.py +70 -0
- DataFlow/dataflow/operators/eval/GeneralText/APIcaller/treeinstruct_scorer.py +53 -0
- DataFlow/dataflow/operators/eval/GeneralText/__init__.py +55 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/task2vec_scorer.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/vendi_scorer.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task2vec.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task_similarity.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/utils.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task2vec.py +544 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task_similarity.py +485 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/utils.py +76 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec_scorer.py +76 -0
- DataFlow/dataflow/operators/eval/GeneralText/diversity/vendi_scorer.py +36 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bert_scorer.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bleu_scorer.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/cider_scorer.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/bert_scorer.py +46 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__init__.py +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/__init__.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/bleu.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/bleu.py +236 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/bleu_scorer.py +47 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__init__.py +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/__init__.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/cider.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/cider/cider.py +134 -0
- DataFlow/dataflow/operators/eval/GeneralText/gen/cider_scorer.py +60 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/__pycache__/model.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/model.py +161 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/__pycache__/qurater_annotate.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/__pycache__/modeling_flash_llama.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/modeling_flash_llama.py +853 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/qurater_annotate.py +190 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/__pycache__/data_analysis.cpython-310.pyc +0 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/data_analysis.py +53 -0
- DataFlow/dataflow/operators/eval/GeneralText/models/__pycache__/debertav3_scorer.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
baidu.zip filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
baidu.zip filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Qwen2.5-Math/evaluation/latex2sympy/antlr-4.11.1-complete.jar filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Qwen2.5-Math/evaluation/data/tabmwp/test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
Reproducibility/evaluation_log/result/ernie_openseek/math/test_qwen25-math-cot_-1_seed0_t0.0_s0_e-1.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Reproducibility/cleaned_data/for_erniekit_training/sft_dataflow_finemath.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSLexer.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
Reproducibility/cleaned_data/for_erniekit_training/sft_dataflow_dolmino.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
Reproducibility/cleaned_data/dataflow_finemath.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSLexer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
Reproducibility/evaluation_log/result/ernie_dataflow/math/test_qwen25-math-cot_-1_seed0_t0.0_s0_e-1.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSParser.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
Qwen2.5-Math/evaluation/latex2sympy/gen/__pycache__/PSParser.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
Reproducibility/cleaned_data/dataflow_dolmino.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
ernie/ERNIE/examples/pre-training/demo_data/data-1-part1.idx filter=lfs diff=lfs merge=lfs -text
|
DataFlow/OpenSeek_CoT/Openseek_Math_Chain_of_Thoughts_pipeline.py.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataflow.operators.generate import (
|
| 2 |
+
QuestionCategoryClassifier,
|
| 3 |
+
QuestionDifficultyClassifier,
|
| 4 |
+
QuestionGenerator,
|
| 5 |
+
AnswerGenerator,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
from dataflow.operators.filter import (
|
| 9 |
+
QuestionFilter,
|
| 10 |
+
AnswerPipelineRoot,
|
| 11 |
+
AnswerFormatterFilter,
|
| 12 |
+
AnswerTokenLengthFilter,
|
| 13 |
+
AnswerGroundTruthFilter,
|
| 14 |
+
AnswerNgramFilter,
|
| 15 |
+
)
|
| 16 |
+
from dataflow.utils.storage import FileStorage
|
| 17 |
+
from dataflow.serving import APILLMServing_request, LocalModelLLMServing
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ReasoningPipeline():
|
| 21 |
+
def __init__(self):
|
| 22 |
+
|
| 23 |
+
self.storage = FileStorage(
|
| 24 |
+
# first_entry_file_name="/cpfs/user/boyuan/verl_workspace/baidu/DataFlow/demo/example_data/ReasoningPipeline/pipeline_math_short.json",
|
| 25 |
+
# first_entry_file_name="/cpfs/user/boyuan/verl_workspace/baidu/data/fulldata/math/dolmino-mix-1124-math-merged.jsonl",
|
| 26 |
+
first_entry_file_name="/cpfs/user/boyuan/verl_workspace/baidu/data/fulldata/math/second_half_math.jsonl",
|
| 27 |
+
cache_path="./second_half_math",
|
| 28 |
+
file_name_prefix="dataflow_cache_step",
|
| 29 |
+
cache_type="jsonl",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# use API server as LLM serving
|
| 33 |
+
llm_serving = APILLMServing_request(
|
| 34 |
+
api_url="http://10.39.1.99:23456/v1/chat/completions",
|
| 35 |
+
# api_url = "https://aistudio.baidu.com/llm/lmapi/v3",
|
| 36 |
+
# api_url= "https://api.deepseek.com/v1/chat/completions",
|
| 37 |
+
model_name="ernie300",
|
| 38 |
+
# model_name="ernie-4.5-turbo-128k-preview",
|
| 39 |
+
# model_name="qwen3",
|
| 40 |
+
max_workers=50
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
self.question_filter_step3 = QuestionFilter(
|
| 45 |
+
system_prompt="You are an expert in evaluating mathematical problems. Follow the user's instructions strictly and output your final judgment in the required JSON format.",
|
| 46 |
+
llm_serving=llm_serving
|
| 47 |
+
)
|
| 48 |
+
self.question_difficulty_classifier_step4 = QuestionDifficultyClassifier(
|
| 49 |
+
llm_serving=llm_serving
|
| 50 |
+
)
|
| 51 |
+
self.question_category_classifier_step5 = QuestionCategoryClassifier(
|
| 52 |
+
llm_serving=llm_serving
|
| 53 |
+
)
|
| 54 |
+
########################## branch ############################
|
| 55 |
+
# self.answer_pipeline_root_step6 = AnswerPipelineRoot()
|
| 56 |
+
########################## answer ############################
|
| 57 |
+
self.answer_generator_step7 = AnswerGenerator(
|
| 58 |
+
llm_serving=llm_serving
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.answer_format_filter_step8 = AnswerFormatterFilter()
|
| 62 |
+
|
| 63 |
+
self.answer_token_length_filter_step9 = AnswerTokenLengthFilter(
|
| 64 |
+
max_answer_token_length = 8192,
|
| 65 |
+
tokenizer_dir = "/cpfs/user/boyuan/verl_workspace/baidu/models300/qwen3",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.answer_groundtruth_filter_step10 = AnswerGroundTruthFilter()
|
| 69 |
+
|
| 70 |
+
self.answer_ngram_filter_step11 = AnswerNgramFilter(
|
| 71 |
+
min_score = 0.1,
|
| 72 |
+
max_score = 1.0,
|
| 73 |
+
ngrams = 5
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self):
|
| 77 |
+
|
| 78 |
+
# self.question_filter_step1.run(
|
| 79 |
+
# storage = self.storage.step(),
|
| 80 |
+
# input_key = "instruction",
|
| 81 |
+
# )
|
| 82 |
+
|
| 83 |
+
# self.question_gen_step2.run(
|
| 84 |
+
# storage = self.storage.step(),
|
| 85 |
+
# input_key = "instruction",
|
| 86 |
+
# )
|
| 87 |
+
|
| 88 |
+
self.question_filter_step3.run(
|
| 89 |
+
storage = self.storage.step(),
|
| 90 |
+
input_key = "instruction",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.question_difficulty_classifier_step4.run(
|
| 94 |
+
storage = self.storage.step(),
|
| 95 |
+
input_key = "instruction",
|
| 96 |
+
output_key = "question_difficulty"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.question_category_classifier_step5.run(
|
| 100 |
+
storage = self.storage.step(),
|
| 101 |
+
input_key = "instruction",
|
| 102 |
+
output_key = "question_category"
|
| 103 |
+
)
|
| 104 |
+
############# branch #############
|
| 105 |
+
# self.answer_pipeline_root_step6.run(
|
| 106 |
+
# storage = self.storage.step(),
|
| 107 |
+
# input_answer_key = "output",
|
| 108 |
+
# input_gt_key = "golden_answer"
|
| 109 |
+
# )
|
| 110 |
+
############## answer #############
|
| 111 |
+
self.answer_generator_step7.run(
|
| 112 |
+
storage = self.storage.step(),
|
| 113 |
+
input_key = "instruction",
|
| 114 |
+
output_key = "generated_cot"
|
| 115 |
+
)
|
| 116 |
+
self.answer_format_filter_step8.run(
|
| 117 |
+
storage = self.storage.step(),
|
| 118 |
+
input_key = "generated_cot",
|
| 119 |
+
)
|
| 120 |
+
self.answer_token_length_filter_step9.run(
|
| 121 |
+
storage = self.storage.step(),
|
| 122 |
+
input_key = "generated_cot"
|
| 123 |
+
)
|
| 124 |
+
# self.answer_groundtruth_filter_step10.run(
|
| 125 |
+
# storage = self.storage.step(),
|
| 126 |
+
# test_answer_key = "generated_cot",
|
| 127 |
+
# gt_answer_key = "golden_answer"
|
| 128 |
+
# )
|
| 129 |
+
self.answer_ngram_filter_step11.run(
|
| 130 |
+
storage = self.storage.step(),
|
| 131 |
+
question_key = "instruction",
|
| 132 |
+
answer_key = "generated_cot"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
model = ReasoningPipeline()
|
| 137 |
+
model.forward()
|
DataFlow/dataflow/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import *
|
| 2 |
+
from .version import __version__, version_info
|
| 3 |
+
from .logger import get_logger
|
| 4 |
+
from .operators import *
|
| 5 |
+
__all__ = [
|
| 6 |
+
'__version__',
|
| 7 |
+
'version_info',
|
| 8 |
+
'get_logger',
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def hello():
|
| 14 |
+
return "Hello from open-dataflow!"
|
DataFlow/dataflow/cli.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from colorama import init, Fore, Style
|
| 6 |
+
|
| 7 |
+
# from dataflow.utils.paths import BencoPath
|
| 8 |
+
from dataflow.cli_funcs import cli_env, cli_init
|
| 9 |
+
import importlib.metadata
|
| 10 |
+
|
| 11 |
+
PYPI_API_URL = 'https://pypi.org/pypi/open-dataflow/json'
|
| 12 |
+
from dataflow.version import __version__
|
| 13 |
+
|
| 14 |
+
def version_and_check_for_updates():
|
| 15 |
+
# print a bar by the length of the shell width
|
| 16 |
+
print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
|
| 17 |
+
print(f'open-dataflow codebase version: {__version__}')
|
| 18 |
+
try:
|
| 19 |
+
response = requests.get(PYPI_API_URL, timeout=5)
|
| 20 |
+
response.raise_for_status() # 如果响应码不是200,则抛出异常
|
| 21 |
+
pypi_data = response.json()
|
| 22 |
+
cloud_version = pypi_data['info']['version'] # 获取最新版本号
|
| 23 |
+
# cloud_version = '0.1.21' # for debug & test
|
| 24 |
+
print("\tChecking for updates...")
|
| 25 |
+
print("\tLocal version: ", __version__)
|
| 26 |
+
print("\tPyPI newest version: ", cloud_version)
|
| 27 |
+
|
| 28 |
+
local_version = __version__ # 通过 importlib.metadata 获取当前安装版本
|
| 29 |
+
|
| 30 |
+
if cloud_version != local_version:
|
| 31 |
+
print(Fore.YELLOW + f"New version available: {cloud_version}. Your version: {local_version}." + Style.RESET_ALL)
|
| 32 |
+
print("Run 'pip install --upgrade open-dataflow' to upgrade.")
|
| 33 |
+
else:
|
| 34 |
+
print(Fore.GREEN + f"You are using the latest version: {local_version}." + Style.RESET_ALL)
|
| 35 |
+
except requests.exceptions.RequestException as e:
|
| 36 |
+
print(Fore.RED + "Failed to check for updates from PyPI. Please check your internet connection." + Style.RESET_ALL)
|
| 37 |
+
print(f"Error: {e}")
|
| 38 |
+
print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
|
| 39 |
+
def main():
|
| 40 |
+
parser = argparse.ArgumentParser(description='Command line interface for DataFlow, with codebase version: ' + __version__)
|
| 41 |
+
|
| 42 |
+
# Add version argument with update check only when user requests version
|
| 43 |
+
parser.add_argument('-v', '--version', action='store_true', help="Show the version of the tool")
|
| 44 |
+
|
| 45 |
+
subparsers = parser.add_subparsers(dest='command', required=False)
|
| 46 |
+
|
| 47 |
+
# init command
|
| 48 |
+
parser_init = subparsers.add_parser('init', help='Initialize the scripts and configs in a directory')
|
| 49 |
+
init_subparsers = parser_init.add_subparsers(dest='subcommand', required=False)
|
| 50 |
+
|
| 51 |
+
# init all
|
| 52 |
+
parser_init_all = init_subparsers.add_parser('all', help='Initialize all components')
|
| 53 |
+
parser_init_all.set_defaults(subcommand='all')
|
| 54 |
+
|
| 55 |
+
# init reasoning
|
| 56 |
+
parser_init_reasoning = init_subparsers.add_parser('reasoning', help='Initialize reasoning components')
|
| 57 |
+
parser_init_reasoning.set_defaults(subcommand='reasoning')
|
| 58 |
+
|
| 59 |
+
# env command
|
| 60 |
+
parser_env = subparsers.add_parser('env', help='Show environment information')
|
| 61 |
+
|
| 62 |
+
# parser.add_argument('--config', type=str, help='Path to the configuration file')
|
| 63 |
+
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
if args.version:
|
| 66 |
+
version_and_check_for_updates()
|
| 67 |
+
|
| 68 |
+
if args.command == 'init':
|
| 69 |
+
if args.subcommand is None:
|
| 70 |
+
args.subcommand = 'base'
|
| 71 |
+
cli_init(subcommand=args.subcommand)
|
| 72 |
+
# print("TODO Calling cli_init with subcommand:", args.subcommand)
|
| 73 |
+
from dataflow.cli_funcs.paths import DataFlowPath
|
| 74 |
+
# print(DataFlowPath.get_dataflow_dir())
|
| 75 |
+
# print(DataFlowPath.get_dataflow_scripts_dir())
|
| 76 |
+
elif args.command == 'env':
|
| 77 |
+
cli_env()
|
| 78 |
+
|
| 79 |
+
if __name__ == '__main__':
|
| 80 |
+
main()
|
DataFlow/dataflow/core/LLMServing.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LLMServingABC(ABC):
|
| 6 |
+
"""Abstract base class for data generators. Which may be used to generate data from a model or API. Called by operators
|
| 7 |
+
"""
|
| 8 |
+
@abstractmethod
|
| 9 |
+
def generate_from_input(self, user_inputs: List[str], system_prompt: str) -> List[str]:
|
| 10 |
+
"""
|
| 11 |
+
Generate data from input.
|
| 12 |
+
input: List[str], the input of the generator
|
| 13 |
+
"""
|
| 14 |
+
pass
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def cleanup(self):
|
| 17 |
+
"""
|
| 18 |
+
Cleanup the generator and garbage collect all GPU/CPU memory.
|
| 19 |
+
"""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
def load_model(self, model_name_or_path: str, **kwargs: Any):
|
| 23 |
+
"""
|
| 24 |
+
Load the model from the given path.
|
| 25 |
+
This method is optional and can be overridden by subclasses if needed.
|
| 26 |
+
"""
|
| 27 |
+
raise NotImplementedError("This method should be implemented by subclasses.")
|
DataFlow/dataflow/core/Operator.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataflow.logger import get_logger
|
| 3 |
+
|
| 4 |
+
class OperatorABC(ABC):
|
| 5 |
+
|
| 6 |
+
# @abstractmethod
|
| 7 |
+
# def check_config(self, config: dict) -> None:
|
| 8 |
+
# """
|
| 9 |
+
# Check the config of the operator. If config lacks any required keys, raise an error.
|
| 10 |
+
# """
|
| 11 |
+
# pass
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def run(self) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Main function to run the operator.
|
| 17 |
+
"""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def get_operator(operator_name, args) -> OperatorABC:
|
| 21 |
+
from dataflow.utils import OPERATOR_REGISTRY
|
| 22 |
+
print(operator_name, args)
|
| 23 |
+
operator = OPERATOR_REGISTRY.get(operator_name)(args)
|
| 24 |
+
logger = get_logger()
|
| 25 |
+
if operator is not None:
|
| 26 |
+
logger.info(f"Successfully get operator {operator_name}, args {args}")
|
| 27 |
+
else:
|
| 28 |
+
logger.error(f"operator {operator_name} is not found")
|
| 29 |
+
assert operator is not None
|
| 30 |
+
print(operator)
|
| 31 |
+
return operator
|
DataFlow/dataflow/core/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .Operator import OperatorABC, get_operator
|
| 2 |
+
from .LLMServing import LLMServingABC
|
| 3 |
+
__all__ = [
|
| 4 |
+
'OperatorABC',
|
| 5 |
+
'get_operator',
|
| 6 |
+
'LLMServingABC',
|
| 7 |
+
]
|
DataFlow/dataflow/core/__pycache__/LLMServing.cpython-310.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
DataFlow/dataflow/core/__pycache__/Operator.cpython-310.pyc
ADDED
|
Binary file (1.05 kB). View file
|
|
|
DataFlow/dataflow/core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (317 Bytes). View file
|
|
|
DataFlow/dataflow/logger.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import colorlog
|
| 3 |
+
|
| 4 |
+
# 自定义日志等级
|
| 5 |
+
SUCCESS_LEVEL_NUM = 25
|
| 6 |
+
logging.addLevelName(SUCCESS_LEVEL_NUM, "SUCCESS")
|
| 7 |
+
|
| 8 |
+
def success(self, message, *args, **kwargs):
|
| 9 |
+
if self.isEnabledFor(SUCCESS_LEVEL_NUM):
|
| 10 |
+
self._log(SUCCESS_LEVEL_NUM, message, args, **kwargs)
|
| 11 |
+
|
| 12 |
+
logging.Logger.success = success # 添加方法到 Logger 类
|
| 13 |
+
|
| 14 |
+
def get_logger(level=logging.INFO) -> logging.Logger:
|
| 15 |
+
# 创建logger对象
|
| 16 |
+
logger = logging.getLogger("DataFlow")
|
| 17 |
+
if not logger.handlers:
|
| 18 |
+
logger.setLevel(level)
|
| 19 |
+
# 创建控制台日志处理器
|
| 20 |
+
console_handler = logging.StreamHandler()
|
| 21 |
+
console_handler.setLevel(level)
|
| 22 |
+
# 定义颜色输出格式
|
| 23 |
+
color_formatter = colorlog.ColoredFormatter(
|
| 24 |
+
'%(log_color)s %(asctime)s | %(filename)-20s- %(module)-20s- %(funcName)-20s- %(lineno)5d - %(name)-10s | %(levelname)8s | Processno %(process)5d - Threadno %(thread)-15d : %(message)s',
|
| 25 |
+
log_colors={
|
| 26 |
+
'DEBUG': 'cyan',
|
| 27 |
+
# 'INFO': 'white',
|
| 28 |
+
'SUCCESS': 'green',
|
| 29 |
+
'WARNING': 'yellow',
|
| 30 |
+
'ERROR': 'red',
|
| 31 |
+
'CRITICAL': 'red,bg_white',
|
| 32 |
+
}
|
| 33 |
+
)
|
| 34 |
+
# 将颜色输出格式添加到控制台日志处理器
|
| 35 |
+
console_handler.setFormatter(color_formatter)
|
| 36 |
+
# 将控制台日志处理器添加到logger对象
|
| 37 |
+
logger.addHandler(console_handler)
|
| 38 |
+
return logger
|
DataFlow/dataflow/operators/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .eval import *
|
| 2 |
+
# from .generate import *
|
| 3 |
+
# from .filter import *
|
| 4 |
+
# from .refine import *
|
DataFlow/dataflow/operators/eval/AgenticRAG/statistics/f1_scorer.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import string
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from dataflow.core import OperatorABC
|
| 7 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 8 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 9 |
+
from dataflow import get_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@OPERATOR_REGISTRY.register()
|
| 13 |
+
class F1Scorer(OperatorABC):
|
| 14 |
+
|
| 15 |
+
def __init__(self, prediction_key, ground_truth_key):
|
| 16 |
+
self.logger = get_logger()
|
| 17 |
+
self.logger.info(f"Initializing {self.__class__.__name__}...")
|
| 18 |
+
self.prediction_key = prediction_key
|
| 19 |
+
self.ground_truth_key = ground_truth_key
|
| 20 |
+
self.logger.info(f"{self.__class__.__name__} initialized.")
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_desc(lang: str = "zh"):
|
| 24 |
+
return "用于评估预测答案与多个参考答案之间的 F1 分数"
|
| 25 |
+
|
| 26 |
+
def _validate_dataframe(self, dataframe: pd.DataFrame):
|
| 27 |
+
required_keys = [self.prediction_key, self.ground_truth_key]
|
| 28 |
+
forbidden_keys = [self.output_key ]
|
| 29 |
+
|
| 30 |
+
missing = [k for k in required_keys if k not in dataframe.columns]
|
| 31 |
+
conflict = [k for k in forbidden_keys if k in dataframe.columns]
|
| 32 |
+
|
| 33 |
+
if missing:
|
| 34 |
+
raise ValueError(f"Missing required column(s): {missing}")
|
| 35 |
+
if conflict:
|
| 36 |
+
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
|
| 37 |
+
|
| 38 |
+
def normalize_answer(self, s: str) -> str:
|
| 39 |
+
def remove_articles(text):
|
| 40 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
| 41 |
+
|
| 42 |
+
def white_space_fix(text):
|
| 43 |
+
return " ".join(text.split())
|
| 44 |
+
|
| 45 |
+
def remove_punc(text):
|
| 46 |
+
exclude = set(string.punctuation)
|
| 47 |
+
return "".join(ch for ch in text if ch not in exclude)
|
| 48 |
+
|
| 49 |
+
def lower(text):
|
| 50 |
+
return text.lower()
|
| 51 |
+
|
| 52 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
| 53 |
+
|
| 54 |
+
def compute_f1(self, prediction: str, ground_truths) -> float:
|
| 55 |
+
if prediction is None or ground_truths is None:
|
| 56 |
+
return 0.0
|
| 57 |
+
|
| 58 |
+
if isinstance(ground_truths, str):
|
| 59 |
+
ground_truths = [ground_truths]
|
| 60 |
+
|
| 61 |
+
max_f1 = 0.0
|
| 62 |
+
|
| 63 |
+
for ground_truth in ground_truths:
|
| 64 |
+
if ground_truth is None:
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
normalized_prediction = self.normalize_answer(prediction)
|
| 68 |
+
normalized_ground_truth = self.normalize_answer(ground_truth)
|
| 69 |
+
|
| 70 |
+
if normalized_prediction in ["yes", "no", "noanswer"] or normalized_ground_truth in ["yes", "no", "noanswer"]:
|
| 71 |
+
if normalized_prediction != normalized_ground_truth:
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
pred_tokens = normalized_prediction.split()
|
| 75 |
+
gold_tokens = normalized_ground_truth.split()
|
| 76 |
+
common = Counter(pred_tokens) & Counter(gold_tokens)
|
| 77 |
+
num_same = sum(common.values())
|
| 78 |
+
|
| 79 |
+
if num_same == 0:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
precision = num_same / len(pred_tokens)
|
| 83 |
+
recall = num_same / len(gold_tokens)
|
| 84 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
| 85 |
+
max_f1 = max(max_f1, f1)
|
| 86 |
+
|
| 87 |
+
return max_f1
|
| 88 |
+
|
| 89 |
+
def eval(self, dataframe: pd.DataFrame) -> list:
|
| 90 |
+
self.logger.info(f"Evaluating {self.output_key}...")
|
| 91 |
+
f1_scores = []
|
| 92 |
+
|
| 93 |
+
for _, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="F1Scorer Evaluating..."):
|
| 94 |
+
prediction = row.get(self.prediction_key, None)
|
| 95 |
+
ground_truths = row.get(self.ground_truth_key, None)
|
| 96 |
+
score = self.compute_f1(prediction, ground_truths)
|
| 97 |
+
f1_scores.append(score)
|
| 98 |
+
|
| 99 |
+
self.logger.info("Evaluation complete!")
|
| 100 |
+
return f1_scores
|
| 101 |
+
|
| 102 |
+
def run(self, storage: DataFlowStorage, output_key):
|
| 103 |
+
dataframe = storage.read("dataframe")
|
| 104 |
+
self.output_key = output_key
|
| 105 |
+
self._validate_dataframe(dataframe)
|
| 106 |
+
scores = self.eval(dataframe)
|
| 107 |
+
dataframe[self.output_key] = scores
|
| 108 |
+
storage.write(dataframe)
|
DataFlow/dataflow/operators/eval/GeneralText/APIcaller/__pycache__/perspective_scorer.cpython-310.pyc
ADDED
|
Binary file (2.3 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/APIcaller/meta_scorer.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 2 |
+
from dataflow import get_logger
|
| 3 |
+
from dataflow.core import OperatorABC
|
| 4 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from dataflow.core import LLMServingABC
|
| 7 |
+
from dataflow.prompts.general_text import MetaPrompt
|
| 8 |
+
import ast
|
| 9 |
+
|
| 10 |
+
@OPERATOR_REGISTRY.register()
|
| 11 |
+
class MetaScorer(OperatorABC):
|
| 12 |
+
def __init__(self, llm_serving: LLMServingABC = None):
|
| 13 |
+
self.logger = get_logger()
|
| 14 |
+
self.logger.info(f'Initializing {self.__class__.__name__}...')
|
| 15 |
+
self.llm_serving = llm_serving
|
| 16 |
+
self.score_name = 'MetaScore'
|
| 17 |
+
self.prompt = MetaPrompt()
|
| 18 |
+
self.logger.info(f'{self.__class__.__name__} initialized.')
|
| 19 |
+
|
| 20 |
+
self.output_columns = [
|
| 21 |
+
"Text Structure",
|
| 22 |
+
"Diversity & Complexity",
|
| 23 |
+
"Fluency & Understandability",
|
| 24 |
+
"Safety",
|
| 25 |
+
"Educational Value",
|
| 26 |
+
"Content Accuracy & Effectiveness"
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
def get_score(self, samples, input_key):
|
| 30 |
+
system_prompt = self.prompt.build_system_prompt()
|
| 31 |
+
user_prompts = []
|
| 32 |
+
for sample in samples:
|
| 33 |
+
input_text = sample.get(input_key, '')
|
| 34 |
+
user_prompt = self.prompt.build_user_prompt(input_text)
|
| 35 |
+
full_prompt = system_prompt + "\n" + user_prompt
|
| 36 |
+
user_prompts.append(full_prompt)
|
| 37 |
+
|
| 38 |
+
responses = self.llm_serving.generate_from_input(user_inputs=user_prompts)
|
| 39 |
+
scores = []
|
| 40 |
+
|
| 41 |
+
for i, response in enumerate(responses):
|
| 42 |
+
try:
|
| 43 |
+
lines = response.strip().split("\n")
|
| 44 |
+
last_line = lines[-1].strip()
|
| 45 |
+
parsed_scores = ast.literal_eval(last_line)
|
| 46 |
+
if isinstance(parsed_scores, list) and len(parsed_scores) == 6:
|
| 47 |
+
scores.append(parsed_scores)
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError("Score format invalid")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
self.logger.warning(f"Failed to extract score from response {i}: {e}")
|
| 52 |
+
scores.append([float('nan')] * 6)
|
| 53 |
+
|
| 54 |
+
return scores
|
| 55 |
+
|
| 56 |
+
def eval(self, dataframe: pd.DataFrame, input_key: str):
|
| 57 |
+
samples = dataframe.to_dict(orient='records')
|
| 58 |
+
self.logger.info(f"Evaluating {self.score_name}...")
|
| 59 |
+
scores = self.get_score(samples, input_key)
|
| 60 |
+
self.logger.info("Evaluation complete!")
|
| 61 |
+
return scores
|
| 62 |
+
|
| 63 |
+
def run(self, storage: DataFlowStorage, input_key: str):
|
| 64 |
+
self.input_key = input_key
|
| 65 |
+
dataframe = storage.read("dataframe")
|
| 66 |
+
scores = self.eval(dataframe, self.input_key)
|
| 67 |
+
# 展开6列固定命名
|
| 68 |
+
score_df = pd.DataFrame(scores, columns=self.output_columns)
|
| 69 |
+
dataframe = pd.concat([dataframe, score_df], axis=1)
|
| 70 |
+
storage.write(dataframe)
|
DataFlow/dataflow/operators/eval/GeneralText/APIcaller/treeinstruct_scorer.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 2 |
+
from dataflow import get_logger
|
| 3 |
+
from dataflow.core import OperatorABC
|
| 4 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from dataflow.core import LLMServingABC
|
| 7 |
+
from dataflow.prompts.general_text import TreeinstructPrompt
|
| 8 |
+
|
| 9 |
+
@OPERATOR_REGISTRY.register()
|
| 10 |
+
class TreeinstructScorer(OperatorABC):
|
| 11 |
+
def __init__(self, llm_serving: LLMServingABC = None):
|
| 12 |
+
self.logger = get_logger()
|
| 13 |
+
self.logger.info(f'Initializing {self.__class__.__name__}...')
|
| 14 |
+
self.llm_serving = llm_serving
|
| 15 |
+
self.score_name = 'TreeinstructScore'
|
| 16 |
+
self.prompt = TreeinstructPrompt()
|
| 17 |
+
self.logger.info(f'{self.__class__.__name__} initialized.')
|
| 18 |
+
|
| 19 |
+
def get_score(self, samples, input_instruction_key):
|
| 20 |
+
system_prompts = []
|
| 21 |
+
user_prompts = []
|
| 22 |
+
for sample in samples:
|
| 23 |
+
instruction = sample.get(input_instruction_key, [''])
|
| 24 |
+
system_prompts.append(self.prompt.build_system_prompt(instruction))
|
| 25 |
+
user_prompts.append(self.prompt.build_user_prompt())
|
| 26 |
+
|
| 27 |
+
inputs = [system + "\n" + user for system, user in zip(system_prompts, user_prompts)]
|
| 28 |
+
responses = self.llm_serving.generate_from_input(user_inputs=inputs)
|
| 29 |
+
|
| 30 |
+
scores = []
|
| 31 |
+
for response in responses:
|
| 32 |
+
response_lines = response.strip().split("\n")
|
| 33 |
+
score_line = response_lines[-1]
|
| 34 |
+
score = float(score_line.split()[0])
|
| 35 |
+
scores.append(score)
|
| 36 |
+
|
| 37 |
+
return scores
|
| 38 |
+
|
| 39 |
+
def eval(self, dataframe: pd.DataFrame, input_instruction_key: str):
|
| 40 |
+
self.logger.info(f"Evaluating {self.score_name}...")
|
| 41 |
+
samples = dataframe.to_dict(orient='records')
|
| 42 |
+
scores = self.get_score(samples, input_instruction_key)
|
| 43 |
+
self.logger.info("Evaluation complete!")
|
| 44 |
+
return scores
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def run(self, storage: DataFlowStorage, input_instruction_key: str, output_key: str='TreeinstructScore'):
|
| 48 |
+
self.input_instruction_key = input_instruction_key
|
| 49 |
+
self.output_key = output_key
|
| 50 |
+
dataframe = storage.read("dataframe")
|
| 51 |
+
scores = self.eval(dataframe, self.input_instruction_key)
|
| 52 |
+
dataframe[self.output_key] = scores
|
| 53 |
+
storage.write(dataframe)
|
DataFlow/dataflow/operators/eval/GeneralText/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .statistics.ngram_scorer import NgramScorer
|
| 2 |
+
# from .statistics.lexical_diversity_scorer import LexicalDiversityScorer
|
| 3 |
+
# from .statistics.langkit_scorer import LangkitScorer
|
| 4 |
+
|
| 5 |
+
# from .models.deita_quality_scorer import DeitaQualityScorer
|
| 6 |
+
# from .models.instag_scorer import InstagScorer
|
| 7 |
+
# from .models.debertav3_scorer import DebertaV3Scorer
|
| 8 |
+
# from .models.deita_complexity_scorer import DeitaComplexityScorer
|
| 9 |
+
# from .models.fineweb_edu_scorer import FineWebEduScorer
|
| 10 |
+
# from .models.pair_qual_scorer import PairQualScorer
|
| 11 |
+
# from .models.presidio_scorer import PresidioScorer
|
| 12 |
+
# from .models.rm_scorer import RMScorer
|
| 13 |
+
# from .models.textbook_scorer import TextbookScorer
|
| 14 |
+
# from .models.superfiltering_scorer import SuperfilteringScorer
|
| 15 |
+
# from .models.qurating_scorer import QuratingScorer
|
| 16 |
+
# from .models.perplexity_scorer import PerplexityScorer
|
| 17 |
+
|
| 18 |
+
# from .APIcaller.alpagasus_scorer import AlpagasusScorer
|
| 19 |
+
# from .APIcaller.treeinstruct_scorer import TreeinstructScorer
|
| 20 |
+
# from .APIcaller.perspective_scorer import PerspectiveScorer
|
| 21 |
+
# from .APIcaller.meta_scorer import MetaScorer
|
| 22 |
+
|
| 23 |
+
# from .diversity.vendi_scorer import VendiScorer
|
| 24 |
+
# from .diversity.task2vec_scorer import Task2VecScorer
|
| 25 |
+
|
| 26 |
+
# from .gen.bleu_scorer import BleuScorer
|
| 27 |
+
# from .gen.cider_scorer import CiderScorer
|
| 28 |
+
# from .gen.bert_scorer import BERTScorer
|
| 29 |
+
|
| 30 |
+
# __all__ = [
|
| 31 |
+
# 'NgramScorer',
|
| 32 |
+
# 'LexicalDiversityScorer',
|
| 33 |
+
# 'LangkitScorer',
|
| 34 |
+
# 'DeitaQualityScorer',
|
| 35 |
+
# 'InstagScorer',
|
| 36 |
+
# 'DebertaV3Scorer',
|
| 37 |
+
# 'DeitaComplexityScorer',
|
| 38 |
+
# 'FineWebEduScorer',
|
| 39 |
+
# 'PairQualScorer',
|
| 40 |
+
# 'PresidioScorer',
|
| 41 |
+
# 'RMScorer',
|
| 42 |
+
# 'TextbookScorer',
|
| 43 |
+
# 'SuperfilteringScorer',
|
| 44 |
+
# 'QuratingScorer',
|
| 45 |
+
# 'PerplexityScorer',
|
| 46 |
+
# 'AlpagasusScorer',
|
| 47 |
+
# 'TreeinstructScorer',
|
| 48 |
+
# 'PerspectiveScorer',
|
| 49 |
+
# "MetaScorer",
|
| 50 |
+
# 'VendiScorer',
|
| 51 |
+
# 'Task2VecScorer',
|
| 52 |
+
# 'BleuScorer',
|
| 53 |
+
# 'CiderScorer',
|
| 54 |
+
# 'BERTScorer'
|
| 55 |
+
# ]
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/task2vec_scorer.cpython-310.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/__pycache__/vendi_scorer.cpython-310.pyc
ADDED
|
Binary file (1.83 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task2vec.cpython-310.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/task_similarity.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task2vec.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
| 4 |
+
# may not use this file except in compliance with the License. A copy of
|
| 5 |
+
# the License is located at
|
| 6 |
+
#
|
| 7 |
+
# http://aws.amazon.com/apache2.0/
|
| 8 |
+
#
|
| 9 |
+
# or in the "license" file accompanying this file. This file is
|
| 10 |
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
| 11 |
+
# ANY KIND, either express or implied. See the License for the specific
|
| 12 |
+
# language governing permissions and limitations under the License.
|
| 13 |
+
|
| 14 |
+
import itertools
|
| 15 |
+
import math
|
| 16 |
+
import random
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from torch.optim import Optimizer
|
| 23 |
+
import numpy as np
|
| 24 |
+
from tqdm.auto import tqdm, trange
|
| 25 |
+
import logging
|
| 26 |
+
from torch.utils.data import DataLoader, Dataset
|
| 27 |
+
|
| 28 |
+
from .utils import AverageMeter, get_error, get_device
|
| 29 |
+
|
| 30 |
+
## LLM DIV
|
| 31 |
+
def set_seed(seed):
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
np.random.seed(seed)
|
| 34 |
+
torch.manual_seed(seed)
|
| 35 |
+
torch.cuda.manual_seed_all(seed)
|
| 36 |
+
|
| 37 |
+
## LLM DIV
|
| 38 |
+
def get_loss(logits: torch.tensor, targets: torch.tensor, ignore_index=None) -> torch.tensor:
|
| 39 |
+
"""
|
| 40 |
+
Computes the cross-entropy loss for either sequence classification or generation.
|
| 41 |
+
"""
|
| 42 |
+
assert logits.dim() == 3 and ignore_index is not None
|
| 43 |
+
loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
|
| 44 |
+
logits = logits[:,:-1,:]
|
| 45 |
+
logits = logits.transpose(1, 2) # batch_size, vocab_size (i.e. num_classes), sequence_length
|
| 46 |
+
targets = targets[:,1:]
|
| 47 |
+
|
| 48 |
+
return loss(logits, targets)
|
| 49 |
+
|
| 50 |
+
class Embedding:
|
| 51 |
+
"""
|
| 52 |
+
task_embedding = diagonal of the FIM for the filters of size [F_total, 1] total filters for a network.
|
| 53 |
+
|
| 54 |
+
Notes:
|
| 55 |
+
- the diagonal of the Fisher Information Matrix for each layer.
|
| 56 |
+
- embedding size should be the size of the total number of filters for the network.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, hessian, scale, meta=None):
|
| 60 |
+
self.hessian = np.array(hessian)
|
| 61 |
+
self.scale = np.array(scale)
|
| 62 |
+
self.meta = meta
|
| 63 |
+
|
| 64 |
+
def __repr__(self):
|
| 65 |
+
return f'{self.hessian}'
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ProbeNetwork(ABC, nn.Module):
|
| 69 |
+
"""Abstract class that all probe networks should inherit from.
|
| 70 |
+
|
| 71 |
+
This is a standard torch.nn.Module but needs to expose a classifier property that returns the final classicifation
|
| 72 |
+
module (e.g., the last fully connected layer).
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
@abstractmethod
|
| 77 |
+
def classifier(self):
|
| 78 |
+
raise NotImplementedError("Override the classifier property to return the submodules of the network that"
|
| 79 |
+
" should be interpreted as the classifier")
|
| 80 |
+
|
| 81 |
+
@classifier.setter
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def classifier(self, val):
|
| 84 |
+
raise NotImplementedError("Override the classifier setter to set the submodules of the network that"
|
| 85 |
+
" should be interpreted as the classifier")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Task2Vec:
|
| 89 |
+
|
| 90 |
+
def __init__(self, model: ProbeNetwork, skip_layers=0, max_samples=None, classifier_opts=None,
|
| 91 |
+
method='montecarlo', method_opts=None, loader_opts=None, bernoulli=False, mode='autoregressive'): ## LLM DIV
|
| 92 |
+
if classifier_opts is None:
|
| 93 |
+
classifier_opts = {}
|
| 94 |
+
if method_opts is None:
|
| 95 |
+
method_opts = {}
|
| 96 |
+
if loader_opts is None:
|
| 97 |
+
loader_opts = {}
|
| 98 |
+
assert method in ('variational', 'montecarlo')
|
| 99 |
+
assert skip_layers >= 0
|
| 100 |
+
|
| 101 |
+
self.model = model
|
| 102 |
+
# Fix batch norm running statistics (i.e., put batch_norm layers in eval mode)
|
| 103 |
+
self.model.train()
|
| 104 |
+
self.device = get_device(self.model)
|
| 105 |
+
self.skip_layers = skip_layers
|
| 106 |
+
self.max_samples = max_samples
|
| 107 |
+
self.classifier_opts = classifier_opts
|
| 108 |
+
self.method = method
|
| 109 |
+
self.method_opts = method_opts
|
| 110 |
+
self.loader_opts = loader_opts
|
| 111 |
+
self.bernoulli = bernoulli
|
| 112 |
+
self.mode = mode
|
| 113 |
+
if self.mode == "autoregressive":
|
| 114 |
+
self.loss_fn = get_loss
|
| 115 |
+
else:
|
| 116 |
+
self.loss_fn = nn.CrossEntropyLoss() if not self.bernoulli else nn.BCEWithLogitsLoss()
|
| 117 |
+
self.loss_fn = self.loss_fn.to(self.device)
|
| 118 |
+
|
| 119 |
+
def embed(self, dataset: Dataset, epochs: int = 5):
|
| 120 |
+
## LLM DIV
|
| 121 |
+
# Cache the last layer features (needed to train the classifier) and (if needed) the intermediate layer features
|
| 122 |
+
# so that we can skip the initial layers when computing the embedding
|
| 123 |
+
# dataset.train()
|
| 124 |
+
if self.mode == "autoregressive":
|
| 125 |
+
loss = None
|
| 126 |
+
print(f'{self.classifier_opts=}')
|
| 127 |
+
if self.classifier_opts: # is it something truthy? e.g., dict with something in it?
|
| 128 |
+
if self.classifier_opts.get('finetune', False): # finetune only if specified True, else no finetuning if not specified or False.
|
| 129 |
+
epochs = 0
|
| 130 |
+
print(f'Warning: classifier_opts doesnt specify finetune or break early, thus no finetuning is being done. See: {self.classifier_opts=} {epochs=}')
|
| 131 |
+
loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
|
| 132 |
+
else:
|
| 133 |
+
loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
|
| 134 |
+
else: # self.classifier_opts might be None or {}
|
| 135 |
+
loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
|
| 136 |
+
print(f'{loss=} (after fine tune, if not done it will be None)')
|
| 137 |
+
assert loss is not None, f'Err: {loss=}'
|
| 138 |
+
self.compute_fisher(dataset)
|
| 139 |
+
embedding = self.extract_embedding(self.model)
|
| 140 |
+
return embedding, loss
|
| 141 |
+
else:
|
| 142 |
+
if self.skip_layers > 0:
|
| 143 |
+
self._cache_features(dataset, indexes=(self.skip_layers, -1), loader_opts=self.loader_opts,
|
| 144 |
+
max_samples=self.max_samples)
|
| 145 |
+
else:
|
| 146 |
+
self._cache_features(dataset, max_samples=self.max_samples)
|
| 147 |
+
# Fits the last layer classifier using cached features
|
| 148 |
+
self._fit_classifier(**self.classifier_opts)
|
| 149 |
+
|
| 150 |
+
if self.skip_layers > 0:
|
| 151 |
+
dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features,
|
| 152 |
+
self.model.layers[-1].targets)
|
| 153 |
+
|
| 154 |
+
# dataset.eval() # I added this so that the embedding is computed on the val set
|
| 155 |
+
self.compute_fisher(dataset)
|
| 156 |
+
embedding = self.extract_embedding(self.model)
|
| 157 |
+
# dataset.train() # returns to using the support set
|
| 158 |
+
return embedding
|
| 159 |
+
|
| 160 |
+
### LLM DIV
|
| 161 |
+
def _finetune_classifier(self, dataset: Dataset, loader_opts: dict = None, classifier_opts: dict = None, max_samples=None, epochs = 5, learning_rate = 5e-5, adam_epsilon = 1e-8):
|
| 162 |
+
"""Fits the last layer of the HuggingFace transformer probe network."""
|
| 163 |
+
logging.info("Finetune classifier...")
|
| 164 |
+
if loader_opts is None:
|
| 165 |
+
loader_opts = {}
|
| 166 |
+
if classifier_opts is None:
|
| 167 |
+
classifier_opts = {}
|
| 168 |
+
data_loader = DataLoader(dataset, shuffle=False, batch_size=loader_opts.get('batch_size', 8),
|
| 169 |
+
num_workers=loader_opts.get('num_workers', 0), drop_last=False)
|
| 170 |
+
|
| 171 |
+
device = next(self.model.parameters()).device
|
| 172 |
+
print("MODEL DEVICE: ", device)
|
| 173 |
+
|
| 174 |
+
# num_examples = int(classifier_opts.get("task_batch_size", 256) / loader_opts.get('batch_size', 8))
|
| 175 |
+
num_examples = len(list(data_loader)) # not ideal but it's quicker in dev time, usually we won't feed the entire data set to task2vec so this should be fine
|
| 176 |
+
n_batches = num_examples
|
| 177 |
+
|
| 178 |
+
optimizer_grouped_parameters = [
|
| 179 |
+
{'params': [p for p in self.model.lm_head.parameters()],
|
| 180 |
+
'weight_decay': classifier_opts.get("weight_decay",0.0001)},
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=classifier_opts.get("learning_rate",learning_rate), eps=classifier_opts.get("adam_epsilon",adam_epsilon))
|
| 184 |
+
|
| 185 |
+
# Train!
|
| 186 |
+
logging.info("***** Running training *****")
|
| 187 |
+
# logging.info(" Num examples = %d", num_examples)
|
| 188 |
+
logging.info(" Num Epochs = %d", epochs)
|
| 189 |
+
logging.info(" Batch size = %d", loader_opts.get('batch_size', 8))
|
| 190 |
+
|
| 191 |
+
train_iterator = trange(classifier_opts.get("epochs", epochs), desc="Epoch", leave=False)
|
| 192 |
+
set_seed(classifier_opts.get("seed", 42)) # Added here for reproductibility (even between python 2 and 3)
|
| 193 |
+
|
| 194 |
+
self.model.train()
|
| 195 |
+
for epoch in train_iterator:
|
| 196 |
+
metrics = AverageMeter()
|
| 197 |
+
epoch_iterator = tqdm(data_loader, desc="Iteration", total=n_batches, leave=False)
|
| 198 |
+
for step, batch in enumerate(epoch_iterator):
|
| 199 |
+
optimizer.zero_grad()
|
| 200 |
+
inputs = {'input_ids': batch['input_ids'].to(device),
|
| 201 |
+
'attention_mask': batch['attention_mask'].to(device)}
|
| 202 |
+
logits = self.model(**inputs, labels=inputs["input_ids"]).logits
|
| 203 |
+
loss = self.loss_fn(logits, inputs["input_ids"], ignore_index=50256)
|
| 204 |
+
print(f'\nInitial loss {loss.item()} ({step=} {epoch=})') if step == 0 else None
|
| 205 |
+
error = get_error(logits, inputs['input_ids'], ignore_index=50256)
|
| 206 |
+
loss.backward()
|
| 207 |
+
optimizer.step()
|
| 208 |
+
|
| 209 |
+
metrics.update(n=batch['input_ids'].shape[0], loss=loss.item(), error=error)
|
| 210 |
+
epoch_iterator.update(1)
|
| 211 |
+
|
| 212 |
+
if classifier_opts.get("break_early", False):
|
| 213 |
+
print("----> breaking early")
|
| 214 |
+
break
|
| 215 |
+
if classifier_opts.get("break_early", False):
|
| 216 |
+
break
|
| 217 |
+
logging.info(f"[epoch {epoch}]: " + "\t".join(f"{k}: {v}" for k, v in metrics.avg.items()))
|
| 218 |
+
print(f'\nfinal loss {step=} {epoch=} of final layer loss {loss.item()} (note we are not recomputing loss after a step so this loss printed is larger than it should be/one off)')
|
| 219 |
+
return loss.item()
|
| 220 |
+
|
| 221 |
+
### LLM DIV
|
| 222 |
+
def montecarlo_fisher_autoregressive(self, dataset: Dataset, epochs: int = 1):
|
| 223 |
+
logging.info("Using montecarlo Fisher")
|
| 224 |
+
if self.loader_opts is None:
|
| 225 |
+
loader_opts = {}
|
| 226 |
+
else:
|
| 227 |
+
loader_opts = self.loader_opts
|
| 228 |
+
|
| 229 |
+
data_loader = DataLoader(dataset, shuffle=False, batch_size=loader_opts.get('batch_size', 8),
|
| 230 |
+
num_workers=loader_opts.get('num_workers', 0), drop_last=False)
|
| 231 |
+
device = get_device(self.model)
|
| 232 |
+
|
| 233 |
+
# num_examples = int(classifier_opts.get("task_batch_size", 256) / loader_opts.get('batch_size', 8))
|
| 234 |
+
num_examples = len(list(data_loader)) # not idea but it's quicker in dev time, usually we won't feed the entire data set to task2vec so this should be fine
|
| 235 |
+
n_batches = num_examples
|
| 236 |
+
|
| 237 |
+
logging.info("Computing Fisher...")
|
| 238 |
+
for p in self.model.parameters():
|
| 239 |
+
p.grad2_acc = torch.zeros_like(p.data)
|
| 240 |
+
p.grad_counter = 0
|
| 241 |
+
|
| 242 |
+
for k in range(epochs):
|
| 243 |
+
logging.info(f"\tepoch {k + 1}/{epochs}")
|
| 244 |
+
|
| 245 |
+
epoch_iterator = tqdm(data_loader, desc="Iteration", total=n_batches, leave=False)
|
| 246 |
+
for step, batch in enumerate(epoch_iterator):
|
| 247 |
+
inputs = {'input_ids': batch['input_ids'].to(device),
|
| 248 |
+
'attention_mask': batch['attention_mask'].to(device)}
|
| 249 |
+
logits = self.model(**inputs, labels=inputs["input_ids"]).logits
|
| 250 |
+
|
| 251 |
+
# The gradients used to compute the FIM needs to be for y sampled from
|
| 252 |
+
# the model distribution y ~ p_w(y|x), not for y from the dataset
|
| 253 |
+
if self.bernoulli:
|
| 254 |
+
target = torch.bernoulli(F.sigmoid(logits[:,:-1,:])).detach()
|
| 255 |
+
else:
|
| 256 |
+
softmax_output = F.softmax(logits, dim=-1)
|
| 257 |
+
lst = [torch.multinomial(softmax_output[i,:,:], 1).detach().view(-1) for i in range(len(softmax_output))]
|
| 258 |
+
target = torch.stack(lst, dim=0)
|
| 259 |
+
|
| 260 |
+
loss = self.loss_fn(logits, target, ignore_index=50256)
|
| 261 |
+
self.model.zero_grad()
|
| 262 |
+
loss.backward()
|
| 263 |
+
for p in self.model.parameters():
|
| 264 |
+
if p.grad is not None:
|
| 265 |
+
p.grad2_acc += p.grad.data ** 2
|
| 266 |
+
p.grad_counter += 1
|
| 267 |
+
if self.classifier_opts.get("break_early", False):
|
| 268 |
+
break # for debugging faster, otherwise FIM is really slow
|
| 269 |
+
if self.classifier_opts.get("break_early", False):
|
| 270 |
+
break # for debugging faster, otherwise FIM is really slow
|
| 271 |
+
for p in self.model.parameters():
|
| 272 |
+
if p.grad_counter == 0:
|
| 273 |
+
del p.grad2_acc
|
| 274 |
+
else:
|
| 275 |
+
p.grad2_acc /= p.grad_counter
|
| 276 |
+
logging.info("done")
|
| 277 |
+
|
| 278 |
+
def montecarlo_fisher(self, dataset: Dataset, epochs: int = 1):
|
| 279 |
+
logging.info("Using montecarlo Fisher")
|
| 280 |
+
if self.skip_layers > 0:
|
| 281 |
+
dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features,
|
| 282 |
+
self.model.layers[-1].targets)
|
| 283 |
+
data_loader = _get_loader(dataset, **self.loader_opts)
|
| 284 |
+
device = get_device(self.model)
|
| 285 |
+
logging.info("Computing Fisher...")
|
| 286 |
+
|
| 287 |
+
for p in self.model.parameters():
|
| 288 |
+
p.grad2_acc = torch.zeros_like(p.data)
|
| 289 |
+
p.grad_counter = 0
|
| 290 |
+
for k in range(epochs):
|
| 291 |
+
logging.info(f"\tepoch {k + 1}/{epochs}")
|
| 292 |
+
for i, (data, target) in enumerate(tqdm(data_loader, leave=False, desc="Computing Fisher")):
|
| 293 |
+
data = data.to(device)
|
| 294 |
+
output = self.model(data, start_from=self.skip_layers)
|
| 295 |
+
# The gradients used to compute the FIM needs to be for y sampled from
|
| 296 |
+
# the model distribution y ~ p_w(y|x), not for y from the dataset
|
| 297 |
+
if self.bernoulli:
|
| 298 |
+
target = torch.bernoulli(F.sigmoid(output)).detach()
|
| 299 |
+
else:
|
| 300 |
+
target = torch.multinomial(F.softmax(output, dim=-1), 1).detach().view(-1)
|
| 301 |
+
loss = self.loss_fn(output, target)
|
| 302 |
+
self.model.zero_grad()
|
| 303 |
+
loss.backward()
|
| 304 |
+
for p in self.model.parameters():
|
| 305 |
+
if p.grad is not None:
|
| 306 |
+
p.grad2_acc += p.grad.data ** 2
|
| 307 |
+
p.grad_counter += 1
|
| 308 |
+
for p in self.model.parameters():
|
| 309 |
+
if p.grad_counter == 0:
|
| 310 |
+
del p.grad2_acc
|
| 311 |
+
else:
|
| 312 |
+
p.grad2_acc /= p.grad_counter
|
| 313 |
+
logging.info("done")
|
| 314 |
+
|
| 315 |
+
def _run_epoch(self, data_loader: DataLoader, model: ProbeNetwork, loss_fn,
|
| 316 |
+
optimizer: Optimizer, epoch: int, train: bool = True,
|
| 317 |
+
add_compression_loss: bool = False, skip_layers=0, beta=1.0e-7):
|
| 318 |
+
metrics = AverageMeter()
|
| 319 |
+
device = get_device(model)
|
| 320 |
+
|
| 321 |
+
for i, (input, target) in enumerate(tqdm(data_loader, leave=False, desc="Computing Fisher")):
|
| 322 |
+
input = input.to(device)
|
| 323 |
+
target = target.to(device)
|
| 324 |
+
output = model(input, start_from=skip_layers)
|
| 325 |
+
|
| 326 |
+
loss = loss_fn(output, target)
|
| 327 |
+
lz = beta * variational.get_compression_loss(model) if add_compression_loss else torch.zeros_like(loss)
|
| 328 |
+
loss += lz
|
| 329 |
+
|
| 330 |
+
error = get_error(output, target)
|
| 331 |
+
|
| 332 |
+
metrics.update(n=input.size(0), loss=loss.item(), lz=lz.item(), error=error)
|
| 333 |
+
if train:
|
| 334 |
+
optimizer.zero_grad()
|
| 335 |
+
loss.backward()
|
| 336 |
+
optimizer.step()
|
| 337 |
+
# logging.info(
|
| 338 |
+
print(
|
| 339 |
+
"{}: [{epoch}] ".format('Epoch' if train else '', epoch=epoch) +
|
| 340 |
+
"Data/Batch: {:.3f}/{:.3f} ".format(metrics.avg["data_time"], metrics.avg["batch_time"]) +
|
| 341 |
+
"Loss {:.3f} Lz: {:.3f} ".format(metrics.avg["loss"], metrics.avg["lz"]) +
|
| 342 |
+
"Error: {:.2f}".format(metrics.avg["error"])
|
| 343 |
+
)
|
| 344 |
+
return metrics.avg
|
| 345 |
+
|
| 346 |
+
def variational_fisher(self, dataset: Dataset, epochs=1, beta=1e-7):
|
| 347 |
+
logging.info("Training variational fisher...")
|
| 348 |
+
parameters = []
|
| 349 |
+
for layer in self.model.layers[self.skip_layers:-1]:
|
| 350 |
+
if isinstance(layer, nn.Module): # Skip lambda functions
|
| 351 |
+
variational.make_variational(layer)
|
| 352 |
+
parameters += variational.get_variational_vars(layer)
|
| 353 |
+
bn_params = []
|
| 354 |
+
# Allows batchnorm parameters to change
|
| 355 |
+
for m in self.model.modules():
|
| 356 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 357 |
+
bn_params += list(m.parameters())
|
| 358 |
+
# Avoids computing the gradients wrt to the weights to save time and memory
|
| 359 |
+
for p in self.model.parameters():
|
| 360 |
+
if p not in set(parameters) and p not in set(self.model.classifier.parameters()):
|
| 361 |
+
p.old_requires_grad = p.requires_grad
|
| 362 |
+
p.requires_grad = False
|
| 363 |
+
|
| 364 |
+
optimizer = torch.optim.Adam([
|
| 365 |
+
{'params': parameters},
|
| 366 |
+
{'params': bn_params, 'lr': 5e-4},
|
| 367 |
+
{'params': self.model.classifier.parameters(), 'lr': 5e-4}],
|
| 368 |
+
lr=1e-2, betas=(.9, 0.999))
|
| 369 |
+
if self.skip_layers > 0:
|
| 370 |
+
dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features,
|
| 371 |
+
self.model.layers[-1].targets)
|
| 372 |
+
train_loader = _get_loader(dataset, **self.loader_opts)
|
| 373 |
+
|
| 374 |
+
for epoch in range(epochs):
|
| 375 |
+
self._run_epoch(train_loader, self.model, self.loss_fn, optimizer, epoch, beta=beta,
|
| 376 |
+
add_compression_loss=True, train=True)
|
| 377 |
+
|
| 378 |
+
# Resets original value of requires_grad
|
| 379 |
+
for p in self.model.parameters():
|
| 380 |
+
if hasattr(p, 'old_requires_grad'):
|
| 381 |
+
p.requires_grad = p.old_requires_grad
|
| 382 |
+
del p.old_requires_grad
|
| 383 |
+
|
| 384 |
+
def compute_fisher(self, dataset: Dataset):
|
| 385 |
+
"""
|
| 386 |
+
Computes the Fisher Information of the weights of the model wrt the model output on the dataset and stores it.
|
| 387 |
+
|
| 388 |
+
The Fisher Information Matrix is defined as:
|
| 389 |
+
F = E_{x ~ dataset} E_{y ~ p_w(y|x)} [\nabla_w log p_w(y|x) \nabla_w log p_w(y|x)^t]
|
| 390 |
+
where p_w(y|x) is the output probability vector of the network and w are the weights of the network.
|
| 391 |
+
Notice that the label y is sampled from the model output distribution and not from the dataset.
|
| 392 |
+
|
| 393 |
+
This code only approximate the diagonal of F. The result is stored in the model layers and can be extracted
|
| 394 |
+
using the `get_fisher` method. Different approximation methods of the Fisher information matrix are available,
|
| 395 |
+
and can be selected in the __init__.
|
| 396 |
+
|
| 397 |
+
:param dataset: dataset with the task to compute the Fisher on
|
| 398 |
+
"""
|
| 399 |
+
if self.mode == 'autoregressive' and self.method == 'montecarlo':
|
| 400 |
+
fisher_fn = self.montecarlo_fisher_autoregressive
|
| 401 |
+
elif self.method == 'variational':
|
| 402 |
+
fisher_fn = self.variational_fisher
|
| 403 |
+
elif self.method == 'montecarlo':
|
| 404 |
+
fisher_fn = self.montecarlo_fisher
|
| 405 |
+
else:
|
| 406 |
+
raise ValueError(f"Invalid Fisher method {self.method}")
|
| 407 |
+
fisher_fn(dataset, **self.method_opts)
|
| 408 |
+
|
| 409 |
+
def _cache_features(self, dataset: Dataset, indexes=(-1,), max_samples=None, loader_opts: dict = None):
|
| 410 |
+
logging.info("Caching features...")
|
| 411 |
+
if loader_opts is None:
|
| 412 |
+
loader_opts = {}
|
| 413 |
+
data_loader = DataLoader(dataset, shuffle=False, batch_size=loader_opts.get('batch_size', 64),
|
| 414 |
+
num_workers=loader_opts.get('num_workers', 0), drop_last=False)
|
| 415 |
+
|
| 416 |
+
device = next(self.model.parameters()).device
|
| 417 |
+
|
| 418 |
+
def _hook(layer, inputs):
|
| 419 |
+
if not hasattr(layer, 'input_features'):
|
| 420 |
+
layer.input_features = []
|
| 421 |
+
layer.input_features.append(inputs[0].data.cpu().clone())
|
| 422 |
+
|
| 423 |
+
hooks = [self.model.layers[index].register_forward_pre_hook(_hook)
|
| 424 |
+
for index in indexes]
|
| 425 |
+
if max_samples is not None:
|
| 426 |
+
n_batches = min(
|
| 427 |
+
math.floor(max_samples / data_loader.batch_size) - 1, len(data_loader))
|
| 428 |
+
else:
|
| 429 |
+
n_batches = len(data_loader)
|
| 430 |
+
targets = []
|
| 431 |
+
|
| 432 |
+
for i, (input, target) in tqdm(enumerate(itertools.islice(data_loader, 0, n_batches)), total=n_batches,
|
| 433 |
+
leave=False,
|
| 434 |
+
desc="Caching features"):
|
| 435 |
+
targets.append(target.clone())
|
| 436 |
+
self.model(input.to(device))
|
| 437 |
+
for hook in hooks:
|
| 438 |
+
hook.remove()
|
| 439 |
+
for index in indexes:
|
| 440 |
+
self.model.layers[index].input_features = torch.cat(self.model.layers[index].input_features)
|
| 441 |
+
self.model.layers[-1].targets = torch.cat(targets)
|
| 442 |
+
|
| 443 |
+
def _fit_classifier(self, optimizer='adam', learning_rate=0.0004, weight_decay=0.0001,
|
| 444 |
+
epochs=10):
|
| 445 |
+
"""Fits the last layer of the network using the cached features."""
|
| 446 |
+
logging.info("Fitting final classifier...")
|
| 447 |
+
if not hasattr(self.model.classifier, 'input_features'):
|
| 448 |
+
raise ValueError("You need to run `cache_features` on model before running `fit_classifier`")
|
| 449 |
+
targets = self.model.classifier.targets.to(self.device)
|
| 450 |
+
features = self.model.classifier.input_features.to(self.device)
|
| 451 |
+
|
| 452 |
+
dataset = torch.utils.data.TensorDataset(features, targets)
|
| 453 |
+
data_loader = _get_loader(dataset, **self.loader_opts)
|
| 454 |
+
|
| 455 |
+
if optimizer == 'adam':
|
| 456 |
+
optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
| 457 |
+
elif optimizer == 'sgd':
|
| 458 |
+
optimizer = torch.optim.SGD(self.model.fc.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
| 459 |
+
else:
|
| 460 |
+
raise ValueError(f'Unsupported optimizer {optimizer}')
|
| 461 |
+
|
| 462 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 463 |
+
for epoch in tqdm(range(epochs), desc="Fitting classifier", leave=False):
|
| 464 |
+
metrics = AverageMeter()
|
| 465 |
+
for data, target in data_loader:
|
| 466 |
+
optimizer.zero_grad()
|
| 467 |
+
output = self.model.classifier(data)
|
| 468 |
+
loss = loss_fn(self.model.classifier(data), target)
|
| 469 |
+
error = get_error(output, target)
|
| 470 |
+
loss.backward()
|
| 471 |
+
optimizer.step()
|
| 472 |
+
metrics.update(n=data.size(0), loss=loss.item(), error=error)
|
| 473 |
+
logging.info(f"[epoch {epoch}]: " + "\t".join(f"{k}: {v}" for k, v in metrics.avg.items()))
|
| 474 |
+
print(f'\nfinal loss after fitting final layer {loss=}')
|
| 475 |
+
|
| 476 |
+
def extract_embedding(self, model: ProbeNetwork):
|
| 477 |
+
"""
|
| 478 |
+
Reads the values stored by `compute_fisher` and returns them in a common format that describes the diagonal of the
|
| 479 |
+
Fisher Information Matrix for each layer.
|
| 480 |
+
|
| 481 |
+
:param model:
|
| 482 |
+
:return:
|
| 483 |
+
"""
|
| 484 |
+
if self.mode == 'autoregressive':
|
| 485 |
+
hess, scale = [], []
|
| 486 |
+
for name, module in model.named_modules():
|
| 487 |
+
if module is model.lm_head:
|
| 488 |
+
continue
|
| 489 |
+
# The other Fisher approximation methods directly approximate the hessian at the minimum
|
| 490 |
+
if hasattr(module, 'weight') and hasattr(module.weight, 'grad2_acc'):
|
| 491 |
+
grad2 = module.weight.grad2_acc.cpu().detach().numpy()
|
| 492 |
+
filterwise_hess = grad2.reshape(grad2.shape[0], -1).mean(axis=1)
|
| 493 |
+
hess.append(filterwise_hess)
|
| 494 |
+
scale.append(np.ones_like(filterwise_hess))
|
| 495 |
+
else:
|
| 496 |
+
hess, scale = [], []
|
| 497 |
+
for name, module in model.named_modules():
|
| 498 |
+
if module is model.classifier:
|
| 499 |
+
continue
|
| 500 |
+
# The variational Fisher approximation estimates the variance of noise that can be added to the weights
|
| 501 |
+
# without increasing the error more than a threshold. The inverse of this is proportional to an
|
| 502 |
+
# approximation of the hessian in the local minimum.
|
| 503 |
+
if hasattr(module, 'logvar0') and hasattr(module, 'loglambda2'):
|
| 504 |
+
logvar = module.logvar0.view(-1).detach().cpu().numpy()
|
| 505 |
+
hess.append(np.exp(-logvar))
|
| 506 |
+
loglambda2 = module.loglambda2.detach().cpu().numpy()
|
| 507 |
+
scale.append(np.exp(-loglambda2).repeat(logvar.size))
|
| 508 |
+
# The other Fisher approximation methods directly approximate the hessian at the minimum
|
| 509 |
+
elif hasattr(module, 'weight') and hasattr(module.weight, 'grad2_acc'):
|
| 510 |
+
grad2 = module.weight.grad2_acc.cpu().detach().numpy()
|
| 511 |
+
filterwise_hess = grad2.reshape(grad2.shape[0], -1).mean(axis=1)
|
| 512 |
+
hess.append(filterwise_hess)
|
| 513 |
+
scale.append(np.ones_like(filterwise_hess))
|
| 514 |
+
return Embedding(hessian=np.concatenate(hess), scale=np.concatenate(scale), meta=None)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def _get_loader(trainset, testset=None, batch_size=64, num_workers=0, num_samples=10000, drop_last=True):
|
| 518 |
+
if getattr(trainset, 'is_multi_label', False):
|
| 519 |
+
raise ValueError("Multi-label datasets not supported")
|
| 520 |
+
# TODO: Find a way to standardize this
|
| 521 |
+
if hasattr(trainset, 'labels'):
|
| 522 |
+
labels = trainset.labels
|
| 523 |
+
elif hasattr(trainset, 'targets'):
|
| 524 |
+
labels = trainset.targets
|
| 525 |
+
else:
|
| 526 |
+
labels = list(trainset.tensors[1].cpu().numpy())
|
| 527 |
+
num_classes = int(getattr(trainset, 'num_classes', max(labels) + 1))
|
| 528 |
+
class_count = np.eye(num_classes)[labels].sum(axis=0)
|
| 529 |
+
weights = 1. / class_count[labels] / num_classes
|
| 530 |
+
weights /= weights.sum()
|
| 531 |
+
|
| 532 |
+
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=num_samples)
|
| 533 |
+
# No need for mutli-threaded loading if everything is already in memory,
|
| 534 |
+
# and would raise an error if TensorDataset is on CUDA
|
| 535 |
+
num_workers = num_workers if not isinstance(trainset, torch.utils.data.TensorDataset) else 0
|
| 536 |
+
trainloader = torch.utils.data.DataLoader(trainset, sampler=sampler, batch_size=batch_size,
|
| 537 |
+
num_workers=num_workers, drop_last=drop_last)
|
| 538 |
+
|
| 539 |
+
if testset is None:
|
| 540 |
+
return trainloader
|
| 541 |
+
else:
|
| 542 |
+
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, pin_memory=True, shuffle=False,
|
| 543 |
+
num_workers=num_workers)
|
| 544 |
+
return trainloader, testloader
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/task_similarity.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
| 6 |
+
# may not use this file except in compliance with the License. A copy of
|
| 7 |
+
# the License is located at
|
| 8 |
+
#
|
| 9 |
+
# http://aws.amazon.com/apache2.0/
|
| 10 |
+
#
|
| 11 |
+
# or in the "license" file accompanying this file. This file is
|
| 12 |
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
| 13 |
+
# ANY KIND, either express or implied. See the License for the specific
|
| 14 |
+
# language governing permissions and limitations under the License.
|
| 15 |
+
|
| 16 |
+
import itertools
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
import scipy.spatial.distance as distance
|
| 20 |
+
import numpy as np
|
| 21 |
+
import copy
|
| 22 |
+
import pickle
|
| 23 |
+
|
| 24 |
+
# import uutils
|
| 25 |
+
|
| 26 |
+
_DISTANCES = {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# TODO: Remove methods that do not perform well
|
| 30 |
+
|
| 31 |
+
def _register_distance(distance_fn):
|
| 32 |
+
_DISTANCES[distance_fn.__name__] = distance_fn
|
| 33 |
+
return distance_fn
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def is_excluded(k):
|
| 37 |
+
exclude = ['fc', 'linear']
|
| 38 |
+
return any([e in k for e in exclude])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_embedding(filename):
|
| 42 |
+
with open(filename, 'rb') as f:
|
| 43 |
+
e = pickle.load(f)
|
| 44 |
+
return e
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_trivial_embedding_from(e):
|
| 48 |
+
trivial_embedding = copy.deepcopy(e)
|
| 49 |
+
for l in trivial_embedding['layers']:
|
| 50 |
+
a = np.array(l['filter_logvar'])
|
| 51 |
+
a[:] = l['filter_lambda2']
|
| 52 |
+
l['filter_logvar'] = list(a)
|
| 53 |
+
return trivial_embedding
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def binary_entropy(p):
|
| 57 |
+
from scipy.special import xlogy
|
| 58 |
+
return - (xlogy(p, p) + xlogy(1. - p, 1. - p))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_layerwise_variance(e, normalized=False):
|
| 62 |
+
var = [np.exp(l['filter_logvar']) for l in e['layers']]
|
| 63 |
+
if normalized:
|
| 64 |
+
var = [v / np.linalg.norm(v) for v in var]
|
| 65 |
+
return var
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_variance(e, normalized=False):
|
| 69 |
+
var = 1. / np.array(e.hessian)
|
| 70 |
+
if normalized:
|
| 71 |
+
lambda2 = 1. / np.array(e.scale)
|
| 72 |
+
var = var / lambda2
|
| 73 |
+
return var
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_variances(*embeddings, normalized=False):
|
| 77 |
+
return [get_variance(e, normalized=normalized) for e in embeddings]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_hessian(e, normalized=False):
|
| 81 |
+
hess = np.array(e.hessian)
|
| 82 |
+
if normalized:
|
| 83 |
+
scale = np.array(e.scale)
|
| 84 |
+
hess = hess / scale
|
| 85 |
+
return hess
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_hessians(*embeddings, normalized=False):
|
| 89 |
+
return [get_hessian(e, normalized=normalized) for e in embeddings]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_scaled_hessian(e0, e1):
|
| 93 |
+
h0, h1 = get_hessians(e0, e1, normalized=False)
|
| 94 |
+
return h0 / (h0 + h1 + 1e-8), h1 / (h0 + h1 + 1e-8)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_full_kl(e0, e1):
|
| 98 |
+
var0, var1 = get_variance(e0), get_variance(e1)
|
| 99 |
+
kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))
|
| 100 |
+
kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1))
|
| 101 |
+
return kl0, kl1
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def layerwise_kl(e0, e1):
|
| 105 |
+
layers0, layers1 = get_layerwise_variance(e0), get_layerwise_variance(e1)
|
| 106 |
+
kl0 = []
|
| 107 |
+
for var0, var1 in zip(layers0, layers1):
|
| 108 |
+
kl0.append(np.sum(.5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))))
|
| 109 |
+
return kl0
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def layerwise_cosine(e0, e1):
|
| 113 |
+
layers0, layers1 = get_layerwise_variance(e0, normalized=True), get_layerwise_variance(e1, normalized=True)
|
| 114 |
+
res = []
|
| 115 |
+
for var0, var1 in zip(layers0, layers1):
|
| 116 |
+
res.append(distance.cosine(var0, var1))
|
| 117 |
+
return res
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@_register_distance
|
| 121 |
+
def kl(e0, e1):
|
| 122 |
+
var0, var1 = get_variance(e0), get_variance(e1)
|
| 123 |
+
kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))
|
| 124 |
+
kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1))
|
| 125 |
+
return np.maximum(kl0, kl1).sum()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@_register_distance
|
| 129 |
+
def asymmetric_kl(e0, e1):
|
| 130 |
+
var0, var1 = get_variance(e0), get_variance(e1)
|
| 131 |
+
kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0))
|
| 132 |
+
kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1))
|
| 133 |
+
return kl0.sum()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@_register_distance
|
| 137 |
+
def jsd(e0, e1):
|
| 138 |
+
var0, var1 = get_variance(e0), get_variance(e1)
|
| 139 |
+
var = .5 * (var0 + var1)
|
| 140 |
+
kl0 = .5 * (var0 / var - 1 + np.log(var) - np.log(var0))
|
| 141 |
+
kl1 = .5 * (var1 / var - 1 + np.log(var) - np.log(var1))
|
| 142 |
+
return (.5 * (kl0 + kl1)).mean()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@_register_distance
|
| 146 |
+
def cosine(e0, e1):
|
| 147 |
+
h1, h2 = get_scaled_hessian(e0, e1)
|
| 148 |
+
return distance.cosine(h1, h2)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@_register_distance
|
| 152 |
+
def normalized_cosine(e0, e1):
|
| 153 |
+
h1, h2 = get_variances(e0, e1, normalized=True)
|
| 154 |
+
return distance.cosine(h1, h2)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@_register_distance
|
| 158 |
+
def correlation(e0, e1):
|
| 159 |
+
v1, v2 = get_variances(e0, e1, normalized=False)
|
| 160 |
+
return distance.correlation(v1, v2)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@_register_distance
|
| 164 |
+
def entropy(e0, e1):
|
| 165 |
+
h1, h2 = get_scaled_hessian(e0, e1)
|
| 166 |
+
return np.log(2) - binary_entropy(h1).mean()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_normalized_embeddings(embeddings, normalization=None):
|
| 170 |
+
F = [1. / get_variance(e, normalized=False) if e is not None else None for e in embeddings]
|
| 171 |
+
zero_embedding = np.zeros_like([x for x in F if x is not None][0])
|
| 172 |
+
F = np.array([x if x is not None else zero_embedding for x in F])
|
| 173 |
+
# FIXME: compute variance using only valid embeddings
|
| 174 |
+
if normalization is None:
|
| 175 |
+
normalization = np.sqrt((F ** 2).mean(axis=0, keepdims=True))
|
| 176 |
+
F /= normalization
|
| 177 |
+
return F, normalization
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def pdist(embeddings, distance='cosine') -> np.ndarray:
|
| 181 |
+
distance_fn = _DISTANCES[distance]
|
| 182 |
+
n = len(embeddings)
|
| 183 |
+
distance_matrix = np.zeros([n, n])
|
| 184 |
+
if distance != 'asymmetric_kl':
|
| 185 |
+
for (i, e1), (j, e2) in itertools.combinations(enumerate(embeddings), 2):
|
| 186 |
+
distance_matrix[i, j] = distance_fn(e1, e2)
|
| 187 |
+
distance_matrix[j, i] = distance_matrix[i, j]
|
| 188 |
+
else:
|
| 189 |
+
for (i, e1) in enumerate(embeddings):
|
| 190 |
+
for (j, e2) in enumerate(embeddings):
|
| 191 |
+
distance_matrix[i, j] = distance_fn(e1, e2)
|
| 192 |
+
return distance_matrix
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def cross_pdist(embeddings1, embeddings2, distance='cosine') -> np.ndarray :
|
| 196 |
+
"""
|
| 197 |
+
Compute pairwise distance between embeddings1 and embeddings2.
|
| 198 |
+
|
| 199 |
+
ref: https://chat.openai.com/share/a5ca38dc-3393-4cfd-971c-4a29b0c56b63
|
| 200 |
+
"""
|
| 201 |
+
distance_fn = _DISTANCES[distance]
|
| 202 |
+
n1 = len(embeddings1)
|
| 203 |
+
n2 = len(embeddings2)
|
| 204 |
+
distance_matrix = np.zeros([n1, n2])
|
| 205 |
+
if distance != 'asymmetric_kl':
|
| 206 |
+
for i, e1 in enumerate(embeddings1):
|
| 207 |
+
for j, e2 in enumerate(embeddings2):
|
| 208 |
+
distance_matrix[i, j] = distance_fn(e1, e2)
|
| 209 |
+
else:
|
| 210 |
+
for i, e1 in enumerate(embeddings1):
|
| 211 |
+
for j, e2 in enumerate(embeddings2):
|
| 212 |
+
distance_matrix[i, j] = distance_fn(e1, e2)
|
| 213 |
+
return distance_matrix
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def cdist(from_embeddings, to_embeddings, distance='cosine'):
|
| 217 |
+
distance_fn = _DISTANCES[distance]
|
| 218 |
+
distance_matrix = np.zeros([len(from_embeddings), len(to_embeddings)])
|
| 219 |
+
for (i, e1) in enumerate(from_embeddings):
|
| 220 |
+
for (j, e2) in enumerate(to_embeddings):
|
| 221 |
+
if e1 is None or e2 is None:
|
| 222 |
+
continue
|
| 223 |
+
distance_matrix[i, j] = distance_fn(e1, e2)
|
| 224 |
+
return distance_matrix
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def plot_distance_matrix(embeddings, labels=None, distance='cosine', show_plot=True):
|
| 228 |
+
import seaborn as sns
|
| 229 |
+
from scipy.cluster.hierarchy import linkage
|
| 230 |
+
from scipy.spatial.distance import squareform
|
| 231 |
+
import pandas as pd
|
| 232 |
+
import matplotlib.pyplot as plt
|
| 233 |
+
distance_matrix = pdist(embeddings, distance=distance)
|
| 234 |
+
cond_distance_matrix = squareform(distance_matrix, checks=False)
|
| 235 |
+
linkage_matrix = linkage(cond_distance_matrix, method='complete', optimal_ordering=True)
|
| 236 |
+
if labels is not None:
|
| 237 |
+
distance_matrix = pd.DataFrame(distance_matrix, index=labels, columns=labels)
|
| 238 |
+
sns.clustermap(distance_matrix, row_linkage=linkage_matrix, col_linkage=linkage_matrix, cmap='viridis_r')
|
| 239 |
+
if show_plot:
|
| 240 |
+
plt.show()
|
| 241 |
+
|
| 242 |
+
## LLM DIV
|
| 243 |
+
def plot_distance_matrix_heatmap_only(embeddings, labels=None, distance='cosine', show_plot=True, title=None, save_file=None):
|
| 244 |
+
import seaborn as sns
|
| 245 |
+
import pandas as pd
|
| 246 |
+
import matplotlib.pyplot as plt
|
| 247 |
+
distance_matrix = pdist(embeddings, distance=distance)
|
| 248 |
+
if labels is not None:
|
| 249 |
+
distance_matrix = pd.DataFrame(distance_matrix, index=labels, columns=labels)
|
| 250 |
+
sns.heatmap(distance_matrix, cmap='viridis_r')
|
| 251 |
+
if title:
|
| 252 |
+
plt.title(title)
|
| 253 |
+
if save_file:
|
| 254 |
+
_ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
|
| 255 |
+
if show_plot:
|
| 256 |
+
plt.show()
|
| 257 |
+
|
| 258 |
+
## LLM DIV
|
| 259 |
+
def plot_distance_matrix_from_distance_matrix(distance_matrix, labels=None, show_plot=True, title=None, save_file=None, cluster=False, plot_multi=False):
|
| 260 |
+
import seaborn as sns
|
| 261 |
+
from scipy.cluster.hierarchy import linkage
|
| 262 |
+
from scipy.spatial.distance import squareform
|
| 263 |
+
import pandas as pd
|
| 264 |
+
import matplotlib.pyplot as plt
|
| 265 |
+
|
| 266 |
+
cond_distance_matrix = squareform(distance_matrix, checks=False)
|
| 267 |
+
linkage_matrix = linkage(cond_distance_matrix, method='complete', optimal_ordering=True)
|
| 268 |
+
if labels is not None:
|
| 269 |
+
distance_matrix = pd.DataFrame(distance_matrix, index=labels, columns=labels)
|
| 270 |
+
|
| 271 |
+
# plot multiple subplots in one figure
|
| 272 |
+
# distance_matrix passed in is a list of distance_matrix (np.arrays)
|
| 273 |
+
if plot_multi and not cluster:
|
| 274 |
+
num_rows, num_cols = 3, 2
|
| 275 |
+
f, ax = plt.subplots(num_rows, num_cols)#, figsize=(12, 15))
|
| 276 |
+
i = 0
|
| 277 |
+
for row_ind in range(len(num_rows)):
|
| 278 |
+
for col_ind in range(len(num_cols)):
|
| 279 |
+
sns.heatmap(distance_matrix[i], cmap='viridis_r', ax=ax[row_ind, col_ind])
|
| 280 |
+
i += 1
|
| 281 |
+
else:
|
| 282 |
+
if cluster:
|
| 283 |
+
sns.clustermap(distance_matrix, row_linkage=linkage_matrix, col_linkage=linkage_matrix, cmap='viridis_r')
|
| 284 |
+
else:
|
| 285 |
+
sns.heatmap(distance_matrix, cmap='viridis_r')
|
| 286 |
+
|
| 287 |
+
if title:
|
| 288 |
+
plt.title(title)
|
| 289 |
+
if save_file:
|
| 290 |
+
_ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
|
| 291 |
+
if show_plot:
|
| 292 |
+
plt.show()
|
| 293 |
+
|
| 294 |
+
## LLM DIV
|
| 295 |
+
# plot multiple subplots in one figure
|
| 296 |
+
# distance_matrix passed in is a list of distance_matrix np.arrays
|
| 297 |
+
def plot_multi_distance_matrix_from_distance_matrix_list(distance_matrix_lst, title_lst, labels, main_title=None, show_plot=True, title=None, save_file=None, vmin=None, vmax=None):
|
| 298 |
+
import seaborn as sns
|
| 299 |
+
from scipy.cluster.hierarchy import linkage
|
| 300 |
+
from scipy.spatial.distance import squareform
|
| 301 |
+
import pandas as pd
|
| 302 |
+
import matplotlib.pyplot as plt
|
| 303 |
+
import math
|
| 304 |
+
num_rows, num_cols = math.ceil(len(distance_matrix_lst)/2), 2
|
| 305 |
+
if len(distance_matrix_lst) % 2 == 1:
|
| 306 |
+
figsize = (12,10)
|
| 307 |
+
else:
|
| 308 |
+
figsize = (12,10)
|
| 309 |
+
f, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
|
| 310 |
+
i = 0
|
| 311 |
+
for row_ind in range(num_rows):
|
| 312 |
+
for col_ind in range(num_cols):
|
| 313 |
+
if i >= len(distance_matrix_lst):
|
| 314 |
+
break
|
| 315 |
+
distance_matrix = distance_matrix_lst[i]
|
| 316 |
+
distance_matrix = pd.DataFrame(distance_matrix, index=labels[i], columns=labels[i])
|
| 317 |
+
if len(distance_matrix_lst) > 2:
|
| 318 |
+
ax[row_ind, col_ind].set_aspect('equal')
|
| 319 |
+
if vmin is not None and vmax is not None:
|
| 320 |
+
sns.heatmap(distance_matrix, cmap='viridis_r', ax=ax[row_ind, col_ind], vmin=vmin, vmax=vmax)
|
| 321 |
+
else:
|
| 322 |
+
sns.heatmap(distance_matrix, cmap='viridis_r', ax=ax[row_ind, col_ind])
|
| 323 |
+
ax[row_ind, col_ind].set_title(title_lst[i])
|
| 324 |
+
else:
|
| 325 |
+
ax[col_ind].set_aspect('equal')
|
| 326 |
+
sns.heatmap(distance_matrix, cmap='viridis_r', ax=ax[col_ind])
|
| 327 |
+
ax[col_ind].set_title(title_lst[i])
|
| 328 |
+
|
| 329 |
+
i += 1
|
| 330 |
+
if len(distance_matrix_lst) % 2 == 1:
|
| 331 |
+
f.delaxes(ax[num_rows-1,1])
|
| 332 |
+
|
| 333 |
+
if main_title:
|
| 334 |
+
f.suptitle(main_title)
|
| 335 |
+
f.subplots_adjust(top=0.5)
|
| 336 |
+
|
| 337 |
+
if len(distance_matrix_lst) % 2 == 1:
|
| 338 |
+
plt.tight_layout(h_pad=2)
|
| 339 |
+
else:
|
| 340 |
+
plt.tight_layout(h_pad=2, w_pad=5)
|
| 341 |
+
if save_file:
|
| 342 |
+
_ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
|
| 343 |
+
if show_plot:
|
| 344 |
+
plt.show()
|
| 345 |
+
|
| 346 |
+
## LLM DIV
|
| 347 |
+
def stats_of_distance_matrix(distance_matrix: np.ndarray,
|
| 348 |
+
remove_diagonal: bool = True,
|
| 349 |
+
variance_type: str = 'std', # TODO: was ci_0.95. Changed to rid uutils call
|
| 350 |
+
get_total: bool = False,
|
| 351 |
+
) -> Tuple[float, float]:
|
| 352 |
+
if remove_diagonal:
|
| 353 |
+
# - remove diagonal: ref https://stackoverflow.com/questions/46736258/deleting-diagonal-elements-of-a-numpy-array
|
| 354 |
+
triu: np.ndarray = np.triu(distance_matrix)
|
| 355 |
+
tril: np.ndarray = np.tril(distance_matrix)
|
| 356 |
+
# distance_matrix = distance_matrix[~np.eye(distance_matrix.shape[0], dtype=bool)].reshape(distance_matrix.shape[0], -1)
|
| 357 |
+
# remove diagonal and dummy zeros where the other triangular matrix was artificially placed.
|
| 358 |
+
distance_matrix = triu[triu != 0.0]
|
| 359 |
+
|
| 360 |
+
# - flatten
|
| 361 |
+
distance_matrix: np.ndarray = distance_matrix.flatten()
|
| 362 |
+
|
| 363 |
+
# - compute stats of distance matrix
|
| 364 |
+
if variance_type == 'std':
|
| 365 |
+
mu, var = distance_matrix.mean(), distance_matrix.std()
|
| 366 |
+
# elif variance_type == 'ci_0.95':
|
| 367 |
+
# from uutils.torch_uu.metrics.confidence_intervals import mean_confidence_interval
|
| 368 |
+
# mu, var = mean_confidence_interval(distance_matrix, confidence=0.95)
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(f'Invalid variance type, got: {variance_type=}')
|
| 371 |
+
|
| 372 |
+
# - double checks the mean was computed corrects. Since it's symmetric the mean after removing diagonal should be equal to just one side of the diagonals
|
| 373 |
+
if remove_diagonal:
|
| 374 |
+
# from uutils.torch_uu import approx_equal
|
| 375 |
+
# assert approx_equal(triu.sum(), tril.sum(), tolerance=1e-4), f'Distance matrix is not symmetric, are you sure this is correct?'
|
| 376 |
+
# assert approx_equal(distance_matrix.mean(), triu[triu != 0.0].mean(), tolerance=1e-4), f'Mean should be equal to triangular matrix'
|
| 377 |
+
# assert approx_equal(mu, triu[triu != 0.0].mean(), tolerance=1e-4)
|
| 378 |
+
|
| 379 |
+
print('Lower tri sum', tril.sum(), ' / Upper tri sum', triu.sum(), '| These should be approx equal!!')
|
| 380 |
+
print('Total mean', distance_matrix.mean(), ' / Upper mean', triu[triu != 0.0].mean(), ' / Lower mean', tril[tril != 0.0].mean(), '| These should all be approx equal!!')
|
| 381 |
+
print('mu (div coefficient)', mu, ' / Upper mean', triu[triu != 0.0].mean(), '| These should all be approx equal!!')
|
| 382 |
+
if get_total:
|
| 383 |
+
total = distance_matrix.sum()
|
| 384 |
+
return mu, var, total
|
| 385 |
+
else:
|
| 386 |
+
return mu, var
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def stats_cross_distance_matrix(distance_matrix: np.ndarray,
|
| 390 |
+
remove_diagonal: bool = False,
|
| 391 |
+
variance_type: str = 'std', # TODO: was ci_0.95. Changed to rid uutils call
|
| 392 |
+
get_total: bool = False,
|
| 393 |
+
) -> Tuple[float, float]:
|
| 394 |
+
return stats_of_distance_matrix(distance_matrix, remove_diagonal=remove_diagonal, variance_type=variance_type, get_total=get_total)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def plot_histogram_of_distances(distance_matrix: np.ndarray, title, show_plot=True, save_file=None, bins_width=None, grid=True):
|
| 398 |
+
import matplotlib.pyplot as plt
|
| 399 |
+
triu = np.triu(distance_matrix)
|
| 400 |
+
triu = triu[triu != 0.0]
|
| 401 |
+
distance_values = triu.flatten()
|
| 402 |
+
|
| 403 |
+
if grid:
|
| 404 |
+
plt.grid(zorder=0)
|
| 405 |
+
plt.axvline(np.mean(distance_values), color='k', linestyle='dashed', linewidth=1, zorder=4)
|
| 406 |
+
if bins_width is not None:
|
| 407 |
+
plt.hist(distance_values, edgecolor ="black", bins=np.arange(min(distance_values), max(distance_values) + bins_width, bins_width), zorder=3)
|
| 408 |
+
else:
|
| 409 |
+
plt.hist(distance_values, edgecolor ="black", zorder=3)
|
| 410 |
+
plt.title(title)
|
| 411 |
+
plt.xlabel("Cosine Distance between Task Pairs")
|
| 412 |
+
plt.ylabel("Frequency")
|
| 413 |
+
|
| 414 |
+
plt.tight_layout()
|
| 415 |
+
if save_file:
|
| 416 |
+
_ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
|
| 417 |
+
|
| 418 |
+
if show_plot:
|
| 419 |
+
plt.show()
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
## LLM DIV
|
| 423 |
+
# plot multiple subplots in one figure
|
| 424 |
+
# distance_matrix passed in is a list of distance_matrix (np.arrays)
|
| 425 |
+
def plot_multi_histogram_of_distances(distance_matrix_lst, title_lst, main_title=None, show_plot=True, save_file=None,
|
| 426 |
+
xlabel="Cosine Distance between Task Pairs", grid=True, bins_width=None,
|
| 427 |
+
num_cols=2, figsize=(12,10)):
|
| 428 |
+
import seaborn as sns
|
| 429 |
+
from scipy.cluster.hierarchy import linkage
|
| 430 |
+
from scipy.spatial.distance import squareform
|
| 431 |
+
import pandas as pd
|
| 432 |
+
import matplotlib.pyplot as plt
|
| 433 |
+
import math
|
| 434 |
+
|
| 435 |
+
if num_cols == 2:
|
| 436 |
+
num_rows = math.ceil(len(distance_matrix_lst)/2)
|
| 437 |
+
else:
|
| 438 |
+
num_rows = math.ceil(len(distance_matrix_lst)/num_cols)
|
| 439 |
+
|
| 440 |
+
f, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
|
| 441 |
+
i = 0
|
| 442 |
+
for row_ind in range(num_rows):
|
| 443 |
+
for col_ind in range(num_cols):
|
| 444 |
+
if i >= len(distance_matrix_lst):
|
| 445 |
+
break
|
| 446 |
+
triu = np.triu(distance_matrix_lst[i])
|
| 447 |
+
triu = triu[triu != 0.0]
|
| 448 |
+
distance_values = triu.flatten()
|
| 449 |
+
|
| 450 |
+
if len(distance_matrix_lst) > 2:
|
| 451 |
+
if grid:
|
| 452 |
+
ax[row_ind, col_ind].grid(zorder=0)
|
| 453 |
+
if bins_width is not None:
|
| 454 |
+
ax[row_ind, col_ind].hist(distance_values, edgecolor ="black", zorder=3, bins=np.arange(min(distance_values), max(distance_values) + bins_width, bins_width))
|
| 455 |
+
else:
|
| 456 |
+
ax[row_ind, col_ind].hist(distance_values, edgecolor ="black", zorder=3)
|
| 457 |
+
ax[row_ind, col_ind].set_xlabel(xlabel)
|
| 458 |
+
ax[row_ind, col_ind].set_ylabel("Frequency")
|
| 459 |
+
ax[row_ind, col_ind].axvline(np.mean(distance_values), color='k', linestyle='dashed', linewidth=1, zorder=4)
|
| 460 |
+
ax[row_ind, col_ind].set_title(title_lst[i])
|
| 461 |
+
else:
|
| 462 |
+
if grid:
|
| 463 |
+
ax[col_ind].grid(zorder=0)
|
| 464 |
+
ax[col_ind].hist(distance_values, edgecolor ="black", zorder=3)
|
| 465 |
+
if bins_width is not None:
|
| 466 |
+
ax[col_ind].hist(distance_values, edgecolor ="black", zorder=3, bins=np.arange(min(distance_values), max(distance_values) + bins_width, bins_width))
|
| 467 |
+
else:
|
| 468 |
+
ax[col_ind].hist(distance_values, edgecolor ="black", zorder=3)
|
| 469 |
+
ax[col_ind].set_xlabel(xlabel)
|
| 470 |
+
ax[col_ind].set_ylabel("Frequency")
|
| 471 |
+
ax[col_ind].set_title(title_lst[i])
|
| 472 |
+
i += 1
|
| 473 |
+
if len(distance_matrix_lst) % 2 == 1 and num_cols == 2:
|
| 474 |
+
f.delaxes(ax[num_rows-1,1])
|
| 475 |
+
|
| 476 |
+
if main_title:
|
| 477 |
+
f.suptitle(main_title)
|
| 478 |
+
f.subplots_adjust(top=1)
|
| 479 |
+
|
| 480 |
+
plt.grid(True)
|
| 481 |
+
plt.tight_layout()
|
| 482 |
+
if save_file:
|
| 483 |
+
_ = plt.savefig("plots/" + save_file + ".png", bbox_inches='tight')
|
| 484 |
+
if show_plot:
|
| 485 |
+
plt.show()
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec/utils.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
| 4 |
+
# may not use this file except in compliance with the License. A copy of
|
| 5 |
+
# the License is located at
|
| 6 |
+
#
|
| 7 |
+
# http://aws.amazon.com/apache2.0/
|
| 8 |
+
#
|
| 9 |
+
# or in the "license" file accompanying this file. This file is
|
| 10 |
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
| 11 |
+
# ANY KIND, either express or implied. See the License for the specific
|
| 12 |
+
# language governing permissions and limitations under the License.
|
| 13 |
+
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AverageMeter(object):
|
| 20 |
+
"""Computes and stores the average and current value"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.reset()
|
| 24 |
+
|
| 25 |
+
def reset(self):
|
| 26 |
+
self.val = defaultdict(int)
|
| 27 |
+
self.avg = defaultdict(float)
|
| 28 |
+
self.sum = defaultdict(int)
|
| 29 |
+
self.count = defaultdict(int)
|
| 30 |
+
|
| 31 |
+
def update(self, n=1, **val):
|
| 32 |
+
for k in val:
|
| 33 |
+
self.val[k] = val[k]
|
| 34 |
+
self.sum[k] += val[k] * n
|
| 35 |
+
self.count[k] += n
|
| 36 |
+
self.avg[k] = self.sum[k] / self.count[k]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def set_batchnorm_mode(model, train=True):
|
| 40 |
+
"""Allows to set batch_norm layer mode to train or eval, independendtly on the mode of the model."""
|
| 41 |
+
def _set_batchnorm_mode(module):
|
| 42 |
+
if isinstance(module, torch.nn.BatchNorm1d) or isinstance(module, torch.nn.BatchNorm2d):
|
| 43 |
+
if train:
|
| 44 |
+
module.train()
|
| 45 |
+
else:
|
| 46 |
+
module.eval()
|
| 47 |
+
|
| 48 |
+
model.apply(_set_batchnorm_mode)
|
| 49 |
+
|
| 50 |
+
### LLM DIV
|
| 51 |
+
def get_error(output, target, mode='autoregressive', ignore_index=None):
|
| 52 |
+
if mode == 'autoregressive': # output = logits here
|
| 53 |
+
assert ignore_index is not None
|
| 54 |
+
output = output[:,:-1,:]
|
| 55 |
+
logits_inds = torch.argmax(output, dim=-1)
|
| 56 |
+
target = target[:,1:]
|
| 57 |
+
if ignore_index is not None:
|
| 58 |
+
acc = torch.eq(logits_inds, target.unsqueeze(0))[:, target != ignore_index]
|
| 59 |
+
else:
|
| 60 |
+
acc = torch.eq(logits_inds, target.unsqueeze(0))
|
| 61 |
+
acc = acc.float().mean()
|
| 62 |
+
return 1 - acc
|
| 63 |
+
else:
|
| 64 |
+
pred = output.argmax(dim=1)
|
| 65 |
+
correct = pred.eq(target).float().sum()
|
| 66 |
+
return float((1. - correct / output.size(0)) * 100.)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def adjust_learning_rate(optimizer, epoch, optimizer_cfg):
|
| 70 |
+
lr = optimizer_cfg.lr * (0.1 ** np.less(optimizer_cfg.schedule, epoch).sum())
|
| 71 |
+
for param_group in optimizer.param_groups:
|
| 72 |
+
param_group['lr'] = lr
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_device(model: torch.nn.Module):
|
| 76 |
+
return next(model.parameters()).device
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/task2vec_scorer.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataflow.operators.eval.GeneralText.diversity.task2vec.task2vec import Task2Vec
|
| 2 |
+
from dataflow.operators.eval.GeneralText.diversity.task2vec import task_similarity
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 6 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 7 |
+
from dataflow.core import OperatorABC
|
| 8 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from dataflow import get_logger
|
| 11 |
+
from typing import Optional
|
| 12 |
+
# Task2Vec dataset diversity evaluation
|
| 13 |
+
# Cited from: Beyond Scale: the Diversity Coefficient as a Data Quality Metric Demonstrates LLMs are Pre-trained on Formally Diverse Data
|
| 14 |
+
@OPERATOR_REGISTRY.register()
|
| 15 |
+
class Task2VecScorer(OperatorABC):
|
| 16 |
+
def __init__(self, device='cuda', sample_nums=10, sample_size=1, method: Optional[str]='montecarlo', model_cache_dir='./dataflow_cache'):
|
| 17 |
+
self.logger = get_logger()
|
| 18 |
+
self.logger.info(f'Initializing {self.__class__.__name__}...')
|
| 19 |
+
# evaluating diversity by extract sample_nums * sample_size samples
|
| 20 |
+
self.sample_nums = sample_nums
|
| 21 |
+
self.sample_size = sample_size
|
| 22 |
+
self.device = device
|
| 23 |
+
self.model_cache_dir = model_cache_dir
|
| 24 |
+
self.score_name = 'Task2VecScore'
|
| 25 |
+
self.method = method
|
| 26 |
+
if method not in ['montecarlo', 'variational']:
|
| 27 |
+
raise ValueError(f"Invalid method '{method}'. Valid options are 'montecarlo' and 'variational'.")
|
| 28 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=self.model_cache_dir)
|
| 29 |
+
self.probe_network = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=self.model_cache_dir)
|
| 30 |
+
self.device = torch.device(self.device if self.device and torch.cuda.is_available() else "cpu")
|
| 31 |
+
self.probe_network = self.probe_network.to(self.device)
|
| 32 |
+
self.logger.info(f'{self.__class__.__name__} initialized.')
|
| 33 |
+
|
| 34 |
+
def preprocess(self, texts):
|
| 35 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 36 |
+
tokenized_outputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 37 |
+
return {key: value.to(self.device) for key, value in tokenized_outputs.items()}
|
| 38 |
+
|
| 39 |
+
def get_score(self, sentences):
|
| 40 |
+
embeddings = []
|
| 41 |
+
data_length = len(sentences)
|
| 42 |
+
for sample_num in range(self.sample_nums):
|
| 43 |
+
self.logger.info(f'--> Sample {sample_num + 1}/{self.sample_nums}')
|
| 44 |
+
indices = random.sample(range(data_length), self.sample_size)
|
| 45 |
+
texts = [sentences[i] for i in indices]
|
| 46 |
+
tokenized_batch = self.preprocess(texts)
|
| 47 |
+
tokenized_dataset = CustomTensorDataset(tokenized_batch)
|
| 48 |
+
embedding, _ = Task2Vec(self.probe_network, method=self.method).embed(tokenized_dataset)
|
| 49 |
+
embeddings.append(embedding)
|
| 50 |
+
distance_matrix = task_similarity.pdist(embeddings, distance='cosine')
|
| 51 |
+
div_coeff, conf_interval = task_similarity.stats_of_distance_matrix(distance_matrix)
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"Task2VecDiversityScore": div_coeff,
|
| 55 |
+
"ConfidenceInterval": conf_interval
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def run(self, storage: DataFlowStorage, input_key: str):
|
| 59 |
+
dataframe = storage.read("dataframe")
|
| 60 |
+
samples = dataframe[input_key].to_list()
|
| 61 |
+
self.logger.info(f"Evaluating {self.score_name}...")
|
| 62 |
+
task2vec_score = self.get_score(samples)
|
| 63 |
+
self.logger.info("Evaluation complete!")
|
| 64 |
+
self.logger.info(f"Task2Vec Diversity Score: {task2vec_score}")
|
| 65 |
+
return task2vec_score
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CustomTensorDataset(Dataset):
|
| 69 |
+
def __init__(self, tokenized_batch):
|
| 70 |
+
self.tokenized_batch = tokenized_batch
|
| 71 |
+
|
| 72 |
+
def __getitem__(self, index):
|
| 73 |
+
return {key: self.tokenized_batch[key][index] for key in self.tokenized_batch}
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(next(iter(self.tokenized_batch.values())))
|
DataFlow/dataflow/operators/eval/GeneralText/diversity/vendi_scorer.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vendi_score import text_utils
|
| 2 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from dataflow.core import OperatorABC
|
| 5 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 6 |
+
from dataflow import get_logger
|
| 7 |
+
|
| 8 |
+
# VendiScore dataset diversity evaluation
|
| 9 |
+
# Cited from: The Vendi Score: A Diversity Evaluation Metric for Machine Learning
|
| 10 |
+
@OPERATOR_REGISTRY.register()
|
| 11 |
+
class VendiScorer(OperatorABC):
|
| 12 |
+
def __init__(self, device='cuda'):
|
| 13 |
+
self.logger = get_logger()
|
| 14 |
+
self.logger.info(f'Initializing {self.__class__.__name__}...')
|
| 15 |
+
self.bert_model_path = 'bert-base-uncased'
|
| 16 |
+
self.simcse_model_path = 'princeton-nlp/unsup-simcse-bert-base-uncased'
|
| 17 |
+
self.device = device
|
| 18 |
+
self.score_name = 'VendiScore'
|
| 19 |
+
self.logger.info(f'{self.__class__.__name__} initialized.')
|
| 20 |
+
|
| 21 |
+
def get_score(self, sentences):
|
| 22 |
+
result = {}
|
| 23 |
+
bert_vs = text_utils.embedding_vendi_score(sentences, model_path=self.bert_model_path, device=self.device)
|
| 24 |
+
result["BERTVendiScore"] = round(bert_vs, 2)
|
| 25 |
+
simcse_vs = text_utils.embedding_vendi_score(sentences, model_path=self.simcse_model_path, device=self.device)
|
| 26 |
+
result["SimCSEVendiScore"] = round(simcse_vs, 2)
|
| 27 |
+
return result
|
| 28 |
+
|
| 29 |
+
def run(self, storage: DataFlowStorage, input_key: str):
|
| 30 |
+
dataframe = storage.read("dataframe")
|
| 31 |
+
samples = dataframe[input_key].to_list()
|
| 32 |
+
self.logger.info(f"Evaluating {self.score_name}...")
|
| 33 |
+
vendiscore = self.get_score(samples)
|
| 34 |
+
self.logger.info("Evaluation complete!")
|
| 35 |
+
self.logger.info(f"VendiScore: {vendiscore}")
|
| 36 |
+
return vendiscore
|
DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bert_scorer.cpython-310.pyc
ADDED
|
Binary file (2 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/bleu_scorer.cpython-310.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/gen/__pycache__/cider_scorer.cpython-310.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/gen/bert_scorer.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataflow.core import OperatorABC
|
| 2 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 3 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 4 |
+
from dataflow import get_logger
|
| 5 |
+
import evaluate
|
| 6 |
+
|
| 7 |
+
@OPERATOR_REGISTRY.register()
|
| 8 |
+
class BERTScorer(OperatorABC):
|
| 9 |
+
def __init__(self, lang='en', model_cache_dir='./dataflow_cache'):
|
| 10 |
+
self.logger = get_logger()
|
| 11 |
+
self.logger.info(f'Initializing {self.__class__.__name__}...')
|
| 12 |
+
self.data_type = "text"
|
| 13 |
+
self.score_name = "BERTScore"
|
| 14 |
+
self.lang = lang
|
| 15 |
+
self.model_type = "distilbert-base-uncased"
|
| 16 |
+
self.idf = False
|
| 17 |
+
self.rescale_with_baseline = False
|
| 18 |
+
self.bertscore = evaluate.load("bertscore", cache_dir=model_cache_dir)
|
| 19 |
+
self.logger.info(f'{self.__class__.__name__} initialized.')
|
| 20 |
+
|
| 21 |
+
def eval(self, dataframe, input_key, reference_key):
|
| 22 |
+
eval_data = dataframe[input_key].to_list()
|
| 23 |
+
ref_data = dataframe[reference_key].to_list()
|
| 24 |
+
self.logger.info(f"Evaluating {self.score_name}...")
|
| 25 |
+
if ref_data is None:
|
| 26 |
+
raise ValueError("Reference data must be provided for BERTScorer")
|
| 27 |
+
results = self.bertscore.compute(
|
| 28 |
+
predictions=eval_data,
|
| 29 |
+
references=ref_data,
|
| 30 |
+
lang=self.lang,
|
| 31 |
+
model_type=self.model_type,
|
| 32 |
+
idf=self.idf,
|
| 33 |
+
rescale_with_baseline=self.rescale_with_baseline
|
| 34 |
+
)
|
| 35 |
+
scores = results["f1"]
|
| 36 |
+
self.logger.info("Evaluation complete!")
|
| 37 |
+
return scores
|
| 38 |
+
|
| 39 |
+
def run(self, storage: DataFlowStorage, input_key: str, reference_key: str, output_key: str='BertScore'):
|
| 40 |
+
self.input_key = input_key
|
| 41 |
+
self.reference_key = reference_key
|
| 42 |
+
self.output_key = output_key
|
| 43 |
+
dataframe = storage.read("dataframe")
|
| 44 |
+
scores = self.eval(dataframe, input_key, reference_key)
|
| 45 |
+
dataframe[self.output_key] = scores
|
| 46 |
+
storage.write(dataframe)
|
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__init__.py
ADDED
|
File without changes
|
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/__pycache__/bleu.cpython-310.pyc
ADDED
|
Binary file (6.82 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu/bleu.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import sys, math, re
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
import six
|
| 6 |
+
from six.moves import xrange as range
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def precook(s, n=4, out=False):
|
| 10 |
+
|
| 11 |
+
words = s.split()
|
| 12 |
+
counts = defaultdict(int)
|
| 13 |
+
for k in range(1,n+1):
|
| 14 |
+
for i in range(len(words)-k+1):
|
| 15 |
+
ngram = tuple(words[i:i+k])
|
| 16 |
+
counts[ngram] += 1
|
| 17 |
+
return (len(words), counts)
|
| 18 |
+
|
| 19 |
+
def cook_refs(refs, eff=None, n=4):
|
| 20 |
+
reflen = []
|
| 21 |
+
maxcounts = {}
|
| 22 |
+
for ref in refs:
|
| 23 |
+
rl, counts = precook(ref, n)
|
| 24 |
+
reflen.append(rl)
|
| 25 |
+
for (ngram,count) in six.iteritems(counts):
|
| 26 |
+
maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if eff == "shortest":
|
| 30 |
+
reflen = min(reflen)
|
| 31 |
+
elif eff == "average":
|
| 32 |
+
reflen = float(sum(reflen))/len(reflen)
|
| 33 |
+
|
| 34 |
+
return (reflen, maxcounts)
|
| 35 |
+
|
| 36 |
+
def cook_test(test, reflen_refmaxcounts, eff=None, n=4):
|
| 37 |
+
reflen, refmaxcounts = reflen_refmaxcounts
|
| 38 |
+
testlen, counts = precook(test, n, True)
|
| 39 |
+
|
| 40 |
+
result = {}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if eff == "closest":
|
| 45 |
+
result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
|
| 46 |
+
else:
|
| 47 |
+
result["reflen"] = reflen
|
| 48 |
+
|
| 49 |
+
result["testlen"] = testlen
|
| 50 |
+
|
| 51 |
+
result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
|
| 52 |
+
|
| 53 |
+
result['correct'] = [0]*n
|
| 54 |
+
for (ngram, count) in six.iteritems(counts):
|
| 55 |
+
result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
|
| 56 |
+
|
| 57 |
+
return result
|
| 58 |
+
|
| 59 |
+
class Bleu(object):
|
| 60 |
+
"""Bleu scorer.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
__slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
|
| 64 |
+
|
| 65 |
+
def copy(self):
|
| 66 |
+
|
| 67 |
+
new = Bleu(n=self.n)
|
| 68 |
+
new.ctest = copy.copy(self.ctest)
|
| 69 |
+
new.crefs = copy.copy(self.crefs)
|
| 70 |
+
new._score = None
|
| 71 |
+
return new
|
| 72 |
+
|
| 73 |
+
def __init__(self, test=None, refs=None, n=4, special_reflen=None):
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
self.n = n
|
| 77 |
+
self.crefs = []
|
| 78 |
+
self.ctest = []
|
| 79 |
+
self.cook_append(test, refs)
|
| 80 |
+
self.special_reflen = special_reflen
|
| 81 |
+
|
| 82 |
+
def cook_append(self, test, refs):
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if refs is not None:
|
| 86 |
+
self.crefs.append(cook_refs(refs))
|
| 87 |
+
if test is not None:
|
| 88 |
+
cooked_test = cook_test(test, self.crefs[-1])
|
| 89 |
+
self.ctest.append(cooked_test) ## N.B.: -1
|
| 90 |
+
else:
|
| 91 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
| 92 |
+
|
| 93 |
+
self._score = None ## need to recompute
|
| 94 |
+
|
| 95 |
+
def ratio(self, option=None):
|
| 96 |
+
self.compute_score(option=option)
|
| 97 |
+
return self._ratio
|
| 98 |
+
|
| 99 |
+
def score_ratio(self, option=None):
|
| 100 |
+
|
| 101 |
+
return (self.fscore(option=option), self.ratio(option=option))
|
| 102 |
+
|
| 103 |
+
def score_ratio_str(self, option=None):
|
| 104 |
+
return "%.4f (%.2f)" % self.score_ratio(option)
|
| 105 |
+
|
| 106 |
+
def reflen(self, option=None):
|
| 107 |
+
self.compute_score(option=option)
|
| 108 |
+
return self._reflen
|
| 109 |
+
|
| 110 |
+
def testlen(self, option=None):
|
| 111 |
+
self.compute_score(option=option)
|
| 112 |
+
return self._testlen
|
| 113 |
+
|
| 114 |
+
def retest(self, new_test):
|
| 115 |
+
if type(new_test) is str:
|
| 116 |
+
new_test = [new_test]
|
| 117 |
+
assert len(new_test) == len(self.crefs), new_test
|
| 118 |
+
self.ctest = []
|
| 119 |
+
for t, rs in zip(new_test, self.crefs):
|
| 120 |
+
self.ctest.append(cook_test(t, rs))
|
| 121 |
+
self._score = None
|
| 122 |
+
|
| 123 |
+
return self
|
| 124 |
+
|
| 125 |
+
def rescore(self, new_test):
|
| 126 |
+
''' replace test(s) with new test(s), and returns the new score.'''
|
| 127 |
+
|
| 128 |
+
return self.retest(new_test).compute_score()
|
| 129 |
+
|
| 130 |
+
def size(self):
|
| 131 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 132 |
+
return len(self.crefs)
|
| 133 |
+
|
| 134 |
+
def __iadd__(self, other):
|
| 135 |
+
'''add an instance (e.g., from another sentence).'''
|
| 136 |
+
|
| 137 |
+
if type(other) is tuple:
|
| 138 |
+
## avoid creating new BleuScorer instances
|
| 139 |
+
self.cook_append(other[0], other[1])
|
| 140 |
+
else:
|
| 141 |
+
assert self.compatible(other), "incompatible BLEUs."
|
| 142 |
+
self.ctest.extend(other.ctest)
|
| 143 |
+
self.crefs.extend(other.crefs)
|
| 144 |
+
self._score = None ## need to recompute
|
| 145 |
+
|
| 146 |
+
return self
|
| 147 |
+
|
| 148 |
+
def compatible(self, other):
|
| 149 |
+
return isinstance(other, Bleu) and self.n == other.n
|
| 150 |
+
|
| 151 |
+
def single_reflen(self, option="average"):
|
| 152 |
+
return self._single_reflen(self.crefs[0][0], option)
|
| 153 |
+
|
| 154 |
+
def _single_reflen(self, reflens, option=None, testlen=None):
|
| 155 |
+
|
| 156 |
+
if option == "shortest":
|
| 157 |
+
reflen = min(reflens)
|
| 158 |
+
elif option == "average":
|
| 159 |
+
reflen = float(sum(reflens))/len(reflens)
|
| 160 |
+
elif option == "closest":
|
| 161 |
+
reflen = min((abs(l-testlen), l) for l in reflens)[1]
|
| 162 |
+
else:
|
| 163 |
+
assert False, "unsupported reflen option %s" % option
|
| 164 |
+
|
| 165 |
+
return reflen
|
| 166 |
+
|
| 167 |
+
def recompute_score(self, option=None, verbose=0):
|
| 168 |
+
self._score = None
|
| 169 |
+
return self.compute_score(option, verbose)
|
| 170 |
+
|
| 171 |
+
def compute_score(self, option=None, verbose=0):
|
| 172 |
+
n = self.n
|
| 173 |
+
small = 1e-9
|
| 174 |
+
tiny = 1e-15 ## so that if guess is 0 still return 0
|
| 175 |
+
bleu_list = [[] for _ in range(n)]
|
| 176 |
+
|
| 177 |
+
if self._score is not None:
|
| 178 |
+
return self._score
|
| 179 |
+
|
| 180 |
+
if option is None:
|
| 181 |
+
option = "average" if len(self.crefs) == 1 else "closest"
|
| 182 |
+
|
| 183 |
+
self._testlen = 0
|
| 184 |
+
self._reflen = 0
|
| 185 |
+
totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
|
| 186 |
+
|
| 187 |
+
# for each sentence
|
| 188 |
+
for comps in self.ctest:
|
| 189 |
+
testlen = comps['testlen']
|
| 190 |
+
self._testlen += testlen
|
| 191 |
+
|
| 192 |
+
if self.special_reflen is None: ## need computation
|
| 193 |
+
reflen = self._single_reflen(comps['reflen'], option, testlen)
|
| 194 |
+
else:
|
| 195 |
+
reflen = self.special_reflen
|
| 196 |
+
|
| 197 |
+
self._reflen += reflen
|
| 198 |
+
|
| 199 |
+
for key in ['guess','correct']:
|
| 200 |
+
for k in range(n):
|
| 201 |
+
totalcomps[key][k] += comps[key][k]
|
| 202 |
+
|
| 203 |
+
# append per image bleu score
|
| 204 |
+
bleu = 1.
|
| 205 |
+
for k in range(n):
|
| 206 |
+
bleu *= (float(comps['correct'][k]) + tiny) \
|
| 207 |
+
/(float(comps['guess'][k]) + small)
|
| 208 |
+
bleu_list[k].append(bleu ** (1./(k+1)))
|
| 209 |
+
ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
|
| 210 |
+
if ratio < 1:
|
| 211 |
+
for k in range(n):
|
| 212 |
+
bleu_list[k][-1] *= math.exp(1 - 1/ratio)
|
| 213 |
+
|
| 214 |
+
if verbose > 1:
|
| 215 |
+
print(comps, reflen)
|
| 216 |
+
|
| 217 |
+
totalcomps['reflen'] = self._reflen
|
| 218 |
+
totalcomps['testlen'] = self._testlen
|
| 219 |
+
|
| 220 |
+
bleus = []
|
| 221 |
+
bleu = 1.
|
| 222 |
+
for k in range(n):
|
| 223 |
+
bleu *= float(totalcomps['correct'][k] + tiny) \
|
| 224 |
+
/ (totalcomps['guess'][k] + small)
|
| 225 |
+
bleus.append(bleu ** (1./(k+1)))
|
| 226 |
+
ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
|
| 227 |
+
if ratio < 1:
|
| 228 |
+
for k in range(n):
|
| 229 |
+
bleus[k] *= math.exp(1 - 1/ratio)
|
| 230 |
+
|
| 231 |
+
if verbose > 0:
|
| 232 |
+
print(totalcomps)
|
| 233 |
+
print("ratio:", ratio)
|
| 234 |
+
|
| 235 |
+
self._score = bleus
|
| 236 |
+
return self._score, bleu_list
|
DataFlow/dataflow/operators/eval/GeneralText/gen/bleu_scorer.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataflow.core import OperatorABC
|
| 2 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 3 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 4 |
+
from dataflow import get_logger
|
| 5 |
+
from dataflow.operators.eval.GeneralText.gen.bleu.bleu import Bleu
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
@OPERATOR_REGISTRY.register()
|
| 9 |
+
class BleuScorer(OperatorABC):
|
| 10 |
+
def __init__(self, n=4, eff="average", special_reflen=None):
|
| 11 |
+
self.logger = get_logger()
|
| 12 |
+
self.logger.info(f'Initializing {self.__class__.__name__}...')
|
| 13 |
+
self.score_name = 'BleuScore'
|
| 14 |
+
valid_eff_options = ["shortest", "average", "longest"]
|
| 15 |
+
if eff not in valid_eff_options:
|
| 16 |
+
raise ValueError(f"Invalid value for 'eff'. Must be one of {valid_eff_options}, but got '{eff}'.")
|
| 17 |
+
self.n = n # Max n-gram length (default: 4)
|
| 18 |
+
self.eff = eff # [shortest, average, longest]
|
| 19 |
+
self.special_reflen = special_reflen # Special reference length if specified
|
| 20 |
+
self.logger.info(f'{self.__class__.__name__} initialized.')
|
| 21 |
+
|
| 22 |
+
def _score_func(self, eval_text, ref_text):
|
| 23 |
+
bleu_scorer = Bleu(
|
| 24 |
+
test=eval_text,
|
| 25 |
+
refs=[ref_text],
|
| 26 |
+
n=self.n,
|
| 27 |
+
special_reflen=self.special_reflen,
|
| 28 |
+
)
|
| 29 |
+
bleu_score, _ = bleu_scorer.compute_score(option=self.eff)
|
| 30 |
+
return bleu_score[0]
|
| 31 |
+
|
| 32 |
+
def eval(self, dataframe, input_key, reference_key):
|
| 33 |
+
eval_data = dataframe[input_key]
|
| 34 |
+
ref_data = dataframe[reference_key]
|
| 35 |
+
self.logger.info(f"Evaluating {self.score_name}...")
|
| 36 |
+
scores = [self._score_func(eval_text, ref_text) for eval_text, ref_text in tqdm(zip(eval_data, ref_data), desc="BleuScorer Evaluating...")]
|
| 37 |
+
self.logger.info("Evaluation complete!")
|
| 38 |
+
return scores
|
| 39 |
+
|
| 40 |
+
def run(self, storage: DataFlowStorage, input_key: str, reference_key: str, output_key: str='BleuScore'):
|
| 41 |
+
self.input_key = input_key
|
| 42 |
+
self.reference_key = reference_key
|
| 43 |
+
self.output_key = output_key
|
| 44 |
+
dataframe = storage.read("dataframe")
|
| 45 |
+
scores = self.eval(dataframe, input_key, reference_key)
|
| 46 |
+
dataframe[self.output_key] = scores
|
| 47 |
+
storage.write(dataframe)
|
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__init__.py
ADDED
|
File without changes
|
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/__pycache__/cider.cpython-310.pyc
ADDED
|
Binary file (5.54 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/gen/cider/cider.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import pickle
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
import os
|
| 7 |
+
from six.moves import xrange
|
| 8 |
+
import six
|
| 9 |
+
|
| 10 |
+
def precook(s, n=4, out=False):
|
| 11 |
+
words = s.split()
|
| 12 |
+
counts = defaultdict(int)
|
| 13 |
+
for k in xrange(1, n+1):
|
| 14 |
+
for i in xrange(len(words)-k+1):
|
| 15 |
+
ngram = tuple(words[i:i+k])
|
| 16 |
+
counts[ngram] += 1
|
| 17 |
+
return counts
|
| 18 |
+
|
| 19 |
+
def cook_refs(refs, n=4):
|
| 20 |
+
return [precook(ref, n) for ref in refs]
|
| 21 |
+
|
| 22 |
+
def cook_test(test, n=4):
|
| 23 |
+
return precook(test, n, True)
|
| 24 |
+
|
| 25 |
+
class Cider(object):
|
| 26 |
+
"""CIDEr scorer."""
|
| 27 |
+
|
| 28 |
+
def copy(self):
|
| 29 |
+
new = Cider(n=self.n)
|
| 30 |
+
new.ctest = copy.copy(self.ctest)
|
| 31 |
+
new.crefs = copy.copy(self.crefs)
|
| 32 |
+
return new
|
| 33 |
+
|
| 34 |
+
def __init__(self, test=None, refs=None, n=4, sigma=6.0, idf=None):
|
| 35 |
+
self.n = n
|
| 36 |
+
self.sigma = sigma
|
| 37 |
+
self.crefs = []
|
| 38 |
+
self.ctest = []
|
| 39 |
+
self.document_frequency = defaultdict(float)
|
| 40 |
+
self.ref_len = None
|
| 41 |
+
|
| 42 |
+
if idf:
|
| 43 |
+
self.document_frequency = idf['df']
|
| 44 |
+
self.ref_len = np.log(float(idf['ref_len'])) # Use reference length from the IDF
|
| 45 |
+
|
| 46 |
+
self.cook_append(test, refs)
|
| 47 |
+
|
| 48 |
+
def cook_append(self, test, refs):
|
| 49 |
+
if refs is not None:
|
| 50 |
+
self.crefs.append(cook_refs(refs))
|
| 51 |
+
if test is not None:
|
| 52 |
+
self.ctest.append(cook_test(test))
|
| 53 |
+
else:
|
| 54 |
+
self.ctest.append(None)
|
| 55 |
+
|
| 56 |
+
def size(self):
|
| 57 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 58 |
+
return len(self.crefs)
|
| 59 |
+
|
| 60 |
+
def __iadd__(self, other):
|
| 61 |
+
if type(other) is tuple:
|
| 62 |
+
self.cook_append(other[0], other[1])
|
| 63 |
+
else:
|
| 64 |
+
self.ctest.extend(other.ctest)
|
| 65 |
+
self.crefs.extend(other.crefs)
|
| 66 |
+
return self
|
| 67 |
+
|
| 68 |
+
def compute_doc_freq(self):
|
| 69 |
+
'''Compute term frequency for reference data to generate IDF.'''
|
| 70 |
+
if not self.document_frequency: # Handle empty DF (for 'corpus' mode)
|
| 71 |
+
for refs in self.crefs:
|
| 72 |
+
for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]):
|
| 73 |
+
self.document_frequency[ngram] += 1
|
| 74 |
+
|
| 75 |
+
def compute_cider(self, df_mode):
|
| 76 |
+
def counts2vec(cnts):
|
| 77 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
| 78 |
+
length = 0
|
| 79 |
+
norm = [0.0 for _ in range(self.n)]
|
| 80 |
+
for (ngram, term_freq) in cnts.items():
|
| 81 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
| 82 |
+
n = len(ngram) - 1
|
| 83 |
+
vec[n][ngram] = float(term_freq) * (self.ref_len - df)
|
| 84 |
+
norm[n] += pow(vec[n][ngram], 2)
|
| 85 |
+
|
| 86 |
+
if n == 1:
|
| 87 |
+
length += term_freq
|
| 88 |
+
norm = [np.sqrt(n) for n in norm]
|
| 89 |
+
return vec, norm, length
|
| 90 |
+
|
| 91 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
| 92 |
+
delta = float(length_hyp - length_ref)
|
| 93 |
+
val = np.array([0.0 for _ in range(self.n)])
|
| 94 |
+
for n in range(self.n):
|
| 95 |
+
for (ngram, count) in vec_hyp[n].items():
|
| 96 |
+
val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
|
| 97 |
+
|
| 98 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
| 99 |
+
val[n] /= (norm_hyp[n] * norm_ref[n])
|
| 100 |
+
|
| 101 |
+
val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2))
|
| 102 |
+
return val
|
| 103 |
+
|
| 104 |
+
if df_mode == "corpus":
|
| 105 |
+
self.ref_len = np.log(float(len(self.crefs))) # Use total references in corpus as ref_len
|
| 106 |
+
|
| 107 |
+
scores = []
|
| 108 |
+
for test, refs in zip(self.ctest, self.crefs):
|
| 109 |
+
vec, norm, length = counts2vec(test)
|
| 110 |
+
score = np.array([0.0 for _ in range(self.n)])
|
| 111 |
+
for ref in refs:
|
| 112 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
| 113 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
| 114 |
+
score_avg = np.mean(score)
|
| 115 |
+
score_avg /= len(refs)
|
| 116 |
+
score_avg *= 10.0
|
| 117 |
+
scores.append(score_avg)
|
| 118 |
+
return scores
|
| 119 |
+
|
| 120 |
+
def compute_score(self, df_mode, option=None, verbose=0):
|
| 121 |
+
'''Compute the CIDEr score based on df_mode (corpus or IDF-based).'''
|
| 122 |
+
self.compute_doc_freq()
|
| 123 |
+
|
| 124 |
+
if df_mode == "corpus":
|
| 125 |
+
if not self.document_frequency: # Handle the case where DF is empty
|
| 126 |
+
raise ValueError("Document frequency is empty. Please check the corpus data.")
|
| 127 |
+
|
| 128 |
+
min_required_data = max(self.document_frequency.values())
|
| 129 |
+
# print(min_required_data)# For corpus mode, we require at least one reference
|
| 130 |
+
# if len(self.ctest) < min_required_data:
|
| 131 |
+
# raise ValueError(f"Insufficient test data: {len(self.ctest)} samples, but at least {min_required_data} are required.")
|
| 132 |
+
|
| 133 |
+
score = self.compute_cider(df_mode)
|
| 134 |
+
return np.mean(np.array(score)), np.array(score)
|
DataFlow/dataflow/operators/eval/GeneralText/gen/cider_scorer.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pickle
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from dataflow.core import OperatorABC
|
| 6 |
+
from dataflow.utils.storage import DataFlowStorage
|
| 7 |
+
from dataflow.utils.registry import OPERATOR_REGISTRY
|
| 8 |
+
from dataflow import get_logger
|
| 9 |
+
from dataflow.operators.eval.GeneralText.gen.cider.cider import Cider
|
| 10 |
+
|
| 11 |
+
def load_idf(idf_path):
|
| 12 |
+
with open(idf_path, 'rb') as f:
|
| 13 |
+
idf = pickle.load(f, encoding='utf-8')
|
| 14 |
+
return idf
|
| 15 |
+
|
| 16 |
+
@OPERATOR_REGISTRY.register()
|
| 17 |
+
class CiderScorer(OperatorABC):
|
| 18 |
+
def __init__(self, n=4, sigma=6.0, df_mode="coco-val-df", idf_path="./dataflow/operators/eval/GeneralText/gen/cider/coco-val-df.p"):
|
| 19 |
+
self.logger = get_logger()
|
| 20 |
+
self.logger.info(f'Initializing {self.__class__.__name__}...')
|
| 21 |
+
self.score_name = 'CiderScore'
|
| 22 |
+
self.n = n # Max n-gram length (default: 4)
|
| 23 |
+
self.sigma = sigma # Sigma for Gaussian penalty (default: 6.0)
|
| 24 |
+
self.df_mode = df_mode
|
| 25 |
+
if self.df_mode != "corpus":
|
| 26 |
+
# The idf file can be downloaded at https://github.com/ramavedantam/coco-caption/blob/master/data/coco-val-df.p
|
| 27 |
+
# Put the file in the correct idf_path
|
| 28 |
+
self.idf = load_idf(idf_path)
|
| 29 |
+
else:
|
| 30 |
+
self.idf = None # No need to load IDF for 'corpus' mode
|
| 31 |
+
self.logger.info(f'{self.__class__.__name__} initialized.')
|
| 32 |
+
|
| 33 |
+
def _score_func(self, eval_text, ref_text):
|
| 34 |
+
cider_scorer = Cider(
|
| 35 |
+
test=eval_text,
|
| 36 |
+
refs=[ref_text],
|
| 37 |
+
n=self.n,
|
| 38 |
+
sigma=self.sigma,
|
| 39 |
+
idf=self.idf # Pass IDF (None if using 'corpus')
|
| 40 |
+
)
|
| 41 |
+
# Pass df_mode dynamically based on the argument
|
| 42 |
+
cider_score, _ = cider_scorer.compute_score(df_mode='corpus' if self.idf is None else 'coco-val-df')
|
| 43 |
+
return cider_score
|
| 44 |
+
|
| 45 |
+
def eval(self, dataframe, input_key, reference_key):
|
| 46 |
+
eval_data = dataframe[input_key]
|
| 47 |
+
ref_data = dataframe[reference_key]
|
| 48 |
+
self.logger.info(f"Evaluating {self.score_name}...")
|
| 49 |
+
scores = [self._score_func(eval_text, ref_text) for eval_text, ref_text in tqdm(zip(eval_data, ref_data), desc="CiderScorer Evaluating...")]
|
| 50 |
+
self.logger.info("Evaluation complete!")
|
| 51 |
+
return scores
|
| 52 |
+
|
| 53 |
+
def run(self, storage: DataFlowStorage, input_key: str, reference_key: str, output_key: str='CiderScore'):
|
| 54 |
+
self.input_key = input_key
|
| 55 |
+
self.reference_key = reference_key
|
| 56 |
+
self.output_key = output_key
|
| 57 |
+
dataframe = storage.read("dataframe")
|
| 58 |
+
scores = self.eval(dataframe, input_key, reference_key)
|
| 59 |
+
dataframe[self.output_key] = scores
|
| 60 |
+
storage.write(dataframe)
|
DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (5.21 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/models/Kenlm/model.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import unicodedata
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
import kenlm
|
| 7 |
+
import sentencepiece
|
| 8 |
+
|
| 9 |
+
class SentencePiece:
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
model: str,
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.sp = sentencepiece.SentencePieceProcessor()
|
| 16 |
+
self.sp.load(str(model))
|
| 17 |
+
|
| 18 |
+
def do(self, text: dict) -> dict:
|
| 19 |
+
tokenized = self.sp.encode_as_pieces(text)
|
| 20 |
+
return " ".join(tokenized)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class KenlmModel:
|
| 24 |
+
digit_re: re.Pattern = re.compile(r"\d")
|
| 25 |
+
unicode_punct: Dict[str, str] = {
|
| 26 |
+
",": ",",
|
| 27 |
+
"。": ".",
|
| 28 |
+
"、": ",",
|
| 29 |
+
"„": '"',
|
| 30 |
+
"”": '"',
|
| 31 |
+
"“": '"',
|
| 32 |
+
"«": '"',
|
| 33 |
+
"»": '"',
|
| 34 |
+
"1": '"',
|
| 35 |
+
"」": '"',
|
| 36 |
+
"「": '"',
|
| 37 |
+
"《": '"',
|
| 38 |
+
"》": '"',
|
| 39 |
+
"´": "'",
|
| 40 |
+
"∶": ":",
|
| 41 |
+
":": ":",
|
| 42 |
+
"?": "?",
|
| 43 |
+
"!": "!",
|
| 44 |
+
"(": "(",
|
| 45 |
+
")": ")",
|
| 46 |
+
";": ";",
|
| 47 |
+
"–": "-",
|
| 48 |
+
"—": " - ",
|
| 49 |
+
".": ". ",
|
| 50 |
+
"~": "~",
|
| 51 |
+
"’": "'",
|
| 52 |
+
"…": "...",
|
| 53 |
+
"━": "-",
|
| 54 |
+
"〈": "<",
|
| 55 |
+
"〉": ">",
|
| 56 |
+
"【": "[",
|
| 57 |
+
"】": "]",
|
| 58 |
+
"%": "%",
|
| 59 |
+
"►": "-",
|
| 60 |
+
}
|
| 61 |
+
unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
|
| 62 |
+
non_printing_chars_re = re.compile(
|
| 63 |
+
f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
|
| 64 |
+
)
|
| 65 |
+
kenlm_model_dir = None
|
| 66 |
+
sentence_piece_model_dir = None
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
model_dataset: str,
|
| 71 |
+
language: str,
|
| 72 |
+
lower_case: bool = False,
|
| 73 |
+
remove_accents: bool = False,
|
| 74 |
+
normalize_numbers: bool = True,
|
| 75 |
+
punctuation: int = 1,
|
| 76 |
+
):
|
| 77 |
+
self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin"))
|
| 78 |
+
self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model"))
|
| 79 |
+
self.accent = remove_accents
|
| 80 |
+
self.case = lower_case
|
| 81 |
+
self.numbers = normalize_numbers
|
| 82 |
+
self.punct = punctuation
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_pretrained(
|
| 86 |
+
cls,
|
| 87 |
+
model_dataset: str,
|
| 88 |
+
language: str,
|
| 89 |
+
):
|
| 90 |
+
return cls(
|
| 91 |
+
model_dataset,
|
| 92 |
+
language,
|
| 93 |
+
False,
|
| 94 |
+
False,
|
| 95 |
+
True,
|
| 96 |
+
1,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def pp(self, log_score, length):
|
| 100 |
+
return 10.0 ** (-log_score / length)
|
| 101 |
+
|
| 102 |
+
def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
|
| 103 |
+
if normalize_cc_net:
|
| 104 |
+
doc = self.normalize(
|
| 105 |
+
doc,
|
| 106 |
+
accent=self.accent,
|
| 107 |
+
case=self.case,
|
| 108 |
+
numbers=self.numbers,
|
| 109 |
+
punct=self.punct,
|
| 110 |
+
)
|
| 111 |
+
# Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
|
| 112 |
+
doc = self.tokenizer.do(doc)
|
| 113 |
+
doc_log_score, doc_length = 0, 0
|
| 114 |
+
for line in doc.split("\n"):
|
| 115 |
+
log_score = self.model.score(line)
|
| 116 |
+
length = len(line.split()) + 1
|
| 117 |
+
doc_log_score += log_score
|
| 118 |
+
doc_length += length
|
| 119 |
+
return round(self.pp(doc_log_score, doc_length), 1)
|
| 120 |
+
|
| 121 |
+
def normalize(
|
| 122 |
+
self,
|
| 123 |
+
line: str,
|
| 124 |
+
accent: bool = True,
|
| 125 |
+
case: bool = True,
|
| 126 |
+
numbers: bool = True,
|
| 127 |
+
punct: int = 1,
|
| 128 |
+
) -> str:
|
| 129 |
+
line = line.strip()
|
| 130 |
+
if not line:
|
| 131 |
+
return line
|
| 132 |
+
if case:
|
| 133 |
+
line = line.lower()
|
| 134 |
+
if accent:
|
| 135 |
+
line = self.strip_accents(line)
|
| 136 |
+
if numbers:
|
| 137 |
+
line = self.digit_re.sub("0", line)
|
| 138 |
+
if punct == 1:
|
| 139 |
+
line = self.replace_unicode_punct(line)
|
| 140 |
+
elif punct == 2:
|
| 141 |
+
line = self.remove_unicode_punct(line)
|
| 142 |
+
line = self.remove_non_printing_char(line)
|
| 143 |
+
return line
|
| 144 |
+
|
| 145 |
+
def strip_accents(self, line: str) -> str:
|
| 146 |
+
"""Strips accents from a piece of text."""
|
| 147 |
+
nfd = unicodedata.normalize("NFD", line)
|
| 148 |
+
output = [c for c in nfd if unicodedata.category(c) != "Mn"]
|
| 149 |
+
if len(output) == line:
|
| 150 |
+
return line
|
| 151 |
+
return "".join(output)
|
| 152 |
+
|
| 153 |
+
def replace_unicode_punct(self, text: str) -> str:
|
| 154 |
+
return "".join(self.unicode_punct.get(c, c) for c in text)
|
| 155 |
+
|
| 156 |
+
def remove_unicode_punct(self, text: str) -> str:
|
| 157 |
+
"""More aggressive version of replace_unicode_punct but also faster."""
|
| 158 |
+
return self.unicode_punct_re.sub("", text)
|
| 159 |
+
|
| 160 |
+
def remove_non_printing_char(self, text: str) -> str:
|
| 161 |
+
return self.non_printing_chars_re.sub("", text)
|
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/__pycache__/qurater_annotate.cpython-310.pyc
ADDED
|
Binary file (7.02 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/__pycache__/modeling_flash_llama.cpython-310.pyc
ADDED
|
Binary file (25 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/modeling/modeling_flash_llama.py
ADDED
|
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
""" PyTorch LLaMA model."""
|
| 21 |
+
from typing import List, Optional, Tuple, Union, Any
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import torch.utils.checkpoint
|
| 26 |
+
from torch import nn
|
| 27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 28 |
+
|
| 29 |
+
import torch.distributed as dist
|
| 30 |
+
|
| 31 |
+
from transformers.activations import ACT2FN
|
| 32 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 33 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 34 |
+
from transformers.utils import logging
|
| 35 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 36 |
+
|
| 37 |
+
def try_import_flash_attention():
|
| 38 |
+
try:
|
| 39 |
+
from flash_attn import flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_with_kvcache
|
| 40 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 41 |
+
from flash_attn.layers.rotary import apply_rotary_emb_func
|
| 42 |
+
except ImportError as e:
|
| 43 |
+
if 'flash_attn.layers.rotary' in str(e):
|
| 44 |
+
raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`')
|
| 45 |
+
else:
|
| 46 |
+
raise ImportError('Please install flash_attention dependency in GPU environment')
|
| 47 |
+
from dataflow import get_logger
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
# @torch.jit.script
|
| 52 |
+
def rmsnorm_func(hidden_states, weight, variance_epsilon):
|
| 53 |
+
input_dtype = hidden_states.dtype
|
| 54 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 55 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 56 |
+
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
| 57 |
+
return (weight * hidden_states).to(input_dtype)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LlamaRMSNorm(nn.Module):
|
| 61 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 62 |
+
"""
|
| 63 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
| 64 |
+
"""
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 67 |
+
self.register_buffer(
|
| 68 |
+
"variance_epsilon",
|
| 69 |
+
torch.tensor(eps),
|
| 70 |
+
persistent=False,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, hidden_states):
|
| 74 |
+
return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class FlashRotaryEmbedding(torch.nn.Module):
|
| 78 |
+
"""
|
| 79 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
| 80 |
+
A crucial insight from the method is that the query and keys are
|
| 81 |
+
transformed by rotation matrices which depend on the relative positions.
|
| 82 |
+
|
| 83 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
| 84 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
| 85 |
+
|
| 86 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
| 87 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
| 88 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
| 89 |
+
|
| 90 |
+
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
| 91 |
+
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
| 92 |
+
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
|
| 96 |
+
scaling_factor=1.0, pos_idx_in_fp32=True, device=None):
|
| 97 |
+
"""
|
| 98 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
| 99 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
| 100 |
+
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
| 101 |
+
otherwise they might be in lower precision.
|
| 102 |
+
This option was added because previously (before 2023-07-02), when we construct
|
| 103 |
+
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
| 104 |
+
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
| 105 |
+
self.inv_freq would be bf16, and the position indices are also in bf16.
|
| 106 |
+
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
| 107 |
+
embeddings for some positions will coincide.
|
| 108 |
+
To maintain compatibility with models previously trained in pure bf16,
|
| 109 |
+
we add this option.
|
| 110 |
+
scaling_factor: RotaryEmbedding extended with linear scaling.
|
| 111 |
+
"""
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.dim = dim
|
| 114 |
+
self.base = float(base)
|
| 115 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 116 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
| 117 |
+
inv_freq = self._compute_inv_freq(device)
|
| 118 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 119 |
+
self.interleaved = interleaved
|
| 120 |
+
self.scale_base = scale_base
|
| 121 |
+
self.scaling_factor = scaling_factor
|
| 122 |
+
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
| 123 |
+
/ (1.4 * dim) if scale_base is not None else None)
|
| 124 |
+
self.register_buffer("scale", scale)
|
| 125 |
+
|
| 126 |
+
self._seq_len_cached = 0
|
| 127 |
+
self._cos_cached = None
|
| 128 |
+
self._sin_cached = None
|
| 129 |
+
self._cos_k_cached = None
|
| 130 |
+
self._sin_k_cached = None
|
| 131 |
+
|
| 132 |
+
def _compute_inv_freq(self, device=None):
|
| 133 |
+
return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
|
| 134 |
+
dtype=torch.float32) / self.dim))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
| 138 |
+
# Reset the tables if the sequence length has changed,
|
| 139 |
+
# if we're on a new device (possibly due to tracing for instance),
|
| 140 |
+
# or if we're switching from inference mode to training
|
| 141 |
+
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
|
| 142 |
+
or self._cos_cached.dtype != dtype
|
| 143 |
+
or (self.training and self._cos_cached.is_inference())):
|
| 144 |
+
self._seq_len_cached = seqlen
|
| 145 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
| 146 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
| 147 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
| 148 |
+
if self.pos_idx_in_fp32:
|
| 149 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 150 |
+
t /= self.scaling_factor
|
| 151 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
| 152 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
| 153 |
+
# cos & sin output to change significantly.
|
| 154 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
| 155 |
+
if self.inv_freq.dtype != torch.float32:
|
| 156 |
+
inv_freq = self.inv_freq.to(torch.float32)
|
| 157 |
+
else:
|
| 158 |
+
inv_freq = self.inv_freq
|
| 159 |
+
else:
|
| 160 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 161 |
+
t /= self.scaling_factor
|
| 162 |
+
inv_freq = self.inv_freq
|
| 163 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
| 164 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 165 |
+
freqs = torch.outer(t, inv_freq)
|
| 166 |
+
if self.scale is None:
|
| 167 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 168 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 169 |
+
else:
|
| 170 |
+
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
| 171 |
+
- seqlen // 2) / self.scale_base)
|
| 172 |
+
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
| 173 |
+
# We want the multiplication by scale to happen in fp32
|
| 174 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 175 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 176 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 177 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 178 |
+
|
| 179 |
+
def forward(self,
|
| 180 |
+
q: torch.Tensor, k: torch.Tensor,
|
| 181 |
+
seqlen_offset: int = 0,
|
| 182 |
+
unpadded_lengths: Optional[Tuple[torch.Tensor]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 183 |
+
"""
|
| 184 |
+
q: (batch, seqlen, nheads, headdim)
|
| 185 |
+
k: (batch, seqlen, nheads, headdim)
|
| 186 |
+
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
| 187 |
+
token in the batch.
|
| 188 |
+
"""
|
| 189 |
+
if unpadded_lengths is not None:
|
| 190 |
+
cu_seqlens, max_seqlen = unpadded_lengths
|
| 191 |
+
else:
|
| 192 |
+
cu_seqlens, max_seqlen = None, q.shape[1]
|
| 193 |
+
self._update_cos_sin_cache(max_seqlen + seqlen_offset, device=q.device, dtype=q.dtype)
|
| 194 |
+
|
| 195 |
+
if self.scale is None:
|
| 196 |
+
return apply_rotary_emb_func(
|
| 197 |
+
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
| 198 |
+
self.interleaved, True, # inplace=True,
|
| 199 |
+
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
| 200 |
+
), apply_rotary_emb_func(
|
| 201 |
+
k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
| 202 |
+
self.interleaved, True, # inplace=True
|
| 203 |
+
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
assert False
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class LlamaMLP(nn.Module):
|
| 210 |
+
def __init__(self, config):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.config = config
|
| 213 |
+
self.hidden_size = config.hidden_size
|
| 214 |
+
self.intermediate_size = config.intermediate_size
|
| 215 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 216 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 217 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 218 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@torch.jit.script
|
| 225 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 226 |
+
if n_rep == 1:
|
| 227 |
+
return hidden_states
|
| 228 |
+
final_shape = list(hidden_states.shape[:-2]) + [-1] + [hidden_states.shape[-1]]
|
| 229 |
+
expand_shape = [-1] * (len(hidden_states.shape) - 1) + [n_rep] + [-1]
|
| 230 |
+
hidden_states = hidden_states.unsqueeze(-2).expand(expand_shape)
|
| 231 |
+
return hidden_states.reshape(final_shape)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class LlamaAttention(nn.Module):
|
| 235 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 236 |
+
|
| 237 |
+
def __init__(self, config: LlamaConfig):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.config = config
|
| 240 |
+
self.hidden_size = config.hidden_size
|
| 241 |
+
self.num_heads = config.num_attention_heads
|
| 242 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 243 |
+
self.num_key_value_heads = getattr(config, "num_key_value_heads", self.num_heads)
|
| 244 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 245 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 246 |
+
|
| 247 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 248 |
+
raise ValueError(
|
| 249 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 250 |
+
f" and `num_heads`: {self.num_heads})."
|
| 251 |
+
)
|
| 252 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 253 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 254 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 255 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 256 |
+
|
| 257 |
+
self.register_buffer(
|
| 258 |
+
"norm_factor",
|
| 259 |
+
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
| 260 |
+
persistent=False,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if not getattr(self.config, "rope_scaling", None):
|
| 264 |
+
scaling_factor = 1
|
| 265 |
+
else:
|
| 266 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 267 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 268 |
+
assert scaling_type == 'linear'
|
| 269 |
+
theta = getattr(self.config, "rope_theta", 10000)
|
| 270 |
+
self.rotary_emb = FlashRotaryEmbedding(
|
| 271 |
+
self.head_dim, base=theta, interleaved=False, scaling_factor=scaling_factor,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
self.distributed_attn_func = flash_attn_kvpacked_func
|
| 275 |
+
|
| 276 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 277 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 278 |
+
|
| 279 |
+
def forward(
|
| 280 |
+
self,
|
| 281 |
+
hidden_states: torch.Tensor,
|
| 282 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 283 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 284 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 285 |
+
output_attentions: bool = False,
|
| 286 |
+
use_cache: bool = False,
|
| 287 |
+
unpadded_lengths: Optional[Tuple[torch.Tensor]] = None,
|
| 288 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 289 |
+
h_size = hidden_states.size(-1)
|
| 290 |
+
|
| 291 |
+
has_layer_past = past_key_value is not None
|
| 292 |
+
|
| 293 |
+
if has_layer_past:
|
| 294 |
+
past_kv = past_key_value[0]
|
| 295 |
+
past_len = past_key_value[1]
|
| 296 |
+
else:
|
| 297 |
+
past_len = 0
|
| 298 |
+
|
| 299 |
+
# NOTE: Hack to include position_ids, assuming they are increasing uniformly per block
|
| 300 |
+
if position_ids is not None:
|
| 301 |
+
past_len += position_ids.min()
|
| 302 |
+
|
| 303 |
+
q = self.q_proj(hidden_states)
|
| 304 |
+
k = self.k_proj(hidden_states)
|
| 305 |
+
v = self.v_proj(hidden_states)
|
| 306 |
+
|
| 307 |
+
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
|
| 308 |
+
k = k.view(*k.shape[:-1], self.num_key_value_heads, self.head_dim)
|
| 309 |
+
v = v.view(*v.shape[:-1], self.num_key_value_heads, self.head_dim)
|
| 310 |
+
|
| 311 |
+
q, k = self.rotary_emb(q, k, past_len, unpadded_lengths)
|
| 312 |
+
|
| 313 |
+
kv = torch.stack([k, v], -3)
|
| 314 |
+
kv = repeat_kv(kv, self.num_key_value_groups)
|
| 315 |
+
|
| 316 |
+
# Cache QKV values
|
| 317 |
+
if has_layer_past:
|
| 318 |
+
new_len = past_len+q.size(1)
|
| 319 |
+
if new_len > past_kv.size(1):
|
| 320 |
+
past_kv = torch.cat([past_kv, torch.empty(hidden_states.size(0), 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1)
|
| 321 |
+
past_kv[:, past_len:new_len] = kv
|
| 322 |
+
kv = past_kv[:, :new_len]
|
| 323 |
+
else:
|
| 324 |
+
past_kv = kv
|
| 325 |
+
|
| 326 |
+
if unpadded_lengths is not None:
|
| 327 |
+
# varlen, ignore padding tokens, efficient for large batch with many paddings
|
| 328 |
+
assert attention_mask is not None
|
| 329 |
+
cu_seqlens, max_seqlen = unpadded_lengths
|
| 330 |
+
|
| 331 |
+
attn_outputs = flash_attn_varlen_kvpacked_func(
|
| 332 |
+
q, kv,
|
| 333 |
+
cu_seqlens, cu_seqlens,
|
| 334 |
+
max_seqlen, max_seqlen,
|
| 335 |
+
dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
|
| 336 |
+
causal=True, return_attn_probs=output_attentions
|
| 337 |
+
)
|
| 338 |
+
# elif use_cache and past_key_value is not None:
|
| 339 |
+
# attn_outputs = flash_attn_with_kvcache(
|
| 340 |
+
# q,
|
| 341 |
+
# kv[:, :, 0],
|
| 342 |
+
# kv[:, :, 1],
|
| 343 |
+
# softmax_scale=1.0/self.norm_factor,
|
| 344 |
+
# causal=True,
|
| 345 |
+
# )
|
| 346 |
+
else:
|
| 347 |
+
attn_outputs = flash_attn_kvpacked_func(
|
| 348 |
+
q, kv,
|
| 349 |
+
dropout_p=0.0,
|
| 350 |
+
softmax_scale=1.0/self.norm_factor,
|
| 351 |
+
causal=True,
|
| 352 |
+
return_attn_probs=output_attentions,
|
| 353 |
+
)
|
| 354 |
+
past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
| 358 |
+
attn_output = attn_output.reshape(*attn_output.shape[:-2], h_size)
|
| 359 |
+
attn_weights = attn_outputs[2] if output_attentions else None
|
| 360 |
+
|
| 361 |
+
attn_output = self.o_proj(attn_output)
|
| 362 |
+
|
| 363 |
+
if not output_attentions:
|
| 364 |
+
attn_weights = None
|
| 365 |
+
|
| 366 |
+
return attn_output, attn_weights, past_key_value
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class LlamaDecoderLayer(nn.Module):
|
| 370 |
+
def __init__(self, config: LlamaConfig):
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.hidden_size = config.hidden_size
|
| 373 |
+
self.self_attn = LlamaAttention(config=config)
|
| 374 |
+
self.mlp = LlamaMLP(config)
|
| 375 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 376 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 377 |
+
self._fsdp_wrap = True
|
| 378 |
+
|
| 379 |
+
def forward(
|
| 380 |
+
self,
|
| 381 |
+
hidden_states: torch.Tensor,
|
| 382 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 383 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 384 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 385 |
+
unpadded_lengths: Optional[Tuple[torch.Tensor]] = None,
|
| 386 |
+
output_attentions: Optional[bool] = False,
|
| 387 |
+
use_cache: Optional[bool] = False,
|
| 388 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 389 |
+
"""
|
| 390 |
+
Args:
|
| 391 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 392 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 393 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 394 |
+
output_attentions (`bool`, *optional*):
|
| 395 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 396 |
+
returned tensors for more detail.
|
| 397 |
+
use_cache (`bool`, *optional*):
|
| 398 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 399 |
+
(see `past_key_values`).
|
| 400 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
residual = hidden_states
|
| 404 |
+
|
| 405 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 406 |
+
|
| 407 |
+
# Self Attention
|
| 408 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 409 |
+
hidden_states=hidden_states,
|
| 410 |
+
attention_mask=attention_mask,
|
| 411 |
+
position_ids=position_ids,
|
| 412 |
+
past_key_value=past_key_value,
|
| 413 |
+
output_attentions=output_attentions,
|
| 414 |
+
use_cache=use_cache,
|
| 415 |
+
unpadded_lengths=unpadded_lengths,
|
| 416 |
+
)
|
| 417 |
+
hidden_states = residual + hidden_states
|
| 418 |
+
|
| 419 |
+
# Fully Connected
|
| 420 |
+
residual = hidden_states
|
| 421 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 422 |
+
hidden_states = self.mlp(hidden_states)
|
| 423 |
+
hidden_states = residual + hidden_states
|
| 424 |
+
|
| 425 |
+
outputs = (hidden_states,)
|
| 426 |
+
|
| 427 |
+
if output_attentions:
|
| 428 |
+
outputs += (self_attn_weights,)
|
| 429 |
+
|
| 430 |
+
if use_cache:
|
| 431 |
+
outputs += (present_key_value,)
|
| 432 |
+
|
| 433 |
+
return outputs
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
| 437 |
+
config_class = LlamaConfig
|
| 438 |
+
base_model_prefix = "model"
|
| 439 |
+
supports_gradient_checkpointing = True
|
| 440 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
| 441 |
+
_skip_keys_device_placement = "past_key_values"
|
| 442 |
+
|
| 443 |
+
def _init_weights(self, module):
|
| 444 |
+
std = self.config.initializer_range
|
| 445 |
+
if isinstance(module, nn.Linear):
|
| 446 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 447 |
+
if module.bias is not None:
|
| 448 |
+
module.bias.data.zero_()
|
| 449 |
+
elif isinstance(module, nn.Embedding):
|
| 450 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 451 |
+
if module.padding_idx is not None:
|
| 452 |
+
module.weight.data[module.padding_idx].zero_()
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class LlamaModel(LlamaPreTrainedModel):
|
| 456 |
+
"""
|
| 457 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
config: LlamaConfig
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
def __init__(self, config: LlamaConfig):
|
| 464 |
+
super().__init__(config)
|
| 465 |
+
self.padding_idx = config.pad_token_id
|
| 466 |
+
self.vocab_size = config.vocab_size
|
| 467 |
+
|
| 468 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 469 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 470 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 471 |
+
|
| 472 |
+
self.gradient_checkpointing = False
|
| 473 |
+
# Initialize weights and apply final processing
|
| 474 |
+
try_import_flash_attention()
|
| 475 |
+
self.post_init()
|
| 476 |
+
|
| 477 |
+
def get_input_embeddings(self):
|
| 478 |
+
return self.embed_tokens
|
| 479 |
+
|
| 480 |
+
def set_input_embeddings(self, value):
|
| 481 |
+
self.embed_tokens = value
|
| 482 |
+
|
| 483 |
+
def forward(
|
| 484 |
+
self,
|
| 485 |
+
input_ids: torch.LongTensor = None,
|
| 486 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 487 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 488 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 489 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 490 |
+
use_cache: Optional[bool] = None,
|
| 491 |
+
output_attentions: Optional[bool] = None,
|
| 492 |
+
output_hidden_states: Optional[bool] = None,
|
| 493 |
+
return_dict: Optional[bool] = None,
|
| 494 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 495 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 496 |
+
output_hidden_states = (
|
| 497 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 498 |
+
)
|
| 499 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 500 |
+
|
| 501 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 502 |
+
|
| 503 |
+
# retrieve input_ids and inputs_embeds
|
| 504 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 505 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 506 |
+
elif input_ids is not None:
|
| 507 |
+
batch_size, seq_length = input_ids.shape
|
| 508 |
+
elif inputs_embeds is not None:
|
| 509 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 510 |
+
else:
|
| 511 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 512 |
+
|
| 513 |
+
# position_ids = None
|
| 514 |
+
|
| 515 |
+
if inputs_embeds is None:
|
| 516 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 517 |
+
|
| 518 |
+
hidden_states = inputs_embeds
|
| 519 |
+
bsz = hidden_states.size(0)
|
| 520 |
+
|
| 521 |
+
if self.gradient_checkpointing and self.training:
|
| 522 |
+
if use_cache:
|
| 523 |
+
logger.warning_once(
|
| 524 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 525 |
+
)
|
| 526 |
+
use_cache = False
|
| 527 |
+
|
| 528 |
+
if (
|
| 529 |
+
((attention_mask is not None) and (not attention_mask.all().item()))
|
| 530 |
+
and not use_cache
|
| 531 |
+
):
|
| 532 |
+
try: # for flash-attn latest version
|
| 533 |
+
hidden_states, unpad_indices, cu_seqlens, max_seqlen, _ = unpad_input(hidden_states, attention_mask)
|
| 534 |
+
except: # for flash-attn 2.3.3 verstion
|
| 535 |
+
hidden_states, unpad_indices, cu_seqlens, max_seqlen = unpad_input(hidden_states, attention_mask)
|
| 536 |
+
unpadded_lengths = (cu_seqlens, max_seqlen)
|
| 537 |
+
else:
|
| 538 |
+
unpadded_lengths = None
|
| 539 |
+
|
| 540 |
+
# decoder layers
|
| 541 |
+
all_hidden_states = () if output_hidden_states else None
|
| 542 |
+
all_self_attns = () if output_attentions else None
|
| 543 |
+
next_decoder_cache = () if use_cache else None
|
| 544 |
+
|
| 545 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 546 |
+
if output_hidden_states:
|
| 547 |
+
if unpadded_lengths is not None:
|
| 548 |
+
all_hidden_states += (pad_input(hidden_states, unpad_indices, bsz, max_seqlen),)
|
| 549 |
+
else:
|
| 550 |
+
all_hidden_states += (hidden_states,)
|
| 551 |
+
|
| 552 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 553 |
+
|
| 554 |
+
if self.gradient_checkpointing and self.training:
|
| 555 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 556 |
+
decoder_layer,
|
| 557 |
+
hidden_states,
|
| 558 |
+
attention_mask,
|
| 559 |
+
position_ids,
|
| 560 |
+
None,
|
| 561 |
+
unpadded_lengths,
|
| 562 |
+
output_attentions,
|
| 563 |
+
False,
|
| 564 |
+
use_reentrant=False
|
| 565 |
+
)
|
| 566 |
+
else:
|
| 567 |
+
layer_outputs = decoder_layer(
|
| 568 |
+
hidden_states,
|
| 569 |
+
attention_mask=attention_mask,
|
| 570 |
+
position_ids=position_ids,
|
| 571 |
+
past_key_value=past_key_value,
|
| 572 |
+
unpadded_lengths=unpadded_lengths,
|
| 573 |
+
output_attentions=output_attentions,
|
| 574 |
+
use_cache=use_cache,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
hidden_states = layer_outputs[0]
|
| 578 |
+
|
| 579 |
+
if use_cache:
|
| 580 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 581 |
+
|
| 582 |
+
if output_attentions:
|
| 583 |
+
all_self_attns += (layer_outputs[1],)
|
| 584 |
+
|
| 585 |
+
if unpadded_lengths is not None:
|
| 586 |
+
hidden_states = pad_input(hidden_states, unpad_indices, bsz, max_seqlen)
|
| 587 |
+
hidden_states = self.norm(hidden_states)
|
| 588 |
+
|
| 589 |
+
# add hidden states from the last decoder layer
|
| 590 |
+
if output_hidden_states:
|
| 591 |
+
all_hidden_states += (hidden_states,)
|
| 592 |
+
|
| 593 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 594 |
+
if not return_dict:
|
| 595 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 596 |
+
return BaseModelOutputWithPast(
|
| 597 |
+
last_hidden_state=hidden_states,
|
| 598 |
+
past_key_values=next_cache,
|
| 599 |
+
hidden_states=all_hidden_states,
|
| 600 |
+
attentions=all_self_attns,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
| 605 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 606 |
+
|
| 607 |
+
def __init__(self, config):
|
| 608 |
+
super().__init__(config)
|
| 609 |
+
self.model = LlamaModel(config)
|
| 610 |
+
self.vocab_size = config.vocab_size
|
| 611 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 612 |
+
try_import_flash_attention()
|
| 613 |
+
# Initialize weights and apply final processing
|
| 614 |
+
self.post_init()
|
| 615 |
+
|
| 616 |
+
def get_input_embeddings(self):
|
| 617 |
+
return self.model.embed_tokens
|
| 618 |
+
|
| 619 |
+
def set_input_embeddings(self, value):
|
| 620 |
+
self.model.embed_tokens = value
|
| 621 |
+
|
| 622 |
+
def get_output_embeddings(self):
|
| 623 |
+
return self.lm_head
|
| 624 |
+
|
| 625 |
+
def set_output_embeddings(self, new_embeddings):
|
| 626 |
+
self.lm_head = new_embeddings
|
| 627 |
+
|
| 628 |
+
def set_decoder(self, decoder):
|
| 629 |
+
self.model = decoder
|
| 630 |
+
|
| 631 |
+
def get_decoder(self):
|
| 632 |
+
return self.model
|
| 633 |
+
|
| 634 |
+
def forward(
|
| 635 |
+
self,
|
| 636 |
+
input_ids: torch.LongTensor = None,
|
| 637 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 638 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 639 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 640 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 641 |
+
labels: Optional[torch.LongTensor] = None,
|
| 642 |
+
use_cache: Optional[bool] = None,
|
| 643 |
+
output_attentions: Optional[bool] = None,
|
| 644 |
+
output_hidden_states: Optional[bool] = None,
|
| 645 |
+
return_dict: Optional[bool] = None,
|
| 646 |
+
avg_valid_labels_per_chunk: Optional[float] = None,
|
| 647 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 648 |
+
r"""
|
| 649 |
+
Args:
|
| 650 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 651 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 652 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 653 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 654 |
+
|
| 655 |
+
Returns:
|
| 656 |
+
|
| 657 |
+
Example:
|
| 658 |
+
|
| 659 |
+
```python
|
| 660 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
| 661 |
+
|
| 662 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 663 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 664 |
+
|
| 665 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 666 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 667 |
+
|
| 668 |
+
>>> # Generate
|
| 669 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 670 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 671 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 672 |
+
```"""
|
| 673 |
+
|
| 674 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 675 |
+
output_hidden_states = (
|
| 676 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 677 |
+
)
|
| 678 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 679 |
+
|
| 680 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 681 |
+
outputs = self.model(
|
| 682 |
+
input_ids=input_ids,
|
| 683 |
+
attention_mask=attention_mask,
|
| 684 |
+
position_ids=position_ids,
|
| 685 |
+
past_key_values=past_key_values,
|
| 686 |
+
inputs_embeds=inputs_embeds,
|
| 687 |
+
use_cache=use_cache,
|
| 688 |
+
output_attentions=output_attentions,
|
| 689 |
+
output_hidden_states=output_hidden_states,
|
| 690 |
+
return_dict=return_dict,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
hidden_states = outputs[0]
|
| 694 |
+
logits = self.lm_head(hidden_states).float()
|
| 695 |
+
|
| 696 |
+
loss = None
|
| 697 |
+
if labels is not None:
|
| 698 |
+
# Shift so that tokens < n predict n
|
| 699 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 700 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 701 |
+
# Flatten the tokens
|
| 702 |
+
loss_fct = CrossEntropyLoss()
|
| 703 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 704 |
+
shift_labels = shift_labels.view(-1)
|
| 705 |
+
# Enable model parallelism
|
| 706 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 707 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 708 |
+
|
| 709 |
+
if not return_dict:
|
| 710 |
+
output = (logits,) + outputs[1:]
|
| 711 |
+
return (loss,) + output if loss is not None else output
|
| 712 |
+
|
| 713 |
+
return CausalLMOutputWithPast(
|
| 714 |
+
loss=loss,
|
| 715 |
+
logits=logits,
|
| 716 |
+
past_key_values=outputs.past_key_values,
|
| 717 |
+
hidden_states=outputs.hidden_states,
|
| 718 |
+
attentions=outputs.attentions,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
def prepare_inputs_for_generation(
|
| 722 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 723 |
+
):
|
| 724 |
+
if past_key_values:
|
| 725 |
+
input_ids = input_ids[:, -1:]
|
| 726 |
+
|
| 727 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 728 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 729 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 730 |
+
else:
|
| 731 |
+
model_inputs = {"input_ids": input_ids}
|
| 732 |
+
|
| 733 |
+
model_inputs.update(
|
| 734 |
+
{
|
| 735 |
+
"past_key_values": past_key_values,
|
| 736 |
+
"use_cache": kwargs.get("use_cache"),
|
| 737 |
+
"attention_mask": attention_mask
|
| 738 |
+
}
|
| 739 |
+
)
|
| 740 |
+
return model_inputs
|
| 741 |
+
|
| 742 |
+
@staticmethod
|
| 743 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 744 |
+
reordered_past = ()
|
| 745 |
+
for layer_past in past_key_values:
|
| 746 |
+
reordered_past += (
|
| 747 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 748 |
+
)
|
| 749 |
+
return reordered_past
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
| 753 |
+
def __init__(self, config):
|
| 754 |
+
super().__init__(config)
|
| 755 |
+
self.num_labels = config.num_labels
|
| 756 |
+
self.model = LlamaModel(config)
|
| 757 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 758 |
+
try_import_flash_attention()
|
| 759 |
+
# Initialize weights and apply final processing
|
| 760 |
+
self.post_init()
|
| 761 |
+
|
| 762 |
+
def get_input_embeddings(self):
|
| 763 |
+
return self.model.embed_tokens
|
| 764 |
+
|
| 765 |
+
def set_input_embeddings(self, value):
|
| 766 |
+
self.model.embed_tokens = value
|
| 767 |
+
|
| 768 |
+
def forward(
|
| 769 |
+
self,
|
| 770 |
+
input_ids: torch.LongTensor = None,
|
| 771 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 772 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 773 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 774 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 775 |
+
labels: Optional[torch.LongTensor] = None,
|
| 776 |
+
use_cache: Optional[bool] = None,
|
| 777 |
+
output_attentions: Optional[bool] = None,
|
| 778 |
+
output_hidden_states: Optional[bool] = None,
|
| 779 |
+
return_dict: Optional[bool] = None,
|
| 780 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 781 |
+
r"""
|
| 782 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 783 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 784 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 785 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 786 |
+
"""
|
| 787 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 788 |
+
|
| 789 |
+
transformer_outputs = self.model(
|
| 790 |
+
input_ids,
|
| 791 |
+
attention_mask=attention_mask,
|
| 792 |
+
position_ids=position_ids,
|
| 793 |
+
past_key_values=past_key_values,
|
| 794 |
+
inputs_embeds=inputs_embeds,
|
| 795 |
+
use_cache=use_cache,
|
| 796 |
+
output_attentions=output_attentions,
|
| 797 |
+
output_hidden_states=output_hidden_states,
|
| 798 |
+
return_dict=return_dict,
|
| 799 |
+
)
|
| 800 |
+
hidden_states = transformer_outputs[0]
|
| 801 |
+
logits = self.score(hidden_states)
|
| 802 |
+
|
| 803 |
+
if input_ids is not None:
|
| 804 |
+
batch_size = input_ids.shape[0]
|
| 805 |
+
else:
|
| 806 |
+
batch_size = inputs_embeds.shape[0]
|
| 807 |
+
|
| 808 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 809 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 810 |
+
if self.config.pad_token_id is None:
|
| 811 |
+
sequence_lengths = -1
|
| 812 |
+
else:
|
| 813 |
+
if input_ids is not None:
|
| 814 |
+
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
| 815 |
+
else:
|
| 816 |
+
sequence_lengths = -1
|
| 817 |
+
|
| 818 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 819 |
+
|
| 820 |
+
loss = None
|
| 821 |
+
if labels is not None:
|
| 822 |
+
labels = labels.to(logits.device)
|
| 823 |
+
if self.config.problem_type is None:
|
| 824 |
+
if self.num_labels == 1:
|
| 825 |
+
self.config.problem_type = "regression"
|
| 826 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 827 |
+
self.config.problem_type = "single_label_classification"
|
| 828 |
+
else:
|
| 829 |
+
self.config.problem_type = "multi_label_classification"
|
| 830 |
+
|
| 831 |
+
if self.config.problem_type == "regression":
|
| 832 |
+
loss_fct = MSELoss()
|
| 833 |
+
if self.num_labels == 1:
|
| 834 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 835 |
+
else:
|
| 836 |
+
loss = loss_fct(pooled_logits, labels)
|
| 837 |
+
elif self.config.problem_type == "single_label_classification":
|
| 838 |
+
loss_fct = CrossEntropyLoss()
|
| 839 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 840 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 841 |
+
loss_fct = BCEWithLogitsLoss()
|
| 842 |
+
loss = loss_fct(pooled_logits, labels)
|
| 843 |
+
if not return_dict:
|
| 844 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 845 |
+
return ((loss,) + output) if loss is not None else output
|
| 846 |
+
|
| 847 |
+
return SequenceClassifierOutputWithPast(
|
| 848 |
+
loss=loss,
|
| 849 |
+
logits=pooled_logits,
|
| 850 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 851 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 852 |
+
attentions=transformer_outputs.attentions,
|
| 853 |
+
)
|
DataFlow/dataflow/operators/eval/GeneralText/models/Qurating/qurater_annotate.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_from_disk, load_dataset, concatenate_datasets
|
| 2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 3 |
+
from .modeling.modeling_flash_llama import LlamaForSequenceClassification
|
| 4 |
+
import torch
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class TokenizeAndChunk:
|
| 9 |
+
def __init__(self, tokenizer_name, text_field, tokens_field, tokens, model_cache_dir):
|
| 10 |
+
self.tokens = tokens
|
| 11 |
+
self.tokenizer_name = tokenizer_name
|
| 12 |
+
self.text_field = text_field
|
| 13 |
+
self.tokens_field = tokens_field
|
| 14 |
+
|
| 15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, cache_dir = model_cache_dir)
|
| 16 |
+
self.tokenizer.pad_token_id = 0
|
| 17 |
+
|
| 18 |
+
def __getstate__(self):
|
| 19 |
+
return {
|
| 20 |
+
"tokenizer_name": self.tokenizer_name,
|
| 21 |
+
"text_field": self.text_field,
|
| 22 |
+
"tokens_field": self.tokens_field,
|
| 23 |
+
"tokens": self.tokens,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def __setstate__(self, state):
|
| 27 |
+
self.__init__(**state)
|
| 28 |
+
|
| 29 |
+
def tokenize_and_chunk(self, source_tokens):
|
| 30 |
+
chunks_token_ids = []
|
| 31 |
+
chunks_token_counts = []
|
| 32 |
+
|
| 33 |
+
for seq in source_tokens:
|
| 34 |
+
chunks = torch.tensor(seq, dtype=torch.long).split(self.tokens)
|
| 35 |
+
chunks_token_ids.append([chunk.tolist() for chunk in chunks])
|
| 36 |
+
chunks_token_counts.append([len(x) for x in chunks])
|
| 37 |
+
|
| 38 |
+
return chunks_token_ids, chunks_token_counts
|
| 39 |
+
|
| 40 |
+
def __call__(self, example):
|
| 41 |
+
if self.tokens_field in example:
|
| 42 |
+
source_tokens = example[self.tokens_field]
|
| 43 |
+
else:
|
| 44 |
+
source_tokens = self.tokenizer(example[self.text_field], truncation=False, padding=False, add_special_tokens=False).input_ids
|
| 45 |
+
|
| 46 |
+
chunks_token_ids, chunks_token_counts = self.tokenize_and_chunk(source_tokens)
|
| 47 |
+
|
| 48 |
+
assert len(example[self.text_field]) == len(chunks_token_ids)
|
| 49 |
+
assert len(example[self.text_field]) == len(chunks_token_counts)
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
"chunks_token_ids": chunks_token_ids,
|
| 53 |
+
"chunks_token_counts": chunks_token_counts,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ModelAnnotator:
|
| 58 |
+
def __init__(self, model_name, labels, device_batch_size, device, model_cache_dir):
|
| 59 |
+
|
| 60 |
+
self.model_name = model_name
|
| 61 |
+
self.labels = labels
|
| 62 |
+
self.device_batch_size = device_batch_size
|
| 63 |
+
|
| 64 |
+
self.model = LlamaForSequenceClassification.from_pretrained(
|
| 65 |
+
self.model_name,
|
| 66 |
+
torch_dtype=torch.bfloat16,
|
| 67 |
+
cache_dir=model_cache_dir)
|
| 68 |
+
self.model.config.pad_token_id = 0
|
| 69 |
+
self.model.eval()
|
| 70 |
+
|
| 71 |
+
self.device = device
|
| 72 |
+
print(f"Using device {self.device}")
|
| 73 |
+
self.model.to(self.device)
|
| 74 |
+
|
| 75 |
+
self.num_labels = len(labels)
|
| 76 |
+
assert self.num_labels == self.model.config.num_labels, f"Number of labels ({self.num_labels}) does not match model config ({self.model.config.num_labels})"
|
| 77 |
+
|
| 78 |
+
def __getstate__(self):
|
| 79 |
+
return {
|
| 80 |
+
"model_name": self.model_name,
|
| 81 |
+
"labels": self.labels,
|
| 82 |
+
"device_batch_size": self.device_batch_size,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def __setstate__(self, state):
|
| 86 |
+
self.__init__(**state)
|
| 87 |
+
|
| 88 |
+
@torch.inference_mode()
|
| 89 |
+
def score_chunks(self, chunks_token_ids, chunks_token_counts):
|
| 90 |
+
sorted_indices = torch.argsort(chunks_token_counts)
|
| 91 |
+
|
| 92 |
+
scores = torch.zeros(len(chunks_token_ids), self.num_labels, dtype=torch.float32)
|
| 93 |
+
|
| 94 |
+
for batch_indices in sorted_indices.split(self.device_batch_size):
|
| 95 |
+
max_len = chunks_token_counts[batch_indices].max()
|
| 96 |
+
|
| 97 |
+
input_ids = torch.zeros((len(batch_indices), max_len), dtype=torch.long)
|
| 98 |
+
attention_mask = torch.zeros((len(batch_indices), max_len), dtype=torch.long)
|
| 99 |
+
|
| 100 |
+
for i, j in enumerate(batch_indices):
|
| 101 |
+
seq = chunks_token_ids[j]
|
| 102 |
+
input_ids[i, :len(seq)] = seq
|
| 103 |
+
attention_mask[i, :len(seq)] = 1
|
| 104 |
+
|
| 105 |
+
outputs = self.model(input_ids.to(self.device), attention_mask=attention_mask.to(self.device), use_cache=False)
|
| 106 |
+
scores[batch_indices] = outputs.logits.float().cpu()
|
| 107 |
+
return scores
|
| 108 |
+
|
| 109 |
+
def __call__(self, example, indices):
|
| 110 |
+
num_seqs = len(indices)
|
| 111 |
+
|
| 112 |
+
source_ids = [i for i, counts in enumerate(example["chunks_token_counts"]) for _ in range(len(counts))]
|
| 113 |
+
chunks_token_ids = [torch.tensor(chunk, dtype=torch.long) for chunks in example["chunks_token_ids"] for chunk in chunks]
|
| 114 |
+
flattened_chunks_token_counts = torch.tensor([chunk for chunks in example["chunks_token_counts"] for chunk in chunks], dtype=torch.long)
|
| 115 |
+
|
| 116 |
+
flattened_scores = self.score_chunks(chunks_token_ids, flattened_chunks_token_counts)
|
| 117 |
+
|
| 118 |
+
chunk_token_counts = example["chunks_token_counts"]
|
| 119 |
+
chunk_scores = [[[] for _ in range(num_seqs)] for _ in range(self.num_labels)]
|
| 120 |
+
|
| 121 |
+
for source_id, score in zip(source_ids, flattened_scores):
|
| 122 |
+
for label in range(self.num_labels):
|
| 123 |
+
chunk_scores[label][source_id].append(score[label].item())
|
| 124 |
+
|
| 125 |
+
output = {
|
| 126 |
+
"index": indices,
|
| 127 |
+
"chunk_lengths": chunk_token_counts,
|
| 128 |
+
"length": [sum(counts) for counts in chunk_token_counts],
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
for i, label in enumerate(self.labels):
|
| 132 |
+
output[f"{label}_chunks"] = chunk_scores[i]
|
| 133 |
+
output[f"{label}_average"] = [
|
| 134 |
+
np.average(scores, weights=token_counts).item()
|
| 135 |
+
for scores, token_counts in zip(chunk_scores[i], chunk_token_counts)
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
return output
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
parser = argparse.ArgumentParser()
|
| 143 |
+
parser.add_argument("input", type=str)
|
| 144 |
+
parser.add_argument("output", type=str)
|
| 145 |
+
|
| 146 |
+
parser.add_argument("-F", "--data_files", type=str, nargs="+", default=[])
|
| 147 |
+
parser.add_argument("-S", "--shard", type=int, nargs=2, default=[0, 1])
|
| 148 |
+
parser.add_argument("-M", "--model", type=str, required=True)
|
| 149 |
+
parser.add_argument("-t", "--tokens", type=int, default=512)
|
| 150 |
+
parser.add_argument("--map_batch_size", type=int, default=512)
|
| 151 |
+
parser.add_argument("-b", "--device_batch_size", type=int, default=16)
|
| 152 |
+
parser.add_argument("-w", "--num_workers", type=int, default=1)
|
| 153 |
+
parser.add_argument("--text_field", type=str, default="text")
|
| 154 |
+
parser.add_argument("--tokens_field", type=str, default="input_ids")
|
| 155 |
+
parser.add_argument("--labels", type=str, nargs="+")
|
| 156 |
+
|
| 157 |
+
args = parser.parse_args()
|
| 158 |
+
print(args)
|
| 159 |
+
|
| 160 |
+
if args.input == "json":
|
| 161 |
+
dataset = load_dataset("json", data_files=args.data_files, split="train")
|
| 162 |
+
else:
|
| 163 |
+
dataset = load_from_disk(args.input)
|
| 164 |
+
|
| 165 |
+
src_dataset = dataset.shard(args.shard[1], args.shard[0], contiguous=True)
|
| 166 |
+
dataset = src_dataset
|
| 167 |
+
|
| 168 |
+
print(dataset)
|
| 169 |
+
print("Total number of examples:", len(dataset))
|
| 170 |
+
dataset = dataset.map(
|
| 171 |
+
TokenizeAndChunk(args.model, args.text_field, args.tokens_field, args.tokens),
|
| 172 |
+
batched=True,
|
| 173 |
+
batch_size=args.map_batch_size,
|
| 174 |
+
num_proc=args.num_workers,
|
| 175 |
+
remove_columns=dataset.column_names)
|
| 176 |
+
|
| 177 |
+
print("After tokenization: Total number of examples:", len(dataset))
|
| 178 |
+
dataset = dataset.map(
|
| 179 |
+
ModelAnnotator(args.model, args.labels, args.device_batch_size),
|
| 180 |
+
batched=True,
|
| 181 |
+
with_indices=True,
|
| 182 |
+
batch_size=args.map_batch_size,
|
| 183 |
+
remove_columns=dataset.column_names)
|
| 184 |
+
|
| 185 |
+
dataset = concatenate_datasets([dataset, src_dataset], axis=1)
|
| 186 |
+
|
| 187 |
+
print("After annotation: Total number of examples:", len(dataset))
|
| 188 |
+
|
| 189 |
+
print(f"Saving to {args.output}")
|
| 190 |
+
dataset.save_to_disk(args.output)
|
DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/__pycache__/data_analysis.cpython-310.pyc
ADDED
|
Binary file (1.51 kB). View file
|
|
|
DataFlow/dataflow/operators/eval/GeneralText/models/Superfiltering/data_analysis.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import argparse
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
PROMPT_DICT_NONE = {
|
| 11 |
+
"prompt_input": (
|
| 12 |
+
"{instruction}\n{input}\n"
|
| 13 |
+
),
|
| 14 |
+
"prompt_no_input": (
|
| 15 |
+
"{instruction}\n"
|
| 16 |
+
),
|
| 17 |
+
}
|
| 18 |
+
# Used to get the ppl and emb for the whole input
|
| 19 |
+
def get_perplexity_and_embedding_whole_text(tokenizer, model, text, max_length, device):
|
| 20 |
+
|
| 21 |
+
input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
| 22 |
+
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
outputs = model(input_ids, labels=input_ids.contiguous())
|
| 25 |
+
loss = outputs.loss
|
| 26 |
+
perplexity = torch.exp(loss)
|
| 27 |
+
|
| 28 |
+
return perplexity.to('cpu').item(), loss.to('cpu').item()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Used to get the ppl and emb for part of input, used in conditional version, and token-wise loss
|
| 32 |
+
def get_perplexity_and_embedding_part_text(tokenizer, model, text, target_span, max_length, device):
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
| 36 |
+
|
| 37 |
+
start_index = text.rfind(target_span)
|
| 38 |
+
start_token = len(tokenizer.encode(text[:start_index]))
|
| 39 |
+
end_token = input_ids.shape[1]
|
| 40 |
+
|
| 41 |
+
labels = input_ids.clone()
|
| 42 |
+
labels[0, :start_token] = -100
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
outputs = model(input_ids, labels=labels)
|
| 46 |
+
|
| 47 |
+
loss = outputs.loss
|
| 48 |
+
perplexity = torch.exp(loss)
|
| 49 |
+
|
| 50 |
+
return perplexity.to('cpu').item(), loss.to('cpu').item()
|
| 51 |
+
|
| 52 |
+
except:
|
| 53 |
+
return 0, 0
|
DataFlow/dataflow/operators/eval/GeneralText/models/__pycache__/debertav3_scorer.cpython-310.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|