diff --git a/DataFlow/dataflow/operators/process/AgenticRAG/__pycache__/ContentChooser.cpython-310.pyc b/DataFlow/dataflow/operators/process/AgenticRAG/__pycache__/ContentChooser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0d0e59e4dea2293be59bba06a7751ee51224903 Binary files /dev/null and b/DataFlow/dataflow/operators/process/AgenticRAG/__pycache__/ContentChooser.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/AgenticRAG/__pycache__/__init__.cpython-310.pyc b/DataFlow/dataflow/operators/process/AgenticRAG/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09d3fa21913ee7e43f45b11c5b832d44e97b8fa2 Binary files /dev/null and b/DataFlow/dataflow/operators/process/AgenticRAG/__pycache__/__init__.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/__init__.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c87e7e161ecf0e831c4c52d65765811d0548cb78 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/__init__.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/ccnet_deduplicator.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/ccnet_deduplicator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef15b208a4d2bedc3b2de0d6c92ad72de15536a0 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/ccnet_deduplicator.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/hash_deduplicator.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/hash_deduplicator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d08db98581d1afe53bf2a7ea619472cc18cbab48 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/hash_deduplicator.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/minhash_deduplicator.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/minhash_deduplicator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e39bcb8ef667d904c91bd07d79ecd0f79db4af39 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/minhash_deduplicator.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/ngramhash_deduplicator.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/ngramhash_deduplicator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..980059eb29003d905d7bdd01dd4166bfc51791f7 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/ngramhash_deduplicator.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/sem_deduplicator.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/sem_deduplicator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d109d0c91bbf26ae9a5af2c8d4544ae5434619a7 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/sem_deduplicator.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/simhash_deduplicator.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/simhash_deduplicator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99cd0f38acdbf622bd85b16c90b5a7699ba2192d Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/deduplicators/__pycache__/simhash_deduplicator.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/debertav3_filter.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/debertav3_filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf53bb24216a6e799c3a51ea66ddbd2f91fffd30 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/debertav3_filter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/fineweb_edu_filter.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/fineweb_edu_filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce1cb07c521f49e2489d3b0d4ad606f8a446fc0a Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/fineweb_edu_filter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/heuristics.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/heuristics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7408b7a4fbe92b6a8cab0bcd19fd37aa0d950451 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/heuristics.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/lexical_diversity_filter.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/lexical_diversity_filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a2d144b668bf1627d6a0d6a3288e372a7390a6 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/lexical_diversity_filter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/perplexity_filter.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/perplexity_filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..542a5dc59d009ea1cbb1df43d6c68e0d2b5d27aa Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/perplexity_filter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/perspective_filter.cpython-310.pyc b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/perspective_filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aab3311bae51da37d08b0bdbc9acc5fd1d544b6 Binary files /dev/null and b/DataFlow/dataflow/operators/process/GeneralText/filters/__pycache__/perspective_filter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerFormatterFilter.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerFormatterFilter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df081e3be223d56f151f60e68880b3df051daf07 Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerFormatterFilter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerGroundTruthFilter.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerGroundTruthFilter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b23c59c36ca036c787b8569058ce01a4b3c112a5 Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerGroundTruthFilter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerJudger_MathVerify.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerJudger_MathVerify.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b201594a5557cec3a0ef8112536640fca0b5161 Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerJudger_MathVerify.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerNgramFilter.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerNgramFilter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8f90d507b98d66c12f3e899aa3d42d9c88edac3 Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerNgramFilter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerPipelineRoot.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerPipelineRoot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b348d82bf8dd22bdc4ace9b1efe6e2c6c64bcbaf Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerPipelineRoot.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerTokenLengthFilter.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerTokenLengthFilter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2da42cf77bf53758304f34b122c4b4311837727b Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/AnswerTokenLengthFilter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/QuestionFilter.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/QuestionFilter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4fdaad3b92aa32744e2fce672aa5a92a540059c Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/QuestionFilter.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/Reasoning/__pycache__/__init__.cpython-310.pyc b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c205573b2affab5076ad1e344377110e6e10fd3d Binary files /dev/null and b/DataFlow/dataflow/operators/process/Reasoning/__pycache__/__init__.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/process/__pycache__/__init__.cpython-310.pyc b/DataFlow/dataflow/operators/process/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3e601985881b8626782bda41ff6304d1dcc2de6 Binary files /dev/null and b/DataFlow/dataflow/operators/process/__pycache__/__init__.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__init__.py b/DataFlow/dataflow/operators/refine/GeneralText/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/__init__.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3a8151cde738bb7244fdf3c1dcfd273cb6a8f13 Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/__init__.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/html_entity_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/html_entity_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cb4c61bb16c0ce2718bfe591c207b298779538d Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/html_entity_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/html_url_remover_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/html_url_remover_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea175549ecce8a03a404cb7d85af46983de5b86f Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/html_url_remover_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/lowercase_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/lowercase_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cd30f3578f03947fa932572a2264b73e11f3ff6 Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/lowercase_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/ner_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/ner_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a2c730b7844a3f9a4ba7b710535b4527ac868a Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/ner_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/pii_anonymize_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/pii_anonymize_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bda9f7e88b422f7bc246cbff5e3e4a456d03fce Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/pii_anonymize_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/ref_removal_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/ref_removal_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a76c5cf95da86935b1bf2adfb0eaa3e6bcebb213 Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/ref_removal_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_contractions_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_contractions_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdf8bff7cdeac177fa64101b8a07438e8990e8cf Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_contractions_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_emoji_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_emoji_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b9cd20a70e069ed141477ca79f1d386cf1b398a Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_emoji_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_extra_spaces_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_extra_spaces_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb4cadc03c575195823cbebd036c0f3d9edaf347 Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_extra_spaces_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_image_ref_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_image_ref_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69aa6301ba1ed68dd5b8154d7a161b0a57738e5d Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_image_ref_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_number_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_number_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d140d73f07ca5a861884e3b6d2b76bb0f1ab91e3 Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_number_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_punctuation_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_punctuation_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bda079eb245f8f6dc10c12e6ff8f0bf16bb039c Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_punctuation_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_repetitions_punctuation_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_repetitions_punctuation_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cd464b4720dc1656d93de974b13401364d99e3b Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_repetitions_punctuation_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_stopwords_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_stopwords_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..486451bc149b102b68eef26962cb7caba6172183 Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/remove_stopwords_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/spelling_correction_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/spelling_correction_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..681966a6400ffd6c9e593f774892792c05734328 Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/spelling_correction_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/stemming_lemmatization_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/stemming_lemmatization_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a88b8661a2dab806d68daa04ac4f109bd7ad8e7b Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/stemming_lemmatization_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/text_normalization_refiner.cpython-310.pyc b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/text_normalization_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0239a89eb6d54d1b70501928c74d14de5132591f Binary files /dev/null and b/DataFlow/dataflow/operators/refine/GeneralText/__pycache__/text_normalization_refiner.cpython-310.pyc differ diff --git a/DataFlow/dataflow/operators/refine/GeneralText/html_entity_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/html_entity_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..a6198769c12b872b543eb2120b602032bfb76533 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/html_entity_refiner.py @@ -0,0 +1,70 @@ +import re +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class HtmlEntityRefiner(OperatorABC): + def __init__(self, html_entities: list = [ + "nbsp", "lt", "gt", "amp", "quot", "apos", "hellip", "ndash", "mdash", + "lsquo", "rsquo", "ldquo", "rdquo" + ]): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + # 从参数中获取自定义 HTML 实体列表,如果未提供则使用默认列表 + self.html_entities = html_entities + + # 构建正则表达式模式,匹配所有定义的 HTML 实体 + # 包括以下几种形式: + # 1. &实体名; + # 2. &实体名; (全角 &) + # 3. &实体名; (中文分号) + # 4. &实体名; (全角 & + 中文分号) + entity_patterns = [] + for entity in self.html_entities: + # &实体名; + entity_patterns.append(fr'&{entity};') + # &实体名; (全角 &) + entity_patterns.append(fr'&{entity};') + # &实体名; (中文分号) + entity_patterns.append(fr'&{entity};') + # &实体名; (全角 & + 中文分号) + entity_patterns.append(fr'&{entity};') + + # 编译正则表达式 + self.html_entity_regex = re.compile('|'.join(entity_patterns)) + + @staticmethod + def get_desc(lang): + return "去除文本中的HTML实体" if lang == "zh" else "Remove HTML entities from the text." + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + refined_text = original_text + + # 使用正则表达式替换所有匹配的HTML实体为空字符串 + refined_text = self.html_entity_regex.sub('', refined_text) + + # 检查文本是否被修改 + if original_text != refined_text: + item = refined_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/html_url_remover_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/html_url_remover_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..16e0cb4a6494f312cb71a753d0664a25dcd7d26b --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/html_url_remover_refiner.py @@ -0,0 +1,46 @@ +import re +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class HtmlUrlRemoverRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + @staticmethod + def get_desc(lang: str = "zh"): + return "去除文本中的URL和HTML标签" if lang == "zh" else "Remove URLs and HTML tags from the text." + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + refined_text = original_text + + # Remove URLs + refined_text = re.sub(r'https?:\/\/\S+[\r\n]*', '', refined_text, flags=re.MULTILINE) + # Remove HTML tags + refined_text = re.sub(r'<.*?>', '', refined_text) + + if original_text != refined_text: + item = refined_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/lowercase_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/lowercase_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf5619927403d88737e6c5d218fe10e6f2c6ede --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/lowercase_refiner.py @@ -0,0 +1,36 @@ +import re +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class LowercaseRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + lower_text = original_text.lower() + if original_text != lower_text: + item = lower_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {lower_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/ner_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/ner_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fb5655762df3d96b24002eddf91e21b10d3c93 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/ner_refiner.py @@ -0,0 +1,79 @@ +import spacy +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +ENTITY_LABELS = { + "PERSON": "[PERSON]", + "ORG": "[ORG]", + "GPE": "[GPE]", + "LOC": "[LOC]", + "PRODUCT": "[PRODUCT]", + "EVENT": "[EVENT]", + "DATE": "[DATE]", + "TIME": "[TIME]", + "MONEY": "[MONEY]", + "PERCENT": "[PERCENT]", + "QUANTITY": "[QUANTITY]", + "ORDINAL": "[ORDINAL]", + "CARDINAL": "[CARDINAL]", + "NORP": "[NORP]", + "FAC": "[FAC]", + "LAW": "[LAW]", + "LANGUAGE": "[LANGUAGE]", + "WORK_OF_ART": "[WORK_OF_ART]", + "LAW": "[LAW]", + "ORDINAL": "[ORDINAL]", + "CARDINAL": "[CARDINAL]", + "PERCENT": "[PERCENT]", + "QUANTITY": "[QUANTITY]", + "DATE": "[DATE]", + "TIME": "[TIME]", + "URL": "[URL]", + "EMAIL": "[EMAIL]", + "MONEY": "[MONEY]", + "FAC": "[FAC]", + "PRODUCT": "[PRODUCT]", + "EVENT": "[EVENT]", + "WORK_OF_ART": "[WORK_OF_ART]", + "LANGUAGE": "[LANGUAGE]", + "NORP": "[NORP]" +} + +@OPERATOR_REGISTRY.register() +class NERRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + self.nlp = spacy.load("en_core_web_sm") + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + refined_text = original_text + + doc = self.nlp(refined_text) + for ent in doc.ents: + if ent.label_ in ENTITY_LABELS : + refined_text = refined_text.replace(ent.text, f"[{ent.label_}]") + + if original_text != refined_text: + item = refined_text + modified = True + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/pii_anonymize_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/pii_anonymize_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..8b5d43469e66376ce6828684073f94a8a1bb7995 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/pii_anonymize_refiner.py @@ -0,0 +1,62 @@ +from tqdm import tqdm +from transformers import AutoModelForTokenClassification, AutoTokenizer +from presidio_analyzer.nlp_engine import TransformersNlpEngine +from presidio_analyzer import AnalyzerEngine +from presidio_anonymizer import AnonymizerEngine +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class PIIAnonymizeRefiner(OperatorABC): + def __init__(self, lang='en', device='cuda', model_cache_dir='./dataflow_cache', model_name='dslim/bert-base-NER', ): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + self.lang = lang + self.device = device + self.model_cache_dir = model_cache_dir + self.model_name = model_name + model_name = 'dslim/bert-base-NER' + self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir) + self.model = AutoModelForTokenClassification.from_pretrained(model_name, cache_dir=self.model_cache_dir).to(self.device) + model_config = [{ + "lang_code": self.lang, + "model_name": { + "spacy": "en_core_web_sm", + "transformers": model_name + } + }] + + self.nlp_engine = TransformersNlpEngine(models=model_config) + self.analyzer = AnalyzerEngine(nlp_engine=self.nlp_engine) + self.anonymizer = AnonymizerEngine() + + @staticmethod + def get_desc(lang: str = "zh"): + return "去除文本中的URL和HTML标签" if lang == "zh" else "Remove URLs and HTML tags from the text." + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + dataframe = storage.read("dataframe") + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + anonymized_count = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + results = self.analyzer.analyze(original_text, language=self.lang) + anonymized_text = self.anonymizer.anonymize(original_text, results) + if original_text != anonymized_text.text: + item = anonymized_text.text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {anonymized_text.text[:30]}...") + + refined_data.append(item) + if modified: + anonymized_count += 1 + self.logger.debug(f"Item modified, total modified so far: {anonymized_count}") + self.logger.info(f"Refining Complete. Total modified items: {anonymized_count}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/ref_removal_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/ref_removal_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..836fd79da6ff821e208f565d01f01168ea5d6d31 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/ref_removal_refiner.py @@ -0,0 +1,62 @@ +import re +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class ReferenceRemoverRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__}...") + + @staticmethod + def get_desc(lang): + return "删除文本中未闭合的引用标签和引用链接" if lang == "zh" else "Remove unclosed reference tags and citation links from the text." + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + dataframe = storage.read("dataframe") + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + numbers = 0 + # 定义要删除的模式 - 更全面的版本 + # 1. 所有标签及其内容(包括各种不完整形式) + ref_pattern = re.compile( + r']*>.*?|' # 完整的ref标签 + r']*>[^<]*$|' # 不完整的ref标签(没有闭合) + r']*>.*?/br' # ref标签后跟/br(如你示例中的情况) + ) + + # 2. 所有{{cite}}模板及其内容(包括各种不完整形式) + cite_pattern = re.compile( + r'\{\{cite\s+\w+\|[^}]*\}\}|' # 完整的cite模板 + r'\{\{cite\s+\w+\|[^}]*$' # 不完整的cite模板(没有闭合) + ) + + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + refined_text = original_text + + # 删除所有未闭合的ref标签 + refined_text, ref_count = ref_pattern.subn('', refined_text) + + # 删除所有不完整的cite模板 + refined_text, cite_count = cite_pattern.subn('', refined_text) + + # 检查是否有任何修改 + if ref_count > 0 or cite_count > 0: + modified = True + numbers += 1 + self.logger.debug(f"Item modified, removed {ref_count} ref tags and {cite_count} cite templates") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_contractions_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_contractions_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..5e27be943a9312321d8fea74f7d3a3d106c12538 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_contractions_refiner.py @@ -0,0 +1,36 @@ +import contractions +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class RemoveContractionsRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + dataframe = storage.read("dataframe") + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + expanded_text = contractions.fix(original_text) + if original_text != expanded_text: + item = expanded_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {expanded_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_emoji_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_emoji_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..a19e434cf2570d7cfde86c3f2d704736bac86d48 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_emoji_refiner.py @@ -0,0 +1,58 @@ +import re +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + + +@OPERATOR_REGISTRY.register() +class RemoveEmojiRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.refiner_name = 'RemoveEmojiRefiner' + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + # Emoji pattern for matching emojis in the text + self.emoji_pattern = re.compile( + "[" + u"\U0001F600-\U0001F64F" # Emoticons + u"\U0001F300-\U0001F5FF" # Miscellaneous symbols and pictographs + u"\U0001F680-\U0001F6FF" # Transport and map symbols + u"\U0001F1E0-\U0001F1FF" # Flags + u"\U00002702-\U000027B0" # Dingbats + u"\U000024C2-\U0001F251" # Enclosed characters + "]+", + flags=re.UNICODE + ) + + @staticmethod + def get_desc(lang: str = "zh"): + return "去除文本中的表情符号" if lang == "zh" else "Remove emojis from the text." + + def run(self, storage: DataFlowStorage, input_key: str): + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + self.logger.info(f"Running {self.__class__.__name__} with input_key = {input_key}...") + + for item in tqdm(dataframe[input_key], desc=f"Implementing {self.refiner_name}"): + modified = False + original_text = item + no_emoji_text = self.emoji_pattern.sub(r'', original_text) + + if original_text != no_emoji_text: + item = no_emoji_text + modified = True + self.logger.debug(f"Modified text for key '{input_key}': Original: {original_text[:30]}... -> Refined: {no_emoji_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + + dataframe[input_key] = refined_data + storage.write(dataframe) + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + + return [input_key] diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_extra_spaces_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_extra_spaces_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6f5051540329261c848ae4d7f0b7c1a91b5e3e --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_extra_spaces_refiner.py @@ -0,0 +1,44 @@ +import re +from tqdm import tqdm +from dataflow.utils.storage import DataFlowStorage +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class RemoveExtraSpacesRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + @staticmethod + def get_desc(lang: str = "zh"): + return "去除文本中的多余空格" if lang == "zh" else "Remove extra spaces in the text." + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + dataframe = storage.read("dataframe") + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + numbers = 0 + refined_data = [] + + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + refined_text = " ".join(original_text.split()) # Remove extra spaces + + if original_text != refined_text: + item = refined_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + + dataframe[self.input_key] = refined_data + storage.write(dataframe) + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + + return [self.input_key] diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_image_ref_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_image_ref_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..478322f9a797561a72a7220ed0ca1ea3163e92e5 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_image_ref_refiner.py @@ -0,0 +1,57 @@ +import re +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class RemoveImageRefsRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.image_pattern = re.compile( + r'!\[\]\(images\/[0-9a-fA-F]\.jpg\)|' + r'[a-fA-F0-9]+\.[a-zA-Z]{3,4}\)|' + r'!\[\]\(images\/[a-f0-9]|' + r'图\s+\d+-\d+:[\u4e00-\u9fa5a-zA-Z0-9]+|' + r'(?:[0-9a-zA-Z]+){7,}|' # 正则5 + r'(?:[一二三四五六七八九十零壹贰叁肆伍陆柒捌玖拾佰仟万亿]+){5,}|' # 正则6(汉字数字) + r"u200e|" + r"÷|\? :|" + r"[�□]|\{\/U\}|" + r"U\+26[0-F][0-D]|U\+273[3-4]|U\+1F[3-6][0-4][0-F]|U\+1F6[8-F][0-F]" + ) + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + @staticmethod + def get_desc(lang: str = "zh"): + return "去除文本中的图片引用" if lang == "zh" else "Remove image references in text." + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + dataframe = storage.read("dataframe") + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + # 移除所有图片引用格式[1,2](@ref) + cleaned_text = self.image_pattern.sub('', original_text) + + if original_text != cleaned_text: + item = cleaned_text + modified = True + # 调试日志:显示修改前后的对比 + self.logger.debug(f"Modified text for key '{self.input_key}':") + self.logger.debug(f"Original: {original_text[:100]}...") + self.logger.debug(f"Refined : {cleaned_text[:100]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_number_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_number_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..10c4c633613a137d0bc876136caa8a44d1a49194 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_number_refiner.py @@ -0,0 +1,35 @@ +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class RemoveNumberRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + dataframe = storage.read("dataframe") + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + no_number_text = ''.join([char for char in original_text if not char.isdigit()]) + if original_text != no_number_text: + item = no_number_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {no_number_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_punctuation_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_punctuation_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..33805ffbb4ec47da2ac15ad1d9e944b47dcc9b99 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_punctuation_refiner.py @@ -0,0 +1,38 @@ +import string +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class RemovePunctuationRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + self.punct_to_remove = string.punctuation + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + no_punct_text = original_text.translate(str.maketrans('', '', self.punct_to_remove)) + + if original_text != no_punct_text: + item = no_punct_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {no_punct_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_repetitions_punctuation_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_repetitions_punctuation_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..bae6c004cb17032f7c5d2affb763f638fba18969 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_repetitions_punctuation_refiner.py @@ -0,0 +1,40 @@ +import re +import string +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class RemoveRepetitionsPunctuationRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + self.punct_to_remove = string.punctuation + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + no_extra_punct_text = re.sub(r'([^\w\s_])\1+|(_)\2+', r'\1\2', original_text) + + if original_text != no_extra_punct_text: + item = no_extra_punct_text + modified = True + + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {no_extra_punct_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/remove_stopwords_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/remove_stopwords_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..be631e9fae0428e29a9921f6cf6e0c26e23d37cf --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/remove_stopwords_refiner.py @@ -0,0 +1,48 @@ +import nltk +from nltk.corpus import stopwords +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class RemoveStopwordsRefiner(OperatorABC): + def __init__(self, model_cache_dir: str = './dataflow_cache'): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + self.model_cache_dir = model_cache_dir + nltk.data.path.append(self.model_cache_dir) + nltk.download('stopwords', download_dir=self.model_cache_dir) + + def remove_stopwords(self, text): + words = text.split() + stopwords_list = set(stopwords.words('english')) + refined_words = [word for word in words if word.lower() not in stopwords_list] + return " ".join(refined_words) + + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + dataframe = storage.read("dataframe") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + refined_text = self.remove_stopwords(original_text) + + if original_text != refined_text: + item = refined_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/operators/refine/GeneralText/text_normalization_refiner.py b/DataFlow/dataflow/operators/refine/GeneralText/text_normalization_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..48e79a0c7f12e94d78cc3cdab6799f018cbc3ce7 --- /dev/null +++ b/DataFlow/dataflow/operators/refine/GeneralText/text_normalization_refiner.py @@ -0,0 +1,55 @@ +import re +from datetime import datetime +from tqdm import tqdm +from dataflow import get_logger +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.registry import OPERATOR_REGISTRY + +@OPERATOR_REGISTRY.register() +class TextNormalizationRefiner(OperatorABC): + def __init__(self): + self.logger = get_logger() + self.logger.info(f"Initializing {self.__class__.__name__} ...") + + def run(self, storage: DataFlowStorage, input_key: str): + self.input_key = input_key + dataframe = storage.read("dataframe") + self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...") + numbers = 0 + refined_data = [] + for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"): + modified = False + original_text = item + refined_text = original_text + + refined_text = re.sub(r'(\d{1,2})[/.](\d{1,2})[/.](\d{2,4})', r'\3-\2-\1', refined_text) + date_patterns = [ + (r'\b(\w+)\s+(\d{1,2}),\s+(\d{4})\b', '%B %d, %Y'), + (r'\b(\d{1,2})\s+(\w+)\s+(\d{4})\b', '%d %B %Y') + ] + for pattern, date_format in date_patterns: + match = re.search(pattern, refined_text) + if match: + date_str = match.group(0) + try: + parsed_date = datetime.strptime(date_str, date_format) + refined_text = refined_text.replace(date_str, parsed_date.strftime('%Y-%m-%d')) + except ValueError: + pass + + refined_text = re.sub(r'\$\s?(\d+)', r'\1 USD', refined_text) + + if original_text != refined_text: + item = refined_text + modified = True + self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...") + + refined_data.append(item) + if modified: + numbers += 1 + self.logger.debug(f"Item modified, total modified so far: {numbers}") + self.logger.info(f"Refining Complete. Total modified items: {numbers}") + dataframe[self.input_key] = refined_data + output_file = storage.write(dataframe) + return [self.input_key] \ No newline at end of file diff --git a/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/.gitkeep b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/bash/pipeline_full.sh b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/bash/pipeline_full.sh new file mode 100644 index 0000000000000000000000000000000000000000..6e9db18f7cb7de1a621e62ae787e0f3f74e79829 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/bash/pipeline_full.sh @@ -0,0 +1,16 @@ +# # ------------------------------ Question ------------------------------# +# # Step 0, Initial Clustering and filter +echo -e "\033[32m===== [Step 0] Filter =====\033[0m" +python pipeline_step.py --yaml_path dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/process/ContentChooser.yaml --step_name ContentChooser + +#Step 1, Prompt Synthesis +#echo -e "\033[32m===== [Step 1] Prompt Synthesis =====\033[0m" +#python pipeline_step.py --yaml_path dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/AutoPromptGenerator.yaml --step_name AutoPromptGenerator + +# Step 2, QA Synthesis +#echo -e "\033[32m===== [Step 2] QA Synthesis =====\033[0m" +#python pipeline_step.py --yaml_path dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/QAGenerator.yaml --step_name QAGenerator + +# Step 3, QA Scorer +#echo -e "\033[32m===== [Step 3] QA Scorer =====\033[0m" +#python pipeline_step.py --yaml_path dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/QAScorer.yaml --step_name QAScorer \ No newline at end of file diff --git a/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/AutoPromptGenerator.yaml b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/AutoPromptGenerator.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/QAGenerator.yaml b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/QAGenerator.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/QAScorer.yaml b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/generate/QAScorer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/process/ContentChooser.yaml b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/process/ContentChooser.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e158703b89991129f2879ca209a2db4037e7a72d --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/AgenticRAGPipeline/yaml/process/ContentChooser.yaml @@ -0,0 +1,6 @@ +input_key: "text" +embedding_model_path: "/mnt/public/data/lh/models/hub/gte-Qwen2-7B-instruct" +num_samples: 5 +method: "kcenter" +input_file: "dataflow/example/AgenticRAGPipeline/pipeline_small_chunk.json" +output_file: "dataflow/example/AgenticRAGPipeline/pipeline_step1_out.json" \ No newline at end of file diff --git a/DataFlow/dataflow/scripts/pipelines/CodePipeline/.gitkeep b/DataFlow/dataflow/scripts/pipelines/CodePipeline/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/AnswerGenerator.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/AnswerGenerator.yaml new file mode 100644 index 0000000000000000000000000000000000000000..268ec180c43124f047e94630426422ee392da4c0 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/AnswerGenerator.yaml @@ -0,0 +1,9 @@ +input_file: "dataflow/example/ReasoningPipeline/pipeline_math_step6out_gt.jsonl" +output_file: "dataflow/example/ReasoningPipeline/pipeline_math_step7out_gt.jsonl" +input_key: "instruction" +output_key: "generated_cot" +generator_type: "request" +api_url: "http://123.129.219.111:3000/v1/chat/completions" +max_workers: 100 +model_name: "gpt-4o" +system_prompt: "" diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/PseudoAnswerGenerator.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/PseudoAnswerGenerator.yaml new file mode 100644 index 0000000000000000000000000000000000000000..054b0f35d60c86051f115fa67e9d7e013a92074c --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/PseudoAnswerGenerator.yaml @@ -0,0 +1,13 @@ +input_file: "dataflow/example/ReasoningPipeline/pipeline_math_step6out_no_gt.jsonl" +output_file: "dataflow/example/ReasoningPipeline/pipeline_math_step7out_no_gt.jsonl" +input_key: "instruction" +output_key_answer: "pseudo_answers" # 输出的伪答案的key +output_key_answer_value: "pseudo_answer_value" # 输出的伪答案的值 +output_key_solutions: "pseudo_solutions" # 输出的正确伪答案的所有完整response +output_key_correct_solution_example: "pseudo_correct_solution_example" # 一个正确的response +generator_type: "request" +max_times : 3 +api_url: "http://123.129.219.111:3000/v1/chat/completions" +max_workers: 100 +model_name: "gpt-4o" +system_prompt: "" \ No newline at end of file diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/QuestionCategoryClassifier.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/QuestionCategoryClassifier.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1308dd54e681d2dc5f3730e51f09d0a11fd53ed6 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/QuestionCategoryClassifier.yaml @@ -0,0 +1,9 @@ +input_file: "dataflow/example/ReasoningPipeline/pipeline_math_step4out.jsonl" +output_file: "dataflow/example/ReasoningPipeline/pipeline_math_step5out.jsonl" +input_key: "instruction" +# output_key: "question_category" +generator_type: "request" +api_url: "http://123.129.219.111:3000/v1/chat/completions" +max_workers: 100 +model_name: "gpt-4o" +system_prompt: "" diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/QuestionGenerator.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/QuestionGenerator.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d145a2e8743e5e9b98061abe90a45f7f35e7a3e --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/generate/QuestionGenerator.yaml @@ -0,0 +1,10 @@ +input_file: "dataflow/example/ReasoningPipeline/pipeline_math_step1out.jsonl" +output_file: "dataflow/example/ReasoningPipeline/pipeline_math_step2out.jsonl" +input_key: "instruction" +# output_key: "unused" # output仍然保存在input_key中,output_key目前仅用于统一接口 +generator_type: "request" +num_prompts: 3 +api_url: "http://123.129.219.111:3000/v1/chat/completions" +max_workers: 100 +model_name: "gpt-4o" +system_prompt: "" \ No newline at end of file diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/AnswerPipelineRoot.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/AnswerPipelineRoot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ad3936ea21cc077e3edf4c7b84405f2d5e44148 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/AnswerPipelineRoot.yaml @@ -0,0 +1,6 @@ +input_file: "dataflow/example/ReasoningPipeline//pipeline_math_step5out.jsonl" +input_key: "data" +input_answer_key: "output" +input_gt_key: "golden_answer" +output_file_with_gt: "dataflow/example/ReasoningPipeline//pipeline_math_step6out_gt.jsonl" +output_file_without_gt: "dataflow/example/ReasoningPipeline//pipeline_math_step6out_no_gt.jsonl" diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerAnsSelection.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerAnsSelection.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f28215aa199300537fed02300e2a9570c1bb8eca --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerAnsSelection.yaml @@ -0,0 +1,6 @@ +output_file: "dataflow/example/ReasoningPipeline/pipeline_math_step10out_gt.jsonl" +input_file: 'dataflow/example/ReasoningPipeline/pipeline_math_step9out_gt.jsonl' # Local data path, supports json, jsonl, parquet formats +formatter: "TextFormatter" # Data loader type +compare_method: math_verify # exact or math_verify +test_answer_key: "generated_cot" +gt_answer_key: "golden_answer" \ No newline at end of file diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerNgramFilter.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerNgramFilter.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79ca6e43564152597e60f0630c85b5dd43cd0d96 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerNgramFilter.yaml @@ -0,0 +1,7 @@ +output_file: "dataflow/example/ReasoningPipeline/pipeline_math_step11out_gt.jsonl" +input_file: 'dataflow/example/ReasoningPipeline/pipeline_math_step10out_gt.jsonl' # Local data path, supports json, jsonl, parquet formats +min_score: 0.1 +max_score: 1.0 +ngrams: 5 +question_key: "instruction" +answer_key: "generated_cot" diff --git a/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerNgramFilter_withoutGT.yaml b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerNgramFilter_withoutGT.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc2c273ae794b0293f104798f3ee847a1c334a66 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/ReasoningPipeline/yaml/SFT/process/ReasonerNgramFilter_withoutGT.yaml @@ -0,0 +1,7 @@ +output_file: "dataflow/example/ReasoningPipeline/pipeline_math_step10out_no_gt.jsonl" +input_file: 'dataflow/example/ReasoningPipeline/pipeline_math_step9out_no_gt.jsonl' # Local data path, supports json, jsonl, parquet formats +min_score: 0.1 +max_score: 1.0 +ngrams: 5 +question_key: "instruction" +answer_key: "pseudo_correct_solution_example" diff --git a/DataFlow/dataflow/scripts/pipelines/TextPipeline/yaml/eval/ngram_filter.yaml b/DataFlow/dataflow/scripts/pipelines/TextPipeline/yaml/eval/ngram_filter.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c36bebce960a45ff243a3aa932c2532c9c26dd8 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/TextPipeline/yaml/eval/ngram_filter.yaml @@ -0,0 +1,4 @@ +input_file: ./pt_input.jsonl +output_file: ./step_1_ngram_scorer.jsonl +input_key: "raw_content" +ngrams: 5 \ No newline at end of file diff --git a/DataFlow/dataflow/scripts/pipelines/TextPipeline/yaml/eval/ngram_scorer.yaml b/DataFlow/dataflow/scripts/pipelines/TextPipeline/yaml/eval/ngram_scorer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c36bebce960a45ff243a3aa932c2532c9c26dd8 --- /dev/null +++ b/DataFlow/dataflow/scripts/pipelines/TextPipeline/yaml/eval/ngram_scorer.yaml @@ -0,0 +1,4 @@ +input_file: ./pt_input.jsonl +output_file: ./step_1_ngram_scorer.jsonl +input_key: "raw_content" +ngrams: 5 \ No newline at end of file diff --git a/ernie/ERNIE/ernie/__pycache__/configuration.cpython-311.pyc b/ernie/ERNIE/ernie/__pycache__/configuration.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0500434b5a66375b991eeb0e3095b15dea324cb0 Binary files /dev/null and b/ernie/ERNIE/ernie/__pycache__/configuration.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/__pycache__/modeling.cpython-311.pyc b/ernie/ERNIE/ernie/__pycache__/modeling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e483d90a4bc224fd2f203e6a7100491dfcfef84a Binary files /dev/null and b/ernie/ERNIE/ernie/__pycache__/modeling.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/__pycache__/modeling_moe_pp.cpython-311.pyc b/ernie/ERNIE/ernie/__pycache__/modeling_moe_pp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f45dc43f3b8e9faaf0dd18aa839e2129da2592b Binary files /dev/null and b/ernie/ERNIE/ernie/__pycache__/modeling_moe_pp.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/__pycache__/sequence_parallel_utils.cpython-311.pyc b/ernie/ERNIE/ernie/__pycache__/sequence_parallel_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3895d46e021cd4d41ff30b9ca161d04c790a87f5 Binary files /dev/null and b/ernie/ERNIE/ernie/__pycache__/sequence_parallel_utils.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/__pycache__/tokenizer.cpython-311.pyc b/ernie/ERNIE/ernie/__pycache__/tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c05c4ae3eb6248138b50eb47235a7797e9a2b5b7 Binary files /dev/null and b/ernie/ERNIE/ernie/__pycache__/tokenizer.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/__pycache__/base.cpython-311.pyc b/ernie/ERNIE/ernie/dataset/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbeb44fb323190af031ce296246b86fc88d1d6cb Binary files /dev/null and b/ernie/ERNIE/ernie/dataset/__pycache__/base.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/__pycache__/data_utils.cpython-311.pyc b/ernie/ERNIE/ernie/dataset/__pycache__/data_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ebd44600d52959e1cf6b480254eaf4a41143154 Binary files /dev/null and b/ernie/ERNIE/ernie/dataset/__pycache__/data_utils.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/__pycache__/dpo.cpython-311.pyc b/ernie/ERNIE/ernie/dataset/__pycache__/dpo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57f77c9e279e85d025e9a3496cb2f9552044529f Binary files /dev/null and b/ernie/ERNIE/ernie/dataset/__pycache__/dpo.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/__pycache__/finetuning.cpython-311.pyc b/ernie/ERNIE/ernie/dataset/__pycache__/finetuning.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f50ab45e24e7c0f04dc72c6062d2117edb15c589 Binary files /dev/null and b/ernie/ERNIE/ernie/dataset/__pycache__/finetuning.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/base.py b/ernie/ERNIE/ernie/dataset/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9f80a3625074c9fa4f54c461cf67c15754650f29 --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/base.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Basic datasets implement. """ + +import glob +import gzip +import json +import random +from contextlib import contextmanager +from functools import partial + +import numpy as np +from paddle.io import IterableDataset +from paddleformers.utils.log import logger + +from ernie.dataset.hf import hf_parser + + +@contextmanager +def open_file(filename): + """Construct a file handler that can read normal or gzip-compressed files. + + The handler automatically detects compression based on file extension. + + Args: + filename (str): Path to the target file, which may end with .gz for gzip compression. + + Returns: + Generator[TextIO]: A file object generator that yields lines from the file. + """ + if filename.endswith(".gz"): + fp = gzip.open(filename, "rt") + else: + fp = open(filename) + yield fp + fp.close() + + +class FileDataset(IterableDataset): + """Single file dataset that supports line processing and optional shuffling.""" + + def __init__(self, filename, process_fn=None, shuffle_file=False): + """Initialize the single file dataset. + + Args: + filename (str): Path to the input data file. + process_fn (callable, optional): Function to preprocess each line. + shuffle_file (bool): Whether to shuffle lines before iteration. + """ + self._filename = filename + self._process_fn = process_fn + self._shuffle_file = shuffle_file + + def __iter__(self): + """Iterate through the dataset with optional shuffling and processing. + + Yields: + dict: Processed examples from the file, skipping invalid entries. + """ + with open_file(self._filename) as fin: + if self._shuffle_file: + lines = fin.readlines() + np.random.shuffle(lines) + else: + lines = fin + for lineno, line in enumerate(lines): + try: + ex = json.loads(line) + except Exception as e: + logger.warning(f"Skip loading error data at line {lineno} of {self._filename}. Error message: {e}") + continue + if self._process_fn is not None: + try: + ex = self._process_fn(ex, self._filename) + except Exception as e: + logger.warning( + f"Skip parsing error data at line {lineno} of {self._filename}. Error message: {e}" + ) + continue + # ignore invalid example + if ex is None: + continue + elif isinstance(ex, list): + yield from ex + else: + yield ex + + +class FileListDataset(IterableDataset): + """Multiple files dataset supporting file list and glob patterns.""" + + def __init__( + self, + filename, + file_format="filelist", + process_fn=None, + shuffle_file=False, + shuffle_files=False, + ): + """Initialize the file list dataset. + + Args: + filename (str): Path to file containing file list or glob pattern. + file_format (str): 'filelist' for list file or 'glob' for pattern matching. + process_fn (callable, optional): Function to preprocess each line. + shuffle_file (bool): Shuffle lines within each file. + shuffle_files (bool): Shuffle order of files during iteration. + """ + if file_format == "filelist": + self._filenames = [] + with open(filename) as fin: + for line in fin: + cols = line.strip().split("\t") + self._filenames.append(cols[0]) + elif file_format == "glob": + self._filenames = sorted(glob.glob(filename)) + else: + raise ValueError(f"Unsupported file_format: {file_format}") + + self._sub_datasets = [] + for fname in self._filenames: + self._sub_datasets.append(FileDataset(fname, process_fn=process_fn, shuffle_file=shuffle_file)) + + self._shuffle_files = shuffle_files + + def __iter__(self): + """Iterate through multiple files with optional shuffling. + + Yields: + dict: Processed examples from all files in specified order. + """ + if self._shuffle_files: + # NOTE(hehuang) stateful shuffle + sub_datasets = self._sub_datasets + np.random.shuffle(self._sub_datasets) + else: + sub_datasets = self._sub_datasets + for ds in sub_datasets: + yield from ds + + +class MultiSourceDataset(IterableDataset): + """Dataset that combines multiple data sources with probability sampling.""" + + def __init__( + self, + task_dataset_path, + task_dataset_prob, + sub_dataset_type=["erniekit"], + random_seed=11, + process_fn=None, + shuffle_file=False, + shuffle_files=False, + ): + """Initialize the multi-source dataset. + + Args: + task_dataset_path (list): List contains path of data sources. + task_dataset_prob (list): List contains probabilities of data sources. + sub_dataset_type (list): List of type of sub-dataset ('erniekit', 'filelist', 'glob', or 'alpaca'). + random_seed (int): Seed for reproducible sampling. + process_fn (callable, optional): Function to preprocess each example. + shuffle_file (bool): Shuffle lines within each file. + shuffle_files (bool): Shuffle order of files during iteration. + """ + tasks = [] + for i in range(len(task_dataset_path)): + tasks.append({"prob": task_dataset_prob[i], "filepath": task_dataset_path[i]}) + # filter zero probability task + tasks = [task for task in tasks if task["prob"] > 0] + self._task_group = tasks + for idx, task in enumerate(self._task_group): + each_sub_dataset_type = sub_dataset_type[idx] + if hf_parser.is_hf_dataset(task["filepath"]): + task["dataset"] = hf_parser.create_hf_dataset( + repo_id=task["filepath"], + process_fn=( + partial(process_fn, task_name=task["task_name"]) if "task_name" in task else process_fn + ), + shuffle_file=shuffle_file, + ) + continue + + if each_sub_dataset_type == "erniekit": + task["dataset"] = FileDataset( + task["filepath"], + process_fn=( + partial(process_fn, task_name=task["task_name"]) if "task_name" in task else process_fn + ), + shuffle_file=shuffle_file, + ) + elif each_sub_dataset_type in ["filelist", "glob"]: + task["dataset"] = FileListDataset( + task["train_filelist"], + file_format=each_sub_dataset_type, + process_fn=( + partial(process_fn, task_name=task["task_name"]) if "task_name" in task else process_fn + ), + shuffle_file=shuffle_file, + shuffle_files=shuffle_files, + ) + elif each_sub_dataset_type in ["alpaca"]: + task["dataset"] = hf_parser.create_dataset_from_file( + file_path=task["filepath"], + formatting="alpaca", + doc_formatting="auto", + process_fn=( + partial(process_fn, task_name=task["task_name"]) if "task_name" in task else process_fn + ), + shuffle_file=shuffle_file, + ) + else: + raise NotImplementedError(f"Cannot support {each_sub_dataset_type} now.") + sum_prob = sum([task["prob"] for task in self._task_group]) + for task in self._task_group: + task["prob_origin"] = task["prob"] + task["prob"] = task["prob"] / sum_prob + + self.random_seed = random_seed + + def __iter__(self): + """Iterate through examples from multiple sources with probability sampling. + + Yields: + dict: Processed examples from randomly selected data sources. + """ + rng = random.Random(self.random_seed) + probs = [task["prob"] for task in self._task_group] + # Initialize task iterator + for task in self._task_group: + task["iterator"] = iter(task["dataset"]) + while True: + task = rng.choices(self._task_group, weights=probs)[0] + try: + yield from task["iterator"] + except StopIteration: + task["iterator"] = iter(task["dataset"]) + yield from task["iterator"] diff --git a/ernie/ERNIE/ernie/dataset/data_utils.py b/ernie/ERNIE/ernie/dataset/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3388eae5491505cd275c888ca42268f1a01e9ed4 --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/data_utils.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Useful data utility.""" + +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np +from paddleformers.utils.log import logger + +INF = 1000000 +OPT_MULTI_OF = 256 + + +@dataclass +class Example: + """Data format for raw SFT (Supervised Fine-Tuning) examples.""" + + request: Dict + system: str + label: List[int] + is_system: int + source: str + + +def pad_batch_data( + insts, + pad_idx=0, + return_pos=False, + max_seq_len=None, + return_input_mask=False, + return_max_len=False, + return_num_token=False, + return_seq_lens=False, +): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + return_list = [] + max_len = ( + max_seq_len if max_seq_len is not None else max(len(inst) for inst in insts) + ) + # Any token included in dict can be used to pad, since the paddings' loss + # will be masked out by weights and make no effect on parameter gradients. + + inst_data = np.array( + [inst + list([pad_idx] * (max_len - len(inst))) for inst in insts] + ) + return_list += [inst_data.astype("int64").reshape([-1, max_len])] + + # position data + if return_pos: + inst_pos = np.array( + [ + list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) + for inst in insts + ] + ) + + return_list += [inst_pos.astype("int64").reshape([-1, max_len])] + + if return_input_mask: + # This is used to avoid attention on paddings. + input_mask_data = np.array( + [[1] * len(inst) + [0] * (max_len - len(inst)) for inst in insts] + ) + input_mask_data = np.expand_dims(input_mask_data, axis=-1) + return_list += [input_mask_data.astype("float32")] + + if return_max_len: + return_list += [max_len] + + if return_num_token: + num_token = 0 + for inst in insts: + num_token += len(inst) + return_list += [num_token] + + if return_seq_lens: + seq_lens = np.array([len(inst) for inst in insts]) + return_list += [seq_lens.astype("int64").reshape([-1, 1])] + + return return_list if len(return_list) > 1 else return_list[0] + + +def convert_to_tokens_for_pt( + dial: List[dict], + tokenizer, + max_src_len, +): + """Convert a dial to tokens for PT model.""" + # content_1+"\n"+content_2+"\n"+content_3 + sentence = "\n".join([x["content"] for x in dial]) + tokens = tokenizer.tokenize(sentence) + if len(tokens) > max_src_len: + logger.warning( + f"The length of text ({len(tokens)}) cannot " + f"be greater than max input length \ + ({max_src_len}). \ + We will truncate it." + ) + # NOTE: LLM lost in middle + tokens = tokens[: max_src_len // 2] + tokens[-max_src_len:] + + return tokens + + +def convert_to_tokens_for_sft( + dial: List[dict], + tokenizer, + max_src_len, +): + """ + Convert dialogue format into token sequences for supervised fine-tuning (SFT). + + Args: + dial: Dialogue history as list of message dictionaries with: + - role: "system", "knowledge", "user" or "assistant" + - content: Text content + tokenizer: Tokenizer instance for text processing + max_src_len: Maximum allowed length for source tokens + + Returns: + List of processed tokens ready for model input + """ + encoded_messages = tokenizer.encode_chat_inputs({"messages": dial}) + + num_reserved_tokens_for_each_dialog = 1 # only break_turn_token or end_token + num_reserved_tokens_for_each_turn = 8 + + cur_len = num_reserved_tokens_for_each_dialog + + turn_index = len(encoded_messages) - 1 + + tokens = [] + tokens = encoded_messages[turn_index][0] + turn_index -= 1 + + while turn_index >= 0: + tokens_src, tokens_target = encoded_messages[turn_index] + if len(tokens_src) + len(tokens_target) > ( + max_src_len + 1 - cur_len - num_reserved_tokens_for_each_turn + ): + break + + tokens = tokens_src + tokens_target + tokens + cur_len = len(tokens) + turn_index -= 1 + + return tokens + + +def convert_to_input_ids( + dials: List[List[dict]], + tokenizer, + data_format, + max_src_len, +) -> Tuple[List[List[int]], int]: + """Convert batch dialogue into input_ids. + + The API support multiple data format: `pt`, `sft. + + Args: + dials (List[List[dict]]): A batch of dialogue. + tokenizer (Ernie4_5_Tokenizer): The used tokenizer. + data_format (str): The data format for converting dialogue to input_ids, + support `base`, `chat`. + max_src_len (int): The maximum length of input_ids. + + Returns: + input_ids (List[List[int]]): The raw input_ids with truncation, but without padding. + num_input_tokens (int): The total input tokens in a batch. + + Raises: + ValueError: Invalid data format. + """ + input_ids = [] + num_input_tokens = 0 + for dial in dials: + if data_format == "base": + tokens = convert_to_tokens_for_pt(dial, tokenizer, max_src_len) + input_ids.append(tokenizer.convert_tokens_to_ids(tokens)) + elif data_format == "chat": + input_ids.append(convert_to_tokens_for_sft(dial, tokenizer, max_src_len)) + else: + raise ValueError(f"Unsupported data format: {data_format}") + num_input_tokens += len(input_ids[-1]) + return input_ids, num_input_tokens diff --git a/ernie/ERNIE/ernie/dataset/dpo.py b/ernie/ERNIE/ernie/dataset/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..3afe00c69e9a1749fb3c36e91c329a4b7911be8b --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/dpo.py @@ -0,0 +1,842 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO dataset.""" + +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np +from paddle.io import IterableDataset, get_worker_info +from paddleformers.utils.log import logger +from scipy.linalg import block_diag + +from ernie.dataset.base import MultiSourceDataset + +LOGGER_COUNT = 0 + + +@dataclass +class Example: + """Dataset example.""" + + chosen: str + rejected: str + source: str + session_start_index: int + score_delta: float + + +@dataclass +class Sequence: + """Sequence.""" + + input_ids: Optional[List[int]] + position_ids: Optional[List[int]] + attention_mask: Optional[List[List[int]]] + attn_mask_start_row_indices: Optional[List[int]] + chosen_labels: List[int] + rejected_labels: List[int] + response_index: List[int] + score_delta: float + + +def create_dataset(**dataset_config): + """Create DPO dataset. + + Args: + **dataset_config: Configuration parameters including: + - task_dataset_path (str): Path of each dataset + - task_dataset_prob (str): Prob of each dataset + - sub_dataset_type (str): type of each dataset + - tokenizer: Text tokenization module + - max_seq_len (int): Total sequence length limit + - max_prompt_len (int): Total prompt length + - num_samples_each_epoch (int): number of sample per training epoch + - is_valid (bool, optional): Validation mode flag. Defaults to False + - random_seed (int): Reproduction seed for shuffling + - greedy_intokens (bool): Greedy intokens strategy + - buffer_size (int): Preloading buffer capacity + - use_attn_mask_start_row_indices (bool): Sparse attention mode + - mask_out_eos_token (bool): EOS loss masking + + Returns: + SequenceDataset: Configured dataset pipeline with: + - Multi-source data loading + - Dynamic sequence generation + - Session-aware processing (when enabled) + """ + task_dataset_path = [ + path + for path in str(dataset_config["task_group"]).replace(" ", "").split(",") + if path != "" + ] + task_dataset_prob = [ + float(prob) + for prob in str(dataset_config["task_group_prob"]).replace(" ", "").split(",") + if prob != "" + ] + sub_dataset_type = [ + type_ + for type_ in str(dataset_config["sub_dataset_type"]).replace(" ", "").split(",") + if type_ != "" + ] + + if not (len(task_dataset_path) == len(task_dataset_prob) == len(sub_dataset_type)): + raise ValueError( + "The len of dataset path, prob, type are inconsistent, please check the configuration." + ) + + if len(task_dataset_path) == 0: + raise ValueError( + "The len of dataset path is zero, please check the configuration." + ) + + example_dataset = MultiSourceDataset( + task_dataset_path=task_dataset_path, + task_dataset_prob=task_dataset_prob, + sub_dataset_type=sub_dataset_type, + process_fn=process_session_example, + ) + sequence_dataset = SequenceDataset( + dataset=example_dataset, + tokenizer=dataset_config["tokenizer"], + max_seq_len=dataset_config["max_seq_len"], + max_prompt_len=dataset_config["max_prompt_len"], + num_samples_each_epoch=dataset_config["num_samples_each_epoch"], + is_valid=dataset_config.get("is_valid", False), + random_seed=dataset_config["random_seed"], + random_shuffle=dataset_config["random_shuffle"], + greedy_intokens=dataset_config["greedy_intokens"], + buffer_size=dataset_config["buffer_size"], + use_attn_mask_start_row_indices=dataset_config.pop( + "use_attn_mask_start_row_indices", True + ), + mask_out_eos_token=dataset_config["mask_out_eos_token"], + ) + return sequence_dataset + + +def collate_fn( + batch, + tokenizer, + max_seq_len=None, + use_sparse_head_and_loss_fn=True, + use_fused_head_and_loss_fn=True, + use_response_score_delta=False, +): + """Convert batch data into tensor for DPO. + + Args: + batch (List[List[Sequence]]): Batch of input sequences containing multiple data samples. + Each sample is a list of Sequence objects containing tokenized data components. + tokenizer (Tokenizer): Text tokenizer for processing sequence components. + max_seq_len (int, optional): Maximum sequence length for padding/truncation. + If None, will raise ValueError. Defaults to None. + use_sparse_head_and_loss_fn (bool, optional): Whether to use sparse indexing for loss calculation. + Enables memory-efficient indexing for large sequences. Defaults to True. + use_fused_head_and_loss_fn (bool, optional): Whether to use fused kernel to calculate lm head and loss. + Optimizes for memory access patterns. Defaults to True. + + Returns: + Dict[str, np.ndarray]: Processed tensor dictionary containing: + - input_ids (int32): Padded token ids [batch_size, max_seq_len] + - position_ids (int32): Position ids [batch_size, max_seq_len] + - chosen_labels (int32): Preferred response labels [batch_size, max_seq_len] + - rejected_labels (int32): Unpreferred response labels [batch_size, max_seq_len] + - response_indexs (int32): Response span indices [batch_size, 4] + - attention_mask (float32, optional): Attention mask matrix [batch_size, 1, max_seq_len, max_seq_len] + - attn_mask_start_row_indices (int32, optional): Sparse attention row indices [batch_size, max_seq_len] + """ + if max_seq_len is None: + raise ValueError("max_seq_len is None.") + + input_dict = { + "input_ids": [], + "position_ids": [], + "chosen_labels": [], + "rejected_labels": [], + "response_indexs": [], + } + if use_response_score_delta: + input_dict["score_deltas"] = [] + + sequence = batch[0][0] + if sequence.attn_mask_start_row_indices is not None: + input_dict["attn_mask_start_row_indices"] = [] + use_attn_mask_start_row_indices = True + elif sequence.attention_mask is not None: + input_dict["attention_mask"] = [] + use_attn_mask_start_row_indices = False + else: + raise ValueError( + "attention_mask and attn_mask_start_row_indices are both None." + ) + sequence_sum_flatten = 0 + for i, sequences in enumerate(batch): + difference = max_seq_len - sum( + [len(sequence.input_ids) for sequence in sequences] + ) + + input_dict["input_ids"].append( + sum([sequence.input_ids for sequence in sequences], []) + [0] * difference + ) + input_dict["position_ids"].append( + sum([sequence.position_ids for sequence in sequences], []) + + [0] * difference + ) + input_dict["chosen_labels"].append( + sum([sequence.chosen_labels for sequence in sequences], []) + + [0] * difference + ) + input_dict["rejected_labels"].append( + sum([sequence.rejected_labels for sequence in sequences], []) + + [0] * difference + ) + if use_attn_mask_start_row_indices: + start_row_indices = [] + sequence_sum = 0 + for sequence in sequences: + start_row_indices += [ + indice + sequence_sum + for indice in sequence.attn_mask_start_row_indices + ] + sequence_sum += len(sequence.input_ids) + input_dict["attn_mask_start_row_indices"].append( + [start_row_indices + list(range(start_row_indices[-1], max_seq_len))] + ) + else: + input_dict["attention_mask"].append( + # (s,s) -> (1,s,s) + np.expand_dims( + # pad to max_loength + np.pad( + # block attention_mask + block_diag( + *[sequence.attention_mask for sequence in sequences] + ), + pad_width=((0, difference), (0, difference)), + mode="constant", + constant_values=False, + ), + axis=0, + ) + ) + sequence_sum = 0 + for sequence in sequences: + # bs, chosen_response_start_index, rejeted_response_start_index, rejeted_response_end_index + 1 + if use_sparse_head_and_loss_fn: + response_index = [ + i, + sequence_sum_flatten, + sequence.response_index[1] + - sequence.response_index[0] + + sequence_sum_flatten, + sequence.response_index[2] + - sequence.response_index[0] + + sequence_sum_flatten, + ] + sequence_sum_flatten += ( + sequence.response_index[2] - sequence.response_index[0] + ) + elif use_fused_head_and_loss_fn: + response_index = [ + i, + sequence.response_index[0] + sequence_sum_flatten, + sequence.response_index[1] + sequence_sum_flatten, + sequence.response_index[2] + sequence_sum_flatten, + ] + sequence_sum_flatten += len(sequence.input_ids) + else: + response_index = [ + i, + sequence.response_index[0] + sequence_sum, + sequence.response_index[1] + sequence_sum, + sequence.response_index[2] + sequence_sum, + ] + sequence_sum += len(sequence.input_ids) + input_dict["response_indexs"].append(response_index) + if use_response_score_delta: + input_dict["score_deltas"].append(sequence.score_delta) + + for key in input_dict: + if key == "attention_mask": + input_dict[key] = np.array(input_dict[key], dtype=np.float32) + elif key == "attn_mask_start_row_indices": + input_dict[key] = np.array(input_dict[key], dtype=np.int32) + else: + input_dict[key] = np.array(input_dict[key]) + return input_dict + + +def process_session_example(data, input_file): + """Convert raw format example to Example. + + Args: + data (dict): Raw session data dictionary containing: + - src (str/list): Multi-turn dialogue context (user inputs sequence) + - tgt (str/list): Assistant responses sequence (must be 1 shorter than src) + - response (List[List[str]]): Pair of multi-turn response candidates [each is list of strings] + - sort (List[int]): Ranking scores for response pairs [length must be 2] + - system (str, optional): System-level instruction for dialogue + input_file (str): Source file path for data provenance tracking + + Returns: + Example: Standardized data container with fields: + - src (list): Full context sequence (with system prompt if exists) + - tgt (list): Expected response sequence + - is_system (int): System prompt presence flag (0/1) + - chosen/rejected (list): Selected best/worst multi-turn responses + - source: Original data file path + - data_format: Format identifier "sft" + """ + if isinstance(data["src"], str): + data["src"] = [data["src"]] + if isinstance(data["tgt"], str): + data["tgt"] = [data["tgt"]] + if len(data["src"]) != len(data["tgt"]) + 1: + raise ValueError( + f"Data format error. src length must be tgt length + 1. " + f"But got src_length:{len(data['src'])} tgt_length:{len(data['tgt'])}" + ) + if (len(data["response"]) != 2) or (len(data["response"]) != len(data["sort"])): + raise ValueError( + f"Response and sort length must be 2. " + f"But got response_length:{len(data['response'])} sort_length:{len(data['sort'])}." + ) + if data["sort"][0] == data["sort"][1]: + raise ValueError( + f"Sort field must be different." f" But got 'sort':{data['sort']}" + ) + if isinstance(data["response"][0], str) and isinstance(data["response"][1], str): + data["response"] = [[data["response"][0]], [data["response"][1]]] + for response in data["response"]: + if not isinstance(response, list): + raise ValueError( + f"Session level response should be List[List[str]], but got List of {type(response)}" + ) + if len(response) % 2 != 1: + raise ValueError( + "The number of responses should be even, but an odd number of responses were obtained." + ) + for r in response: + if len(r.strip()) < 1: + raise ValueError( + f"Response field must be longer than 1." + f" But got 'response':{data['response']}." + ) + + if len(data["response"][0]) < 1 or len(data["response"][1]) < 1: + raise ValueError( + f"Ignore empty response." f" But got 'response':{data['response']}." + ) + if data["sort"][0] > data["sort"][1]: + chosen = data["response"][0] + rejected = data["response"][1] + else: + chosen = data["response"][1] + rejected = data["response"][0] + + if "is_system" not in data: + # If is_system is 1, it indicates that the sample includes system settings + # and no other sample should be concatenated before it. + data["is_system"] = 0 + + if data["is_system"] == 1: + data["system"] = data["src"][0] + data["src"] = data["src"][1:] + data["tgt"] = data["tgt"][1:] + + if "system" in data: + if not isinstance(data["system"], str): + raise ValueError("System field must be a string.") + + # convert to OpenAI format + data["messages"] = [] + if "system" in data: + data["messages"].append({"role": "system", "content": data["system"]}) + for idx in range(len(data["src"])): + data["messages"].append({"role": "user", "content": data["src"][idx]}) + if idx != len(data["src"]) - 1: + data["messages"].append({"role": "assistant", "content": data["tgt"][idx]}) + + chosen_m, rejected_m = data["messages"], deepcopy(data["messages"]) + session_start_index = ( + len(data["messages"]) + if data["messages"][0]["role"] != "system" + else len(data["messages"]) - 1 + ) + for idx in range(len(chosen)): + if idx % 2 == 0: + # assistant + chosen_m.append({"role": "assistant", "content": chosen[idx]}) + rejected_m.append({"role": "assistant", "content": rejected[idx]}) + else: + # user + chosen_m.append({"role": "user", "content": chosen[idx]}) + rejected_m.append({"role": "user", "content": rejected[idx]}) + + return Example( + chosen={"messages": chosen_m}, + rejected={"messages": rejected_m}, + session_start_index=session_start_index, + source=input_file, + score_delta=1.0, + ) + + +class InfiniteDataset(IterableDataset): + """Load infinite data from original dataset with shuffle. + + Args: + dataset (IterableDataset): Source dataset to wrap. Will be fully + materialized into a list for repeated access. + rng (random.Random, optional): Custom random number generator for + controlling shuffle behavior. Defaults to new random.Random(). + """ + + def __init__(self, dataset, rng=None, random_shuffle=True): + """Initialize InfiniteDataset. + + Args: + dataset (Iterable): The original dataset to wrap. + rng (Random, optional): Random number generator for shuffling. + random_shuffle (bool): Whether to enable random shuffling. + """ + self.data = list(iter(dataset)) + self.indices = list(range(len(self.data))) + if rng is None: + rng = random.Random() + self.rng = rng + self.random_shuffle = random_shuffle + + def __iter__(self): + while True: + if self.random_shuffle: + self.rng.shuffle(self.indices) + for i in self.indices: + yield self.data[i] + + +class SequenceDataset(IterableDataset): + """Stateful dataset for generating token sequences from multi-source examples. + + Args: + dataset (MultiSourceDataset): Source dataset containing examples to process + tokenizer (Tokenizer): Tokenizer for text processing and token conversion + max_seq_len (int, optional): Maximum sequence length. Defaults to 4096 + max_prompt_len (int, optional): Maximum prompt context length. Defaults to 2048 + num_samples_each_epoch (int, optional): number of sample per epoch. Defaults to 1e5 + is_valid (bool, optional): Validation mode flag (disable randomization). Defaults to False + random_seed (int, optional): Seed for reproducible shuffling. Defaults to 11 + random_shuffle (bool, optional): Enable random shuffling. Defaults to True + greedy_intokens (bool, optional): Greedy intokens strategy. Defaults to False + buffer_size (int, optional): Preload buffer size for optimization. Defaults to 500 + use_attn_mask_start_row_indices (bool, optional): Use sparse attention indexing. Defaults to True + mask_out_eos_token (bool, optional): Exclude EOS from loss calculation. Defaults to True + """ + + def __init__( + self, + dataset: MultiSourceDataset, + tokenizer, + max_seq_len: int = 4096, + max_prompt_len: int = 2048, + num_samples_each_epoch: int = 100000, + is_valid: bool = False, + random_seed: int = 11, + random_shuffle: bool = True, + greedy_intokens: bool = False, + buffer_size: int = 500, + use_attn_mask_start_row_indices: bool = True, + mask_out_eos_token: bool = True, + ): + self.example_dataset = dataset + self.tokenizer = tokenizer + self.start_token = tokenizer.bos_token + self.end_token = tokenizer.eos_token + self.break_token = tokenizer.sep_token + self.break_turn_token = tokenizer.cls_token + self.sys_start_token = getattr(tokenizer, "sys_start_token", None) + self.sys_end_token = getattr(tokenizer, "sys_end_token", None) + + self.max_seq_len = max_seq_len + self.max_prompt_len = max_prompt_len + if self.max_prompt_len > self.max_seq_len: + raise ValueError( + f"max_prompt_len should be less than max_seq_len, but got {self.max_prompt_len} > {self.max_seq_len}" + ) + self.is_valid = is_valid + self.random_seed = random_seed + self.rng = random.Random(random_seed) + self.random_shuffle = random_shuffle + self.greedy_intokens = greedy_intokens + self.buffer_size = buffer_size + self.origin_dataset_num = 0 + self.use_attn_mask_start_row_indices = use_attn_mask_start_row_indices + self.mask_out_eos_token = mask_out_eos_token + + # For new data concatenation mode + self.begin_of_query = self.tokenizer.tokenize("User: ") + self.begin_of_response = self.tokenizer.tokenize("\nAssistant: ") + self.end_of_response = "<|end_of_sentence|>" + self.begin_token = "<|begin_of_sentence|>" # Same effect as sys_start_token + self.newline_token = self.tokenizer.tokenize( + "\n" + ) # Same effect as sys_end_token + + if not is_valid: + for task in self.example_dataset._task_group: + task["target_num_each_epoch"] = int( + task["prob"] * num_samples_each_epoch + ) + inner_dataset = InfiniteDataset( + task["dataset"], self.rng, self.random_shuffle + ) + task["iterator"] = iter(inner_dataset) + task["num_examples"] = len(inner_dataset.data) + logger.info( + f"{task['filepath']}: task prob: {task['prob']}, " + f"ori number of examples: {len(inner_dataset.data)}, " + f"target_num_each_epoch: {task['target_num_each_epoch']}" + ) + self.origin_dataset_num += len(inner_dataset.data) + + self.epoch_index = 0 + + def __iter_func(self): + """ + The __iter_func method implements iteration over the dataset. + Each iteration returns a Sequence-type element. + Within the current epoch, samples are randomly generated using epoch_rng and are only valid for that epoch. + If multiple workers exist, data is partitioned according to worker ID. + + Args: + None (no parameters) + + Returns: + Sequence (class): A Sequence-type element containing input IDs, input masks, and labels. + + Raises: + No exceptions raised. + """ + # epoch_rng only use in this epoch. + epoch_rng = np.random.RandomState(self.epoch_index) + worker_info = get_worker_info() + + # prepare epoch data + examples_all = [] + batch_sequence, cur_len = [], 0 + for task in self.example_dataset._task_group: + if self.is_valid: + examples = [ex for ex in task["dataset"]] + self.origin_dataset_num += len(examples) + else: + examples = [ + next(task["iterator"]) for _ in range(task["target_num_each_epoch"]) + ] + if self.random_shuffle: + epoch_rng.shuffle(examples) + if worker_info is not None: + examples = examples[worker_info.id :: worker_info.num_workers] + examples_all.extend(examples) + if self.random_shuffle: + epoch_rng.shuffle(examples_all) + if not self.greedy_intokens: + for example in examples_all: + sequence = self._postprocess_sequence(example) + if sequence is None: + continue + + if cur_len + len(sequence.input_ids) <= self.max_seq_len: + batch_sequence.append(sequence) + cur_len += len(sequence.input_ids) + else: + yield batch_sequence + batch_sequence, cur_len = [sequence], len(sequence.input_ids) + + if len(batch_sequence) > 0: + yield batch_sequence + else: + sequence_buffer = [] + buffer_size = self.buffer_size + for example in examples_all: + sequence = self._postprocess_sequence(example) + if sequence is None: + continue + sequence_buffer.append(sequence) + + if len(sequence_buffer) == buffer_size: + sequence_pack = self._generate_greedy_packs(sequence_buffer) + for pack in sequence_pack: + yield pack + sequence_buffer = [] + if len(sequence_buffer) > 0: + sequence_pack = self._generate_greedy_packs(sequence_buffer) + for pack in sequence_pack: + yield pack + + self.epoch_index += 1 + + def __iter__(self): + """ + Rewrite the __iter__ method to implement dataset iteration. + Each iteration returns a Sequence-type element. + """ + if self.is_valid: + yield from self.__iter_func() + else: + while True: + yield from self.__iter_func() + + def _generate_greedy_packs(self, sequences): + """Generate sequence packs using greedy bin packing algorithm for efficient batching. + + Args: + sequences (List[Sequence]): List of input sequences containing: + - input_ids (List[int]): Tokenized sequence + [Other sequence attributes...] + + Returns: + List[List[Sequence]]: Packed sequences grouped into batches where: + - Each sublist represents a batch + - Sum of sequence lengths in batch <= self.max_seq_len + - Batches ordered by descending remaining capacity + """ + left_len_list = np.array([]) + sequence_pack = [] + for sequence in sequences: + sequence_len = len(sequence.input_ids) + if len(left_len_list) > 0: + max_left_len_index = left_len_list.argmax() + + if ( + len(left_len_list) == 0 + or left_len_list[max_left_len_index] < sequence_len + ): + sequence_pack.append([sequence]) + left_len_list = np.append( + left_len_list, np.array([self.max_seq_len - sequence_len]) + ) + else: + sequence_pack[max_left_len_index].append(sequence) + left_len_list[max_left_len_index] -= sequence_len + return sequence_pack + + def __postprocess_before_concat(self, example): + """Process multi-turn conversation data into tokenized sequences with dynamic truncation. + + Args: + example (Example): Input data object containing: + - src (List[str]): Conversation history prompts + - tgt (List[str]): Corresponding responses + - chosen/rejected (List[str]): Preferred/unpreferred response paths + - is_system (int): System prompt presence flag + - system (str): System settings + + Returns: + tuple: (prompt_ids, response_ids_list, label_ids_list, response_lens, total_len) containing: + - prompt_token_ids (List[int]): Main conversation context token ids + - response_token_ids_list (List[List[int]]): [chosen_path, rejected_path] response token ids + - response_label_ids_list (List[List[int]]): Each response label ids(mask included) + - response_len_list (List[int]): Valid response token length(special token excluded) + - cur_len (int): Final input ids length + """ + prompt_token_ids = [] + + cur_len = 0 + + # encoded_messages: List[List[int]] + chosen_encoded_messages = self.tokenizer.encode_chat_inputs(example.chosen) + rejected_encoded_messages = self.tokenizer.encode_chat_inputs(example.rejected) + + # chosen/rejected response + response_token_ids_list = [] + response_label_ids_list = [] + response_len_list = [] + for responses in [ + chosen_encoded_messages[example.session_start_index // 2 :], + rejected_encoded_messages[example.session_start_index // 2 :], + ]: + responses_token_ids = [] + responses_label_ids = [] + response_len = 0 + for i, response in enumerate(responses): + q, a = response + label_ids, res = [], [] + + if i != 0: + # prompt + label_ids += [0] * (len(q) - 1) + res += q + + # response + if self.mask_out_eos_token: + label_ids += a[:-1] + [0, 0] + response_len += len(a) - 1 + res += a + else: + label_ids += a + [0] + response_len += len(a) + res += a + responses_token_ids += res + responses_label_ids += label_ids + response_token_ids_list.append(responses_token_ids) + response_label_ids_list.append(responses_label_ids) + response_len_list.append(response_len) + + cur_len += sum(map(len, response_token_ids_list)) + + # create at least one turn + turn_index = len(chosen_encoded_messages) - 1 + while turn_index >= 0: + if turn_index == len(chosen_encoded_messages) - 1: + cur_turn_token = chosen_encoded_messages[turn_index][0] + else: + cur_turn_token = ( + chosen_encoded_messages[turn_index][0] + + chosen_encoded_messages[turn_index][1] + ) + + if cur_len + len(cur_turn_token) > self.max_seq_len: + break + + prompt_token_ids = cur_turn_token + prompt_token_ids + cur_len += len(cur_turn_token) + turn_index -= 1 + + # at least one turn + if turn_index == len(chosen_encoded_messages) - 1: + sub_src = example.chosen[0]["content"].strip()[:5] + global LOGGER_COUNT + LOGGER_COUNT += 1 + if LOGGER_COUNT <= 5: + logger.warning( + f"[SKIP] max_seq_len({self.max_seq_len}) is insufficient to include " + f"even one turn, example_output:'{{'src':[{sub_src}, ……]}}'" + ) + return (None,) * 5 + + if cur_len > self.max_seq_len: + logger.warning(f"[SKIP] Example is too long: {example}") + return (None,) * 5 + + return ( + prompt_token_ids, + response_token_ids_list, + response_label_ids_list, + response_len_list, + cur_len, + ) + + def _postprocess_sequence(self, example): + """Assemble processed components into final training sequence with attention controls. + + Args: + example (Example): Input data object containing raw fields: + - data_format (str): Specifies processing mode ("ec3_completion" or others) + - [Other fields depending on data_format] + + Returns: + Sequence: Processed training sequence containing: + - input_ids (List[int]): Concatenated token IDs [prompt + chosen + rejected] + - position_ids (List[int]): Position indices with special structure: + * prompt positions: 0~N + * chosen positions: N~N+M + * rejected positions: N~N+K (reuses prompt start index) + - chosen_labels (List[int]): Masked labels for chosen response: + * 0 for prompt/rejected sections + * Shifted response tokens for chosen + - rejected_labels (List[int]): Masked labels for rejected response + - response_index (List[int]): Span indices [start, chosen_end, total_end] + - attention controls (mask or indices): + * attention_mask (np.ndarray): Causal mask matrix if enabled + * attn_mask_start_row_indices (List[int]): Sparse attention indices + - score_delta (float): Score delta between chosen and rejected responses + """ + # sequence: system + knowledge_tokens + prompt + chosen + reject + ( + prompt_token_ids, + response_token_ids_list, + response_label_ids_list, + response_len_list, + cur_len, + ) = self.__postprocess_before_concat(example) + + # The sequnece is too long, just return None + if prompt_token_ids is None: + return None + # 1.concat all tokens + # 1.1 input_ids + input_ids = ( + prompt_token_ids + response_token_ids_list[0] + response_token_ids_list[1] + ) + if cur_len != len(input_ids): + logger.warning(f"[SKIP] code bug: {example}") + return None + + # 1.2. position_ids + prompt_len = len(prompt_token_ids) + chosen_len = len(response_token_ids_list[0]) + rejected_len = len(response_token_ids_list[1]) + position_ids = ( + list(range(prompt_len)) # prompt + + list(range(prompt_len, prompt_len + chosen_len)) # chosen + + list(range(prompt_len, prompt_len + rejected_len)) # rejected + ) + + # 1.3 labels + chosen_labels = ( + [0] * (prompt_len - 1) + + response_label_ids_list[0] + + [0] * len(response_token_ids_list[1]) + ) + rejected_labels = ( + [0] * (prompt_len - 1) + + [0] * len(response_token_ids_list[0]) + + response_label_ids_list[1] + ) + + # 1.4 response index + # support use_sparse_head_and_loss_fn only + response_index = [0, response_len_list[0], sum(response_len_list)] + + # 1.5 attention mask + if self.use_attn_mask_start_row_indices: + attn_mask_start_row_indices = ( + [cur_len] * (prompt_len) + + [prompt_len + chosen_len] * chosen_len + + [cur_len] * rejected_len + ) + attention_mask = None + else: + attention_mask = np.tri(cur_len, cur_len, dtype=bool) + attention_mask[ + (prompt_len + chosen_len) :, + prompt_len : (prompt_len + chosen_len), + ] = False + attn_mask_start_row_indices = None + # 2. return sequence + return Sequence( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_start_row_indices=attn_mask_start_row_indices, + chosen_labels=chosen_labels, + rejected_labels=rejected_labels, + response_index=response_index, + score_delta=example.score_delta, + ) diff --git a/ernie/ERNIE/ernie/dataset/finetuning.py b/ernie/ERNIE/ernie/dataset/finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..18baa36c12273a2784b874e8b213603c09358c9c --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/finetuning.py @@ -0,0 +1,678 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from dataclasses import dataclass +from typing import List + +import numpy as np +from paddle.io import IterableDataset, get_worker_info +from paddleformers.utils.log import logger + +from ernie.dataset.base import MultiSourceDataset +from ernie.dataset.data_utils import ( + Example, + pad_batch_data, +) + +LOGGER_COUNT = 0 + + +@dataclass +class Sequence: + """Encapsulated sequence class.""" + + token_ids: List[int] + position_ids: List[int] + labels: List[int] + loss_mask: List[int] + num_examples: int + + +def create_dataset(**dataset_config): + """Create SFT dataset based on configuration parameters. + + Args: + dataset_config (dict): Configuration dictionary containing parameters like: + + Returns: + SequenceDataset: Configured sequence dataset for SFT tasks + """ + task_dataset_path = [ + path + for path in str(dataset_config["task_group"]).replace(" ", "").split(",") + if path != "" + ] + task_dataset_prob = [ + float(prob) + for prob in str(dataset_config["task_group_prob"]).replace(" ", "").split(",") + if prob != "" + ] + sub_dataset_type = [ + type_ + for type_ in str(dataset_config["sub_dataset_type"]).replace(" ", "").split(",") + if type_ != "" + ] + + if not (len(task_dataset_path) == len(task_dataset_prob) == len(sub_dataset_type)): + raise ValueError( + "The len of dataset path, prob, type are inconsistent, please check the configuration." + ) + + if len(task_dataset_path) == 0: + raise ValueError( + "The len of dataset path is zero, please check the configuration." + ) + + example_dataset = MultiSourceDataset( + task_dataset_path=task_dataset_path, + task_dataset_prob=task_dataset_prob, + sub_dataset_type=sub_dataset_type, + process_fn=process_example, + ) + sequence_dataset = SequenceDataset( + dataset=example_dataset, + tokenizer=dataset_config["tokenizer"], + max_seq_len=dataset_config["max_seq_len"], + num_samples_each_epoch=dataset_config["num_samples_each_epoch"], + is_valid=dataset_config.get("is_valid", False), + random_seed=dataset_config["random_seed"], + random_shuffle=dataset_config["random_shuffle"], + greedy_intokens=dataset_config["greedy_intokens"], + ) + return sequence_dataset + + +def create_indexed_dataset(data_file_prefix): + """Create indexed dataset from raw data files. + + Args: + data_file_prefix (str): Path prefix for raw data files + + Returns: + IndexedDataset: Preprocessed dataset with memory-efficient indexing + """ + from paddleformers.data.indexed_dataset import ( + make_sft_dataset as make_sft_indexed_dataset, + ) + + indexed_dataset = make_sft_indexed_dataset( + path=data_file_prefix, + dataclass=Sequence, + ) + return indexed_dataset + + +def collate_fn(batch: List[List[Sequence]], tokenizer, model_args, max_seq_len: int): + """Convert batch of sequences into training tensors. + + Args: + batch (List[List[Sequence]]): Batch of input sequences + tokenizer: Tokenizer for text conversion + model_args: Model configuration parameters + max_seq_len (int): Maximum sequence length for padding + + Returns: + dict: Dictionary containing: + - input_ids: Padded token IDs + - labels: Shifted labels for prediction + - loss_mask: Mask for computing loss + """ + input_keys = ["input_ids", "labels", "loss_mask"] + if model_args.use_attn_mask_start_row_indices: + input_keys.append("attn_mask_start_row_indices") + else: + input_keys.append("attention_mask") + return_list = [] + for batch_sequence in batch: + original_token_ids = [seq.token_ids for seq in batch_sequence] + token_ids = [sum(original_token_ids, [])] + loss_mask = [sum([seq.loss_mask for seq in batch_sequence], [])] + labels = [sum([seq.labels for seq in batch_sequence], [])] + + # padding + padded_token_ids = pad_batch_data( + token_ids, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len + ) + padded_labels = pad_batch_data( + labels, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len + ) + padded_loss_mask = pad_batch_data(loss_mask, pad_idx=0, max_seq_len=max_seq_len) + padded_labels = np.where(padded_loss_mask == 1, padded_labels, -100) + return_list.append( + [ + padded_token_ids, + padded_labels, + padded_loss_mask, + ] + ) + + if model_args.use_attn_mask_start_row_indices: + return_list[-1].append( + gen_attn_mask_start_row_indices(original_token_ids, max_seq_len) + ) + else: + return_list[-1].append(gen_self_attn_mask(original_token_ids, max_seq_len)) + + return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)] + input_dict = dict(zip(input_keys, return_list)) + return input_dict + + +def process_example(data, input_file): + """Convert raw data example into training example. + + Args: + data (dict): Raw example data with: + input_file (str): Source file path + + Returns: + Example: Processed example for sequence generation + """ + # We have the code completion dataset, which has the following fields + if isinstance(data["src"], str): + data["src"] = [data["src"]] + if isinstance(data["tgt"], str): + data["tgt"] = [data["tgt"]] + + if len(data["src"]) == 0 or len(data["tgt"]) == 0: + raise ValueError("Ignore example with empty src or empty tgt.") + + for item in data["src"] + data["tgt"]: + if len(item.strip()) == 0: + raise ValueError("Ignore example with empty string in str / tgt field.") + + if "label" not in data: + data["label"] = [1] * len(data["src"]) + + if not (len(data["src"]) == len(data["tgt"]) == len(data["label"])): + raise ValueError("The length of src & tgt & label must be equal.") + + if "is_system" not in data: + # If is_system is 1, it indicates that the sample includes system settings + # and no other sample should be concatenated before it. + data["is_system"] = 0 + + if data["is_system"] == 1: + data["system"] = data["src"][0] + data["src"] = data["src"][1:] + data["tgt"] = data["tgt"][1:] + data["label"] = data["label"][1:] + + # update "system" + if "system" in data: + if not isinstance(data["system"], str): + raise ValueError("System field must be a string.") + data["is_system"] = 1 + + # convert to OpenAI format + data["messages"] = [] + if "system" in data: + data["messages"].append({"role": "system", "content": data["system"]}) + for q, a in zip(data["src"], data["tgt"]): + data["messages"].append({"role": "user", "content": q.strip()}) + data["messages"].append({"role": "assistant", "content": a.strip()}) + + return Example( + request={"messages": data["messages"]}, + system=data["system"] if data["is_system"] else "", + label=data["label"], + is_system=data["is_system"], + source=input_file, + ) + + +class InfiniteDataset(IterableDataset): + """Infinite iterable dataset with shuffle support. + + This dataset supports continuous iteration and optional random shuffling. + """ + + def __init__(self, dataset, rng=None, random_shuffle=True): + """Initialize InfiniteDataset. + + Args: + dataset (Iterable): The original dataset to wrap. + rng (Random, optional): Random number generator for shuffling. + random_shuffle (bool): Whether to enable random shuffling. + """ + self.data = list(iter(dataset)) + self.indices = list(range(len(self.data))) + if rng is None: + rng = random.Random() + self.rng = rng + self.random_shuffle = random_shuffle + + def __iter__(self): + """Infinite iterator with optional shuffling. + + Yields: + object: The next data sample from the dataset. + """ + while True: + if self.random_shuffle: + self.rng.shuffle(self.indices) + for i in self.indices: + yield self.data[i] + + +class SequenceDataset(IterableDataset): + """Dataset for creating sequences from multi-source examples. + + This is a stateful dataset that handles sequence generation and packing. + """ + + def __init__( + self, + dataset: MultiSourceDataset, + tokenizer, + max_seq_len: int = 4096, + num_samples_each_epoch: int = 100000, + is_valid: bool = False, + random_seed: int = 11, + random_shuffle: bool = True, + greedy_intokens: bool = False, + ): + """Initialize SequenceDataset. + + Args: + dataset (MultiSourceDataset): The multi-source example dataset. + tokenizer: Tokenizer for text processing. + max_seq_len (int): Maximum sequence length. + num_samples_each_epoch (int): Target samples per epoch. + is_valid (bool): Flag for validation mode. + random_seed (int): Seed for random number generation. + random_shuffle (bool): Enable random shuffling. + greedy_intokens (bool): Use greedy in-token packing strategy. + """ + + self.example_dataset = dataset + self.tokenizer = tokenizer + self.start_token = tokenizer.bos_token # "" + self.end_token = tokenizer.eos_token # "" + self.break_token = tokenizer.sep_token # "" + self.break_turn_token = tokenizer.cls_token # "" + self.sys_start_token = getattr(tokenizer, "sys_start_token", None) + self.sys_end_token = getattr(tokenizer, "sys_end_token", None) + self.max_seq_len = max_seq_len + self.is_valid = is_valid + self.random_seed = random_seed + self.rng = random.Random(random_seed) + self.random_shuffle = random_shuffle + self.greedy_intokens = greedy_intokens + self.origin_dataset_num = 0 + + # For new data concatenation mode + self.begin_of_query = self.tokenizer.tokenize("User: ") + self.begin_of_response = self.tokenizer.tokenize("\nAssistant: ") + self.end_of_response = "<|end_of_sentence|>" + self.end_of_response_id = self.tokenizer._convert_token_to_id( + [self.end_of_response] + )[0] + self.begin_token = "<|begin_of_sentence|>" # Same effect as sys_start_token + self.begin_token_id = self.tokenizer._convert_token_to_id([self.begin_token])[0] + self.newline_token = self.tokenizer.tokenize( + "\n" + ) # Same effect as sys_end_token + + if not is_valid: + for task in self.example_dataset._task_group: + task["target_num_each_epoch"] = int( + task["prob"] * num_samples_each_epoch + ) + inner_dataset = InfiniteDataset( + task["dataset"], self.rng, self.random_shuffle + ) + task["iterator"] = iter(inner_dataset) + task["num_examples"] = len(inner_dataset.data) + logger.info( + f"{task['filepath']}: task prob: {task['prob']}, " + f"ori number of examples: {len(inner_dataset.data)}, " + f"target_num_each_epoch: {task['target_num_each_epoch']}" + ) + self.origin_dataset_num += len(inner_dataset.data) + else: + self.random_shuffle = False + self.greedy_intokens = 0 + self.estimate = False + # The number of valid samples and skipped samples in estimation + self.unused_samples = 0 + self.used_samples = 0 + # If used_estimate_samples exceeds max_estimate_samples,stop estimating. + self.used_estimate_samples = 0 + self.max_estimate_samples = 0 + if not is_valid: + # Set max estimate samples to dataset num examples in default + if len(self.example_dataset._task_group) > 1: + for task in self.example_dataset._task_group: + self.max_estimate_samples += np.ceil( + task["num_examples"] * task["prob_origin"] + ) + else: + self.max_estimate_samples = self.example_dataset._task_group[0][ + "num_examples" + ] + self.epoch_index = 0 + + def __iter_func(self): + """Core iterator function for sequence generation. + + Returns: + Sequence: A processed sequence containing token IDs and labels. + """ + # epoch_rng only use in this epoch. + epoch_rng = np.random.RandomState(self.epoch_index) + worker_info = get_worker_info() + + # prepare epoch data + logger.info("prepare SequenceDataset ...") + examples_all = [] + for task in self.example_dataset._task_group: + if self.is_valid: + examples = [ex for ex in task["dataset"]] + self.origin_dataset_num += len(examples) + else: + examples = [ + next(task["iterator"]) for _ in range(task["target_num_each_epoch"]) + ] + if self.random_shuffle: + epoch_rng.shuffle(examples) + if worker_info is not None: + examples = examples[worker_info.id :: worker_info.num_workers] + examples_all.extend(examples) + if self.random_shuffle: + epoch_rng.shuffle(examples_all) + logger.info( + f"prepare SequenceDataset done: total number of examples is {len(examples_all)}" + ) + + batch_sequence, cur_len = [], 0 + + if self.is_valid: + examples_all = examples_all[::-1] + + if not self.greedy_intokens: + # base + for example in examples_all[::-1]: + actual_example_num = 1 + sequence = self._postprocess_sequence(example, actual_example_num) + if sequence is None: + if self.estimate: + self.unused_samples += actual_example_num + continue + if self.estimate: + self.used_samples += actual_example_num + if cur_len + len(sequence.token_ids) <= self.max_seq_len: + batch_sequence.append(sequence) + cur_len += len(sequence.token_ids) + else: + yield batch_sequence + batch_sequence, cur_len = [sequence], len(sequence.token_ids) + + if self.estimate: + self.used_estimate_samples += actual_example_num + if self.used_estimate_samples >= self.max_estimate_samples: + # Yield left batch sequence before estimation ends + if len(batch_sequence) > 0: + yield batch_sequence + self.used_estimate_samples = 0 + # Set flag to False and yield empty list to signal the end of estimation + self.estimate = False + yield [] + if len(batch_sequence) > 0: + yield batch_sequence + else: + # Pseudo multiple rounds + group greedy intokens. + buffer_size = 500 + examples = [] + actual_example_num_list = [] + i = 0 + for example in examples_all[::-1]: + actual_example_num = 1 + if i < buffer_size: + examples.append(example) + actual_example_num_list.append(actual_example_num) + i += 1 + else: + # Running greedy strategy in examples. + generate_packs = self._generate_greedy_packs( + examples, actual_example_num_list + ) + for pack in generate_packs: + if len(pack) > 0: + yield pack + examples = [example] + i = 1 + + if self.estimate: + self.used_estimate_samples += actual_example_num + # Stop estimation if the number of samples used in estimation is larger than max_estimate_samples + if self.used_estimate_samples >= self.max_estimate_samples: + # Yield left packs before estimation ends + if len(examples) > 0: + generate_packs = self._generate_greedy_packs( + examples, actual_example_num_list + ) + for pack in generate_packs: + if len(pack) > 0: + yield pack + # Set flag to False and yield empty list to signal the end of estimation + self.estimate = False + yield [] + + if len(examples) > 0: + generate_packs = self._generate_greedy_packs( + examples, actual_example_num_list + ) + for pack in generate_packs: + if len(pack) > 0: + yield pack + + self.epoch_index += 1 + + def __iter__(self): + """Iterator interface for the dataset. + + Yields: + Sequence: The generated sequences. + """ + if self.is_valid: + yield from self.__iter_func() + else: + while True: + yield from self.__iter_func() + + def _postprocess_sequence(self, example, actual_example_num): + """Process code completion examples into token sequences. + + Args: + example: The input example containing code components. + actual_example_num (int): Number of examples used. + + Returns: + Sequence: Processed sequence or None if invalid. + """ + + encoded_messages = self.tokenizer.encode_chat_inputs(example.request) + + num_reserved_tokens_for_each_dialog = 1 # only break_turn_token or end_token + num_reserved_tokens_for_each_turn = 8 + + cur_len = num_reserved_tokens_for_each_dialog + + turn_index = len(encoded_messages) - 1 + + tokens = [] + loss_mask = [] + while turn_index >= 0: + tokens_src, tokens_target = encoded_messages[turn_index] + if len(tokens_src) + len(tokens_target) > ( + self.max_seq_len + 1 - cur_len - num_reserved_tokens_for_each_turn + ): + break + + tokens = tokens_src + tokens_target + tokens + + loss_mask = ( + [0] * (len(tokens_src) - 1) + + [example.label[turn_index]] * (len(tokens_target) + 1) + + loss_mask + ) + assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}" + + cur_len = len(tokens) + + turn_index -= 1 + + # Not even one turn can be added, so need to do warning and skip this example + if ( + len(tokens) + <= num_reserved_tokens_for_each_dialog + num_reserved_tokens_for_each_turn + ): + try: + # For print log + sub_src = example.src[0].strip()[:5] + sub_tgt = example.tgt[-1].strip()[-5:] + global LOGGER_COUNT + LOGGER_COUNT += 1 + if LOGGER_COUNT <= 5: + logger.warning( + f"even one turn, example_output:'{{'src':[{sub_src}, ……],'tgt':[……{sub_tgt}]}}'" + ) + except Exception as _: + logger.warning(f"[SKIP] wrong example: {example}") + + return None + + # Maybe left truncated, so need to add begin_token + if tokens[0] != self.begin_token_id: + tokens = [self.begin_token_id] + tokens + loss_mask = [0] + loss_mask + + if len(tokens) > self.max_seq_len: + raise RuntimeError(f"token_ids is too long: {len(tokens)}") + + # Add EOS token at the end + del tokens[-1] + del loss_mask[-1] + labels = tokens[1:] + [self.tokenizer.eos_token_id] + + # end_of_response is a special token that indicates the end of the turn. + # end_token is a special token that indicates the end of the answer. + labels = [ + la if la != self.end_of_response_id else self.tokenizer.eos_token_id + for la in labels + ] + + pos_ids = list(range(len(tokens))) + + if sum(loss_mask) == 0: + logger.warning(f"[SKIP] all labels set to 0: {example}") + return None + + assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}" + assert len(tokens) == len(labels), f"{len(tokens)}-{len(labels)}" + return Sequence( + token_ids=tokens, + position_ids=pos_ids, + labels=labels, + loss_mask=loss_mask, + num_examples=actual_example_num, + ) + + def _generate_greedy_packs(self, examples, actual_example_num_list): + """Generate packed sequences using greedy strategy. + + Args: + examples: List of examples to pack. + actual_example_num_list: List of example counts. + + Returns: + list: List of packed sequences. + """ + + left_len = np.zeros([len(examples)]) - 1 + left_len[0] = ( + self.max_seq_len + ) # At the beginning, only the first pack is valid. + generate_packs = [[] for i in range(len(examples))] + index = 0 + left_index = 0 + + while index < len(examples): + sequence = self._postprocess_sequence( + examples[index], actual_example_num_list[index] + ) + if sequence is None: + if self.estimate: + self.unused_samples += actual_example_num_list[index] + index += 1 + continue + + max_left_index = left_len.argmax() + # Put the current sequence into the largest left space valid pack. + if len(sequence.token_ids) <= left_len[max_left_index]: + generate_packs[max_left_index].append(sequence) + left_len[max_left_index] -= len(sequence.token_ids) + if self.estimate: + self.used_samples += actual_example_num_list[index] + index += 1 + else: + left_index += 1 + left_len[left_index] = self.max_seq_len + + return generate_packs + + +def gen_self_attn_mask(batch_token_ids: List[List[int]], max_seq_len: int): + """Generate self-attention mask for multi-sequence batches. + + Args: + batch_token_ids (List[List[int]]): List of token ID sequences. + max_seq_len (int): Maximum sequence length. + + Returns: + ndarray: 4D attention mask array. + """ + input_mask_data = np.zeros((1, 1, max_seq_len, max_seq_len), dtype="float32") + offset = 0 + for index, token_ids in enumerate(batch_token_ids): + cur_len = len(token_ids) + b = np.tril(np.ones([cur_len, cur_len]), 0) + input_mask_data[0, 0, offset : offset + cur_len, offset : offset + cur_len] = b + offset += cur_len + return input_mask_data + + +def gen_attn_mask_start_row_indices(batch_token_ids: List[List[int]], max_seq_len: int): + """Generate row indices for flash attention masks. + + Args: + batch_token_ids (List[List[int]]): List of token ID sequences. + max_seq_len (int): Maximum sequence length. + + Returns: + ndarray: Row indices array with dtype int32. + """ + offset = 0 + attn_mask_start_row_indices = [] + for token_ids in batch_token_ids: + cur_len = len(token_ids) + attn_mask_start_row_indices.extend([offset + cur_len] * cur_len) + offset += cur_len + if offset < max_seq_len: + attn_mask_start_row_indices.extend(list(range(offset, max_seq_len))) + # NOTE(hehuang): The dtype of attn_mask_start_row_indices must be np.int32 + return np.array(attn_mask_start_row_indices, dtype=np.int32)[None, None] diff --git a/ernie/ERNIE/ernie/dataset/hf/__pycache__/errors.cpython-311.pyc b/ernie/ERNIE/ernie/dataset/hf/__pycache__/errors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a92a1e63314360ec18e99fe262aeaebc57af530 Binary files /dev/null and b/ernie/ERNIE/ernie/dataset/hf/__pycache__/errors.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/hf/__pycache__/hf_parser.cpython-311.pyc b/ernie/ERNIE/ernie/dataset/hf/__pycache__/hf_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2059b7b9e2fc817c5f93741b6ad65aa1fcdfe8d8 Binary files /dev/null and b/ernie/ERNIE/ernie/dataset/hf/__pycache__/hf_parser.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/hf/__pycache__/parse_config.cpython-311.pyc b/ernie/ERNIE/ernie/dataset/hf/__pycache__/parse_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c01802518342f1ca5f7b32c288bcf4fe1b9212a8 Binary files /dev/null and b/ernie/ERNIE/ernie/dataset/hf/__pycache__/parse_config.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dataset/hf/data_info.json b/ernie/ERNIE/ernie/dataset/hf/data_info.json new file mode 100644 index 0000000000000000000000000000000000000000..1a05bdd628c5adbadb1e7b3e16039314540a079d --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/hf/data_info.json @@ -0,0 +1,309 @@ +{ + "llamafactory/alpaca_en": { + "hf_hub_url": "llamafactory/alpaca_en", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "alpaca_data_en_52k.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "llamafactory/alpaca_zh": { + "hf_hub_url": "llamafactory/alpaca_zh", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "alpaca_data_zh_51k.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "llamafactory/alpaca_gpt4_en": { + "hf_hub_url": "llamafactory/alpaca_gpt4_en", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "alpaca_gpt4_data_en.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "llamafactory/alpaca_gpt4_zh": { + "hf_hub_url": "llamafactory/alpaca_gpt4_zh", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "alpaca_gpt4_data_zh.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "BelleGroup/train_2M_CN": { + "hf_hub_url": "BelleGroup/train_2M_CN", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "train_2M_CN.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "BelleGroup/train_1M_CN": { + "hf_hub_url": "BelleGroup/train_1M_CN", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "Belle_open_source_1M.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "BelleGroup/train_0.5M_CN": { + "hf_hub_url": "BelleGroup/train_0.5M_CN", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "Belle_open_source_0.5M.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "BelleGroup/generated_chat_0.4M": { + "hf_hub_url": "BelleGroup/generated_chat_0.4M", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "generated_chat_0.4M.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "BelleGroup/school_math_0.25M": { + "hf_hub_url": "BelleGroup/school_math_0.25M", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "school_math_0.25M.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "sahil2801/CodeAlpaca-20k": { + "hf_hub_url": "sahil2801/CodeAlpaca-20k", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "code_alpaca_20k.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "TIGER-Lab/MathInstruct": { + "hf_hub_url": "TIGER-Lab/MathInstruct", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "MathInstruct.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "YeungNLP/firefly-train-1.1M": { + "hf_hub_url": "YeungNLP/firefly-train-1.1M", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "firefly-train-1.1M.jsonl", + "columns": { + "prompt": "input", + "response": "target" + } + }, + "suolyer/webqa": { + "hf_hub_url": "suolyer/webqa", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "train.json", + "columns": { + "query": "input", + "response": "output" + } + }, + "zxbsmk/webnovel_cn": { + "hf_hub_url": "zxbsmk/webnovel_cn", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "novel_cn_token512_50k.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "AstraMindAI/SFT-Nectar": { + "hf_hub_url": "AstraMindAI/SFT-Nectar", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "sft_data_structured.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "hfl/stem_zh_instruction": { + "hf_hub_url": "hfl/stem_zh_instruction", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "bio_50282.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "llamafactory/OpenO1-SFT": { + "hf_hub_url": "llamafactory/OpenO1-SFT", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "OpenO1-SFT-Pro.jsonl", + "columns": { + "prompt": "prompt", + "response": "response" + } + }, + "Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT": { + "hf_hub_url": "Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT", + "formatting": "alpaca", + "doc_formatting": "jsonl", + "file_name": "distill_r1_110k_sft.jsonl", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "mayflowergmbh/oasst_de": { + "hf_hub_url": "mayflowergmbh/oasst_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "oasst_de.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output", + "history": "history" + } + }, + "mayflowergmbh/dolly-15k_de": { + "hf_hub_url": "mayflowergmbh/dolly-15k_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "dolly_de.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "mayflowergmbh/alpaca-gpt4_de": { + "hf_hub_url": "mayflowergmbh/alpaca-gpt4_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "alpaca_gpt4_data_de.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "mayflowergmbh/openschnabeltier_de": { + "hf_hub_url": "mayflowergmbh/openschnabeltier_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "openschnabeltier.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "mayflowergmbh/evol-instruct_de": { + "hf_hub_url": "mayflowergmbh/evol-instruct_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "evol_instruct_de.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output", + "history": "history" + } + }, + "mayflowergmbh/dolphin_de": { + "hf_hub_url": "mayflowergmbh/dolphin_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "dolphin.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "mayflowergmbh/booksum_de": { + "hf_hub_url": "mayflowergmbh/booksum_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "booksum.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "mayflowergmbh/airoboros-3.0_de": { + "hf_hub_url": "mayflowergmbh/airoboros-3.0_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "airoboros_3.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "mayflowergmbh/ultra-chat_de": { + "hf_hub_url": "mayflowergmbh/ultra-chat_de", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "ultra_chat_german.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + }, + "pleisto/wikipedia-cn-20230720-filtered": { + "hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered", + "formatting": "alpaca", + "doc_formatting": "json", + "file_name": "wikipedia-cn-20230720-filtered.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output" + } + } +} diff --git a/ernie/ERNIE/ernie/dataset/hf/errors.py b/ernie/ERNIE/ernie/dataset/hf/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..bc47c2726577e27ebd005ade0e0c6e3cbf2383fa --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/hf/errors.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parsing exceptions""" + + +class DataSetNoFilePathNorRepoIDError(Exception): + """Exception class for no file_path nor repo_id found when create_dataset.""" + + def __init__(self, msg): + """ + Init exception class for no file_path nor repo_id found when create_dataset + Args: + msg (str): exception message + """ + super().__init__(msg) + + +class DataSetFileNotFoundError(Exception): + """Exception class for no dataset file found.""" + + def __init__(self, msg): + """ + Init exception class for no dataset file found + Args: + msg (str): exception message + """ + super().__init__(msg) + + +class DataSetFileCannotOpenError(Exception): + """Exception class for cannot open the file.""" + + def __init__(self, msg): + """ + Init exception class for cannot open the file + Args: + msg (str): exception message + """ + super().__init__(msg) + + +class DataSetParseError(Exception): + """Exception class for parsing error.""" + + def __init__(self, msg): + """ + Init exception class for parsing error + Args: + msg (str): exception message + """ + super().__init__(msg) + + +class DataSetFormattingNotSupportedError(Exception): + """Exception class for formatting not supported.""" + + def __init__(self, msg): + """ + Init exception class for formatting not supported + Args: + msg (str): exception message + """ + super().__init__(msg) diff --git a/ernie/ERNIE/ernie/dataset/hf/hf_parser.py b/ernie/ERNIE/ernie/dataset/hf/hf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7a87e0ed8b42d2772473eb6d8cad9bbca74b23 --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/hf/hf_parser.py @@ -0,0 +1,440 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" HuggingFace datasets implement. """ +import glob +import json +import os +import random + +from huggingface_hub import snapshot_download +from paddle.io import IterableDataset + +from ernie.dataset.hf import errors, parse_config + + +class BaseDatasetParser(IterableDataset): + """Base class for file parser.""" + + def __init__(self, file_path, formatting, doc_formatting, columns, process_fn=None, shuffle_file=False): + super().__init__() + self.file_path = file_path + self.file_name = os.path.basename(file_path) + self.formatting = formatting + self.doc_formatting = doc_formatting + self.columns = columns + self.r_columns = {} + for k, v in self.columns.items(): + self.r_columns[v] = k + + self.data = [] + self.failed_row = 0 + + self.process_fn = process_fn + self.shuffle_file = shuffle_file + + self.output_file_name = self.file_name + ".ernie.json" + self.output_file_path = os.path.join(parse_config.DATASET_OUTPUT_ROOT, self.output_file_name) + self.output_json_indent = parse_config.DEFAULT_OUTPUT_JSON_INDENT + + def _alpaca_to_erine(self, item): + """Transform alpaca formatted data to ernie formatting""" + src = [ + item.get("prompt", "") + item.get("query", ""), + ] + tgt = [ + item.get("response", ""), + ] + history = [] + system = item.get("system", "") + is_system = False + if system is None or str(system) == "": + history = list(zip(src[:-1], tgt[:-1])) + system = "" + else: + history = list(zip(src[:-1], tgt[:-1])) + src = [system, *src] + tgt = (["", *tgt],) + is_system = True + output = { + # "alpaca": item, + "src": src, + "tgt": tgt, + "history": history, + "system": system, + "is_system": is_system, + } + return output + + def __iter__(self): + """Iterator function for dataset.""" + self.run() + if self.shuffle_file: + random.shuffle(self.data) + for item in self.data: + ex = self._alpaca_to_erine(item) + if self.process_fn is not None: + try: + ex = self.process_fn(ex, self.file_name) + except Exception as e: + print(f"Skip parsing error data in {self.file_name}. Error message: {e}") + continue + # ignore invalid example + if ex is None: + continue + yield ex + + def scan_dataset_file(self): + """ + Scan files under dataset folder and return the first one filename. + """ + files = glob.glob(os.path.join(self.download_path, "*")) + filenames_under_workspace = sorted([filepath.split(os.sep)[-1] for filepath in files]) + filenames = [] + for filename in filenames_under_workspace: + if filename.lower() == 'readme.md': + continue + filenames.append(filename) + if len(filenames) == 0: + msg = f"{self.repo_id} cannot find dataset files after scan, please check or define in dataset_config.py" + raise errors.DataSetFileNotFoundError(msg) + elif len(filenames) > 1: + msg = ( + f"{self.repo_id} cannot find more than one file in dataset files after scan. " + f"please check and define in dataset_config.py. Scanned files: {filenames}" + ) + return filenames[0] + + def check_and_fill_row_alpaca(self, row): + """ + For alpaca formatting, check and fill default value for the essential field like: + prompt/instruction, query/input, response/output. + """ + has_data = False + for key in row: + if row[key] is not None and len(row[key]) > 0: + has_data = True + for key in parse_config.DEFAULT_COLUMN_VALUE_MAPPING: + if key not in row: + row[key] = "" + return has_data + + def check_row(self, row): + """ + Check if the data meets the format requirements. + """ + if self.formatting == "alpaca": + return self.check_and_fill_row_alpaca(row) + return True + + def append_data(self, row): + """ + Append the correct row into data. + """ + if not isinstance(row, dict): + return + if self.check_row(row): + self.data.append(row) + else: + self.failed_row += 1 + + def add_dict_row(self, dict_row): + """ + Mapping the raw dict into the columns. + """ + row = {} + for input_key, output_key in self.r_columns.items(): + row[output_key] = dict_row.get(input_key, parse_config.DEFAULT_COLUMN_VALUE_MAPPING.get(output_key, None)) + return row + + def add_str_row(self, str_row): + """ + Mapping the raw json string into the columns. + """ + line = str_row.strip() + if len(line) == 0: + return None + try: + input = json.loads(str_row) + return self.add_dict_row(input) + except json.decoder.JSONDecodeError as ee: + msg = f"Unformatted json-line: {str_row}, stop" + raise errors.DataSetParseError(msg) + + def parse_json_file(self): + """ + Parse the json-format file into data. + + Returns: + bool (bool): True means success. False means failed. + + Raises: + errors.DataSetFileCannotOpenError (OSError): Cannot open the file. + errors.DataSetParseError (json.decoder.JSONDecodeError): Cannot open the file using json parser. + """ + try: + with open(self.file_path) as fp: + json_data = json.load(fp) + if isinstance(json_data, list): + for item in json_data: + self.append_data(self.add_dict_row(item)) + elif isinstance(json_data, dict): + self.data.append(self.add_dict_row(json_data)) + else: + return False + except OSError as oe: + msg = f"Cannot open dataset file: {self.file_path}" + raise errors.DataSetFileCannotOpenError(msg) + except json.decoder.JSONDecodeError as ee: + msg = f"Unformatted json file: {self.file_path}, stop" + raise errors.DataSetParseError(msg) + return True + + def parse_json_lines_file(self): + """ + Parse jsonl format, which every line is a json string. + + Returns: + bool (bool): True means success. False means failed. + + Raises: + errors.DataSetFileCannotOpenError (OSError): Cannot open the file. + errors.DataSetParseError (json.decoder.JSONDecodeError): Cannot open the file using json parser. + """ + line = "" + try: + with open(self.file_path) as fp: + for line in fp: + self.append_data(self.add_str_row(line)) + except OSError as oe: + msg = f"Cannot open dataset file: {self.file_path}" + raise errors.DataSetFileCannotOpenError(msg) + except json.decoder.JSONDecodeError as ee: + print(f"bad line:{line}, {ee}") + msg = f"Unformatted json file: {self.file_path}, stop" + raise errors.DataSetParseError(msg) + return True + + def output_json(self): + """ + Output data into file which is json format. + """ + if not parse_config.DEBUG_DATASET_OUTPUT_FORMATTED_FILE: + return + if not os.path.exists(parse_config.DATASET_OUTPUT_ROOT): + os.makedirs(parse_config.DATASET_OUTPUT_ROOT) + with open(self.output_file_path, "w") as ofp: + ofp.write(json.dumps(self.data, ensure_ascii=False, indent=2)) + print(f"[DEBUG]Output parsed result as ernie-formatted json at {self.output_file_path}") + + def parse(self): + """ + Parse the dataset files. + """ + if self.doc_formatting == "json": + self.parse_json_file() + elif self.doc_formatting == "jsonl": + self.parse_json_lines_file() + elif self.doc_formatting == "auto": + for func in [self.parse_json_file, self.parse_json_lines_file]: + if self.doc_formatting != "auto": + break + try: + if self.parse_json_file(): + self.doc_formatting = "json" + except Exception: + continue + print( + f"{self.file_name} read {len(self.data)} items successfully and " + f"{self.failed_row} failed from {self.file_path}, doc formatting:{self.doc_formatting}" + ) + + def check_dataset_filename(self): + """ + Check if file exists. + """ + if self.file_path == "": + msg = "file_path should not be empty" + raise errors.DataSetFileNotFoundError(msg) + if not os.path.isfile(os.path.join(self.file_path)): + print(f"Checking file_path:{self.file_path}") + msg = f"cannot find dataset file:{self.file_path}" + raise errors.DataSetFileNotFoundError(msg) + + def run(self): + """ + Parse the dataset from file. + """ + self.check_dataset_filename() + self.parse() + # self.output_json() + + +class HFBaseParser(BaseDatasetParser): + """Hugging Face Base Dataset parser class.""" + + def __init__(self, repo_id, config_map, process_fn=None, shuffle_file=False): + """Init a HFBaseParser from one dataset in data_info.json""" + self.repo_id = repo_id + self.download_path = os.path.join(parse_config.DATASET_DOWNLOAD_ROOT, repo_id) + self.file_name = config_map.get("file_name", "") + self.file_path = os.path.join(self.download_path, self.file_name) + + self.formatting = config_map.get("formatting", "alpaca") + self.doc_formatting = config_map.get("doc_formatting", parse_config.DEFAULT_DOC_FORMATTING) + self.columns = config_map.get("columns", parse_config.DEFAULT_ALPACA_COLUMNS_MAPPING) + super().__init__(self.file_path, self.formatting, self.doc_formatting, self.columns, process_fn, shuffle_file) + + self.output_file_name = repo_id.replace("/", ".") + ".json" + self.output_file_path = os.path.join(parse_config.DATASET_OUTPUT_ROOT, self.output_file_name) + self.output_json_indent = parse_config.DEFAULT_OUTPUT_JSON_INDENT + + def _base_download(self): + """ + Download dataset from hugging-face. + """ + snapshot_download(repo_id=self.repo_id, repo_type="dataset", local_dir=self.download_path) + + def download(self): + """ + Download dataset function. + """ + self._base_download() + + def check_dataset_filename(self): + """ + Check if file exists. + """ + if self.file_name == "": + msg = f"file_name should be defined for {self.repo_id} in data_info.json" + raise errors.DataSetFileNotFoundError(msg) + if not os.path.isfile(os.path.join(self.file_path)): + print(f"Checking file_path:{self.file_path}") + msg = ( + f"{self.repo_id} cannot find dataset file:{self.file_name} " + f'under path: "{self.download_path}". Please check data_info.json.' + ) + raise errors.DataSetFileNotFoundError(msg) + + def run(self): + """ + Download and parse the dataset. + """ + self.download() + self.check_dataset_filename() + self.parse() + # self.output_json() + + +class HFScanParser(HFBaseParser): + """Dataset parser which scan the dataset files without definition in data_info.json""" + + def __init__(self, repo_id, process_fn=None, shuffle_file=False): + """ + Init a HFScanParser to parse the dataset which is not defined in data_info.json. + + Args: + repo_id (str): repo id for hugging-face hub. + """ + super().__init__(repo_id, {}, process_fn, shuffle_file) + + def parse(self): + """ + Parse the dataset file which is not defined in data_info.json. + Firstly try to parse as a json file, then jsonl file. + + Raises: + errors.DataSetFileCannotOpenError (OSError): Cannot open the file. + errors.DataSetParseError (json.decoder.JSONDecodeError): Cannot open the file using json parser. + """ + try: + self.parse_json_file() + except errors.DataSetParseError as e: + self.parse_json_lines_file() + except errors.DataSetFileCannotOpenError as e: + raise e + print( + f"{self.repo_id} read {len(self.data)} items successfully and " + f"{self.failed_row} failed from {self.file_path}" + ) + + def run(self): + """ + Download and parse the dataset. Some parameters are defined after dataset has been downloaded. + """ + self.download() + self.file_name = self.scan_dataset_file() + print(f'Find {self.file_name} under {self.download_path} when scanning "*.json".') + self.file_path = os.path.join(self.download_path, self.file_name) + self.check_dataset_filename() + self.parse() + self.output_json() + + +def load_data_info(): + """ + Load the data_info.json to get all defined dataset info. + """ + with open(parse_config.DATA_INFO_FILE) as fp: + data_info = json.load(fp) + return data_info + + +hf_repo_config_map = load_data_info() + + +def is_hf_dataset(repo_id): + """ + Check if the data_info configuration of the repo-id. + """ + global hf_repo_config_map + return hf_repo_config_map.get(repo_id, None) is not None + + +def create_hf_dataset(repo_id, process_fn=None, shuffle_file=True): + """ + Create a hugging-face repo dataset. + """ + global hf_repo_config_map + config_map = hf_repo_config_map.get(repo_id, None) + if config_map: + parser = HFBaseParser(repo_id, config_map, process_fn, shuffle_file) + else: + parser = HFScanParser(repo_id, process_fn, shuffle_file) + return parser + + +def create_dataset_from_file( + file_path, formatting="alpaca", doc_formatting="json", process_fn=None, shuffle_file=True +): + """ + Create dataset from file function. + + Args: + file_path (str): the file path of dataset. + formatting (str): formatting of the dataset, e.g. alpaca, sharegpt. + doc_formatting (str): document formatting of the dataset, e.g. json, jsonl. + + Returns: + parser (IterableDataset): The iterable dataset object. + + """ + if formatting not in parse_config.DEFAULT_DATASET_COLUMNS_MAPPING: + msg = ( + f"{formatting} is not supported." + f"Please use one of [{', '.join(list(parse_config.DEFAULT_DATASET_COLUMNS_MAPPING.keys()))}]" + ) + raise errors.DataSetFormattingNotSupportedError(f"{formatting}") + columns = parse_config.DEFAULT_DATASET_COLUMNS_MAPPING[formatting] + parser = BaseDatasetParser(file_path, formatting, doc_formatting, columns, process_fn, shuffle_file) + return parser diff --git a/ernie/ERNIE/ernie/dataset/hf/parse_config.py b/ernie/ERNIE/ernie/dataset/hf/parse_config.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbf76b91657a2321164f4105fbd0d2201563f18 --- /dev/null +++ b/ernie/ERNIE/ernie/dataset/hf/parse_config.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config for parse dataset to same format""" +import os + +DATASET_WORKROOT = os.getenv("ERNIE_DATASET_WORKROOT", os.path.abspath(os.path.join(os.path.dirname(__file__)))) +DATASET_DOWNLOAD_ROOT = os.path.join(DATASET_WORKROOT, "download") +DATASET_OUTPUT_ROOT = os.path.join(DATASET_WORKROOT, "output") + +DATA_INFO_FILE = os.path.join(DATASET_WORKROOT, "data_info.json") +DEFAULT_DOC_FORMATTING = "json" + +DEFAULT_ALPACA_COLUMNS_MAPPING = {"prompt": "instruction", "query": "input", "response": "output", "system": "system"} +DEFAULT_COLUMN_VALUE_MAPPING = {"prompt": "", "query": "", "response": ""} +DEFAULT_DATASET_COLUMNS_MAPPING = {"alpaca": DEFAULT_ALPACA_COLUMNS_MAPPING} + +DEFAULT_OUTPUT_JSON_INDENT = 2 + +ALPACA_COLUMNS_EMPTY_CHECK_LIST = ["prompt", "query", "response"] + +DEBUG_DATASET_OUTPUT_FORMATTED_FILE = True diff --git a/ernie/ERNIE/ernie/dfnrope/__pycache__/__init__.cpython-311.pyc b/ernie/ERNIE/ernie/dfnrope/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fda847f32408b4c46c136910c0a2895a293c14a8 Binary files /dev/null and b/ernie/ERNIE/ernie/dfnrope/__pycache__/__init__.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dfnrope/__pycache__/activation.cpython-311.pyc b/ernie/ERNIE/ernie/dfnrope/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75c64429ef226c2fe6a18580d3769974d143d1f3 Binary files /dev/null and b/ernie/ERNIE/ernie/dfnrope/__pycache__/activation.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dfnrope/__pycache__/configuration.cpython-311.pyc b/ernie/ERNIE/ernie/dfnrope/__pycache__/configuration.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d99bc65b6e5868839700d6329e292da1d91caf7c Binary files /dev/null and b/ernie/ERNIE/ernie/dfnrope/__pycache__/configuration.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dfnrope/__pycache__/modeling.cpython-311.pyc b/ernie/ERNIE/ernie/dfnrope/__pycache__/modeling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15e4bd2aa7a7977b10f1313b9612023dae0bef98 Binary files /dev/null and b/ernie/ERNIE/ernie/dfnrope/__pycache__/modeling.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dfnrope/__pycache__/modeling_pp.cpython-311.pyc b/ernie/ERNIE/ernie/dfnrope/__pycache__/modeling_pp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103914871b7d3edd450d0b421f462e1c07da8a07 Binary files /dev/null and b/ernie/ERNIE/ernie/dfnrope/__pycache__/modeling_pp.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/dfnrope/configuration.py b/ernie/ERNIE/ernie/dfnrope/configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..2e718204f01ffecdba1ef00cb55a102f9ebc32ac --- /dev/null +++ b/ernie/ERNIE/ernie/dfnrope/configuration.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Ernie model configuration""" + +from paddleformers.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "DFNRopeVisionTransformerConfig", +] + + +class DFNRopeVisionTransformerConfig(PretrainedConfig): + """ + Configuration class for DFNRopeVisionTransformer model. + This class inherits from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + + model_type = "DFNRope_vision_transformer" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + attn_implementation="eager", # new added + recompute=False, + attn_sep=False, + vit_num_recompute_layers=10000, + **kwargs, + ): + """ + Initialize DFNRopeVisionTransformer model configuration with default or specified parameters. + + Args: + depth (int): Number of transformer layers in the model. + embed_dim (int): Dimensionality of the embedding layer. + hidden_size (int): Dimensionality of the feedforward network. + hidden_act (str): Activation function for the feedforward network. + mlp_ratio (float): Ratio between the number of input features and + the number of output features in the feedforward network. + num_heads (int): Number of attention heads in each attention layer. + in_channels (int): Number of channels in the input image. + patch_size (int): + Size of patches in the input image. Defaults to 14. + spatial_merge_size (int): + Spatial merge size for the spatial transformer module. Defaults to 2. + attn_implementation (str): Attention implementation type. Defaults to "eager". + recompute (bool): Whether to use recompute. Defaults to False. + attn_sep (bool): Whether to separate attention computation into two stages. Defaults to False. + vit_num_recompute_layers (int): Number of recomputed layers for ViT. Defaults to + """ + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.attn_implementation = attn_implementation + self.recompute = recompute + self.attn_sep = attn_sep + self.vit_num_recompute_layers = vit_num_recompute_layers diff --git a/ernie/ERNIE/ernie/dfnrope/modeling.py b/ernie/ERNIE/ernie/dfnrope/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..719bac4dfc2d4114d5bf01fccd2c63841b0b3028 --- /dev/null +++ b/ernie/ERNIE/ernie/dfnrope/modeling.py @@ -0,0 +1,512 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed.fleet.utils import recompute +from paddle.nn.functional.flash_attention import ( + flash_attn_unpadded as flash_attn_varlen_func, +) +from paddleformers.transformers.model_utils import PretrainedModel +from paddleformers.utils.log import logger + +from ..distributed import get_hcg +from .activation import ACT2FN +from .configuration import DFNRopeVisionTransformerConfig + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + input, + group, + output_split_sizes=None, + input_split_sizes=None, + ): + """ + All-to-all communication in the group. + + Args: + ctx (Any): Context object. + input (Tensor): Input tensor. + group (Group): The group object. + + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + # return input + if dist.get_world_size(group) <= 1: + return input + if input_split_sizes is None and output_split_sizes is None: + output = paddle.empty_like(input) + task = dist.stream.alltoall_single(output, input, None, None, group, True, True) + task.wait() + else: + out_sizes = [sum(output_split_sizes)] + out_sizes.extend(input.shape[1:]) + output = paddle.empty(out_sizes, dtype=input.dtype) + task = dist.stream.alltoall_single( + output, input, output_split_sizes, input_split_sizes, group, sync_op=False + ) + task.wait() + return output + + @staticmethod + def backward(ctx, *grad_output): + """ + all-to-all backward + + """ + # return grad_output + if ctx.input_split_sizes is None and ctx.output_split_sizes is None: + return _AllToAll.apply(*grad_output, ctx.group) + else: + return _AllToAll.apply(*grad_output, ctx.group, ctx.input_split_sizes, ctx.output_split_sizes) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) -> paddle.Tensor: + """Applies Rotary Position Embedding to the input tensors. + + Args: + tensor (paddle.Tensor): The input tensor. + freqs (paddle.Tensor): The frequencies used for the rotation. + Returns: + output (paddle.Tensor): the tensor rotated using the Rotary Position Embedding. + """ + orig_dtype = tensor.dtype + + with paddle.amp.auto_cast(False): + tensor = tensor.astype(dtype="float32") + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") + sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") + output = tensor * cos + rotate_half(tensor) * sin + output = paddle.cast(output, orig_dtype) + return output + + +def qkv_reshard_head(tensor, group): + """ + After concatenating qkv in the seq dimension, perform the split dimension conversion together + """ + parallelism = group.nranks + qkv_seqlen, head_num, head_dim = tensor.shape + tensor = tensor.transpose(perm=[1, 0, 2]).contiguous() + out = _AllToAll.apply(tensor, group) + out = paddle.split(out, parallelism, axis=0) + output_q = [] + output_k = [] + output_v = [] + for output_i in out: + outout = output_i.transpose(perm=[1, 0, 2]).contiguous() + output = paddle.split(outout, 3, axis=0) + output_q.append(output[0]) + output_k.append(output[1]) + output_v.append(output[2]) + q = paddle.concat(output_q, axis=0) + k = paddle.concat(output_k, axis=0) + v = paddle.concat(output_v, axis=0) + return q, k, v + + +class VisionFlashAttention2(nn.Layer): + """VisionFlashAttention2""" + + def __init__(self, dim: int, num_heads: int = 16) -> None: + """ + Args: + dim (int): the dimension of each token. + num_heads (int, optional): number of heads. Default: 16 + """ + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) + self.proj = nn.Linear(dim, dim) + self.head_dim = dim // num_heads # must added + + def forward( + self, + hidden_states: paddle.Tensor, + cu_seqlens: paddle.Tensor, + rotary_pos_emb: paddle.Tensor = None, + attn_sep=False, + ) -> paddle.Tensor: + """ + Args: + hidden_states (paddle.Tensor): hidden states + cu_seqlens (paddle.Tensor): cumulative sequence lengths, with shape [batch_size + 1] + rotary_pos_emb (paddle.Tensor, optional): rotary position embedding. Default: None + Returns: + paddle.Tensor: output tensor + """ + seq_length = hidden_states.shape[0] + qkv = self.qkv(hidden_states).reshape([seq_length, 3, self.num_heads, -1]).transpose(perm=[1, 0, 2, 3]) + q, k, v = qkv.unbind(axis=0) + + if attn_sep: + hcg = get_hcg() + mp_group = hcg.get_model_parallel_group() + qkv = paddle.concat([q, k, v], axis=0) + q, k, v = qkv_reshard_head(qkv, mp_group) + seq_length = q.shape[0] + + q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + softmax_scale = self.head_dim**-0.5 # TODO: Need to add manually + + attn_output = ( + flash_attn_varlen_func( # flash_attn_unpadded + q.astype("bfloat16"), # do not support float32 + k.astype("bfloat16"), + v.astype("bfloat16"), + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + scale=softmax_scale, # TODO: Need to add manually + )[0] + .squeeze(0) + .reshape([seq_length, -1]) + ) + if attn_sep: + out = _AllToAll.apply(attn_output, mp_group) + out = paddle.split(out, mp_group.nranks, axis=0) + attn_output = paddle.concat(out, axis=1) + attn_output = attn_output.astype(paddle.float32) + attn_output = self.proj(attn_output) + return attn_output + + +class PatchEmbed(nn.Layer): + """PatchEmbed""" + + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + """ + Args: + patch_size (int, optional): patch size. Defaults to 14. + in_channels (int, optional): number of channels. Defaults to 3. + embed_dim (int, optional): embedding dimension. Defaults to 1152. + """ + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + self.proj = nn.Linear(in_channels * patch_size * patch_size, embed_dim, bias_attr=False) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """ + Args: + hidden_states (paddle.Tensor): hidden states + + Returns: + paddle.Tensor: output tensor + """ + target_dtype = self.proj.weight.dtype + + hidden_states = self.proj(paddle.cast(hidden_states, dtype=target_dtype)) + + return hidden_states + + +class VisionMlp(nn.Layer): + """VisionMLP""" + + def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.act = ACT2FN[hidden_act] + self.fc2 = nn.Linear(hidden_dim, dim) + + def forward(self, x) -> paddle.Tensor: + """ + Args: + x (paddle.Tensor): input tensor + + Returns: + paddle.Tensor: VisionMLP output tensor + """ + return self.fc2(self.act(self.fc1(x))) + + +class VisionRotaryEmbedding(nn.Layer): + """VisionRotaryEmbedding""" + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + """ + Args: + dim (int): the dimension of each token. + theta (float, optional): the frequency factor. Defaults to 10000.0. + """ + super().__init__() + self.inv_freq = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim) + + def forward(self, seqlen: int) -> paddle.Tensor: + """ + Args: + seqlen (int): length of sequence. + + Returns: + paddle.Tensor: rotary position embedding + """ + seq = paddle.arange(seqlen).cast(self.inv_freq.dtype) + freqs = paddle.outer(x=seq, y=self.inv_freq) + return freqs + + +class DFNRopeVisionBlock(nn.Layer): + """DFNRopeVisionBlock""" + + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + """ + Args: + config (dict): model configuration. + attn_implementation (str, optional): attention implementation. Defaults to "sdpa". + """ + super().__init__() + self.norm1 = nn.LayerNorm(config.embed_dim, epsilon=1e-6) + self.norm2 = nn.LayerNorm(config.embed_dim, epsilon=1e-6) + mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) + + self.attn = VisionFlashAttention2(config.embed_dim, num_heads=config.num_heads) + self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) + self.config = config + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb, attn_sep=False) -> paddle.Tensor: + """ + Args: + hidden_states(paddle.Tensor): hidden states + cu_seqlens (paddle.Tensor): cumulative sequence lengths + rotary_pos_emb: rotary position embedding + + Returns: + paddle.Tensor: output tensor + """ + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + attn_sep=attn_sep, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class PatchMerger(nn.Layer): + """PatchMerger""" + + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + """ + Args: + dim (int): output dimension + context_dim (int): input dimension + spatial_merge_size (int, optional): spatial merge size. Defaults to 2. + """ + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = nn.LayerNorm(context_dim, epsilon=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """ + Args: + x (paddle.Tensor): input tensor + + Returns: + paddle.Tensor: PatchMerger output tensor + """ + x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size])) + return x + + +class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): + """DFNRopeVisionTransformerPretrainedModel""" + + config_class = DFNRopeVisionTransformerConfig + + def __init__(self, config) -> None: + """ + Args: + config (dict): model configuration + """ + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.LayerList([DFNRopeVisionBlock(config) for _ in range(config.depth)]) + + assert ( + config.hidden_size == config.embed_dim + ), "in DFNRope, vit's config.hidden must be equal to config.embed_dim" + # self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim) + self.ln = nn.LayerNorm(config.hidden_size, epsilon=1e-6) + + def get_dtype(self) -> paddle.dtype: + """ + Returns: + paddle.dtype: data type + """ + return self.blocks[0].mlp.fc2.weight.dtype + + def rot_pos_emb(self, grid_thw, num_pad=0): + """rot_pos_emb + + Args: + grid_thw (paddle.Tensor): grid thw of input + + Returns: + paddle.Tensor: rotary position embedding + """ + pos_ids = [] + grid_hw_array = np.array(grid_thw, dtype=np.int64) + for t, h, w in grid_hw_array: + hpos_ids = np.arange(h).reshape(-1, 1) + hpos_ids = np.tile(hpos_ids, (1, w)) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = np.transpose(hpos_ids, (0, 2, 1, 3)) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.arange(w).reshape(1, -1) + wpos_ids = np.tile(wpos_ids, (h, 1)) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = np.transpose(wpos_ids, (0, 2, 1, 3)) + wpos_ids = wpos_ids.flatten() + + stacked_ids = np.stack([hpos_ids, wpos_ids], axis=-1) + tiled_ids = np.tile(stacked_ids, (t, 1)) + pos_ids.append(tiled_ids) + + pos_ids = np.concatenate(pos_ids, axis=0) + if num_pad > 0: + pos_ids = np.concatenate([pos_ids, np.zeros((num_pad, 2), dtype=pos_ids.dtype)]) + max_grid_size = np.amax(grid_hw_array[:, 1:]) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1) + return rotary_pos_emb + + def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor, num_pad=0) -> paddle.Tensor: + """ + Args: + hidden_states (paddle.Tensor): input tensor + grid_thw (paddle.Tensor): grid thw of input + num_pad (int): number of padding tokens + + Returns: + paddle.Tensor: output tensor + """ + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw, num_pad=num_pad) + + cu_seqlens = paddle.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + axis=0, dtype="int32" + ) + + if num_pad > 0: + cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + attn_sep = getattr(self.config, "attn_sep", False) + vit_num_recompute_layers = getattr(self.config, "vit_num_recompute_layers", self.config.depth) + + for idx, blk in enumerate(self.blocks): + if self.config.recompute and self.training and idx < vit_num_recompute_layers: + hidden_states = recompute(blk, hidden_states, cu_seqlens, rotary_pos_emb, attn_sep) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + attn_sep=attn_sep, + ) + + # ret = self.merger(hidden_states) + # ret = hidden_states + ret = self.ln(hidden_states) # add norm + return ret + + def extract_feature(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor: + """ + Args: + hidden_states (paddle.Tensor): input tensor + grid_thw (paddle.Tensor): grid thw of input + + Returns: + paddle.Tensor: output tensor + """ + return self.forward(hidden_states, grid_thw) + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + """ + dummy + """ + return {} + + def set_state_dict(self, state_dict, *args, **kwargs): + """ + Args: + state_dict (Mapping[str, Any]): state_dict + """ + ret = super().set_state_dict(state_dict, *args, **kwargs) + logger.info(f"dfn rope set_state_dict: {ret}") diff --git a/ernie/ERNIE/ernie/distributed/__init__.py b/ernie/ERNIE/ernie/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95974ed0c74e0d6ba2919f576ae74599909f0f1e --- /dev/null +++ b/ernie/ERNIE/ernie/distributed/__init__.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Distributed utils +""" +import paddle + +# Import / Override modules for specific devices +if paddle.is_compiled_with_xpu(): + from .common_dist_utils import ( + AllGatherVarlenOp, + RRColumnSequenceParallelLinear, + RRRowSequenceParallelLinear, + get_hcg, + mark_as_sequence_parallel_parameter, + sequence_parallel_sparse_mask_labels, + ) + from .xpu_dist_utils import ( + ColumnParallelLinear, + ColumnSequenceParallelLinear, + GatherOp, + RowParallelLinear, + RowSequenceParallelLinear, + ScatterOp, + ) +else: + from .common_dist_utils import ( + AllGatherVarlenOp, + ColumnParallelLinear, + ColumnSequenceParallelLinear, + GatherOp, + RowParallelLinear, + RowSequenceParallelLinear, + RRColumnSequenceParallelLinear, + RRRowSequenceParallelLinear, + ScatterOp, + get_hcg, + mark_as_sequence_parallel_parameter, + sequence_parallel_sparse_mask_labels, + ) + +__all__ = [ + "ColumnParallelLinear", + "ColumnSequenceParallelLinear", + "RowParallelLinear", + "RowSequenceParallelLinear", + "GatherOp", + "ScatterOp", + "mark_as_sequence_parallel_parameter", + "ParallelCrossEntropy", + "get_rng_state_tracker", + "parallel_matmul", + "RRColumnSequenceParallelLinear", + "RRRowSequenceParallelLinear", + "AllGatherVarlenOp", + "sequence_parallel_sparse_mask_labels", + "get_hcg", +] + + +def parallel_matmul( + x, + y, + bias=None, + transpose_y=False, + tensor_parallel_degree=1, + tensor_parallel_output=True, + fuse_linear=False, + training=None, +): + """ + Parallel matmul wrapper. + + Args: + x (Tensor): Input tensor. + y (Tensor): Weight tensor. + bias (Tensor, optional): Bias tensor. Default is None. + transpose_y (bool, optional): Whether to transpose y. Default is False. + tensor_parallel_degree (int, optional): Tensor parallel degree. Default is 1. + tensor_parallel_output (bool, optional): Whether to output tensor parallel. Default is True. + fuse_linear (bool, optional): Whether to fuse linear. Default is False. + training (bool, optional): Training state. Default is None. + Returns: + Tensor: Output tensor. + """ + if paddle.is_compiled_with_xpu(): + from .common_dist_utils import _parallel_matmul as default_parallel_matmul + from .xpu_dist_utils import parallel_matmul as xpu_parallel_matmul + + if xpu_parallel_matmul is not None: + return xpu_parallel_matmul()( + x, + y, + bias=bias, + transpose_y=transpose_y, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_output=tensor_parallel_output, + fused_linear=fuse_linear, + ) + else: + return default_parallel_matmul( + x, + y, + bias=bias, + transpose_y=transpose_y, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_output=tensor_parallel_output, + fuse_linear=fuse_linear, + ) + else: + from .common_dist_utils import _parallel_matmul + + return _parallel_matmul( + x, + y, + bias=bias, + transpose_y=transpose_y, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_output=tensor_parallel_output, + fuse_linear=fuse_linear, + ) diff --git a/ernie/ERNIE/ernie/distributed/__pycache__/__init__.cpython-311.pyc b/ernie/ERNIE/ernie/distributed/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c68b668232cca3cfad91e2d6fbd72bbb167a6b9 Binary files /dev/null and b/ernie/ERNIE/ernie/distributed/__pycache__/__init__.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/distributed/__pycache__/common_dist_utils.cpython-311.pyc b/ernie/ERNIE/ernie/distributed/__pycache__/common_dist_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72531b3f788220ae1ef64ea7337652a8099d9af0 Binary files /dev/null and b/ernie/ERNIE/ernie/distributed/__pycache__/common_dist_utils.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/distributed/common_dist_utils.py b/ernie/ERNIE/ernie/distributed/common_dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf99b307ab0f457947faf20479da70ebd0d8d60d --- /dev/null +++ b/ernie/ERNIE/ernie/distributed/common_dist_utils.py @@ -0,0 +1,688 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Common distributed utils. +""" + +import paddle +import paddle.nn.functional as F +from paddle import distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, +) +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + AllGatherOp, + ColumnSequenceParallelLinear, + GatherOp, + ReduceScatterOp, + RowSequenceParallelLinear, + ScatterOp, + all_gather, + mark_as_sequence_parallel_parameter, + scatter, +) +from paddle.incubate.tensor.manipulation import create_async_load + +from ..refined_recompute.utils import RefinedRecomputeFunction + +__all__ = [ + "get_hcg", + "_parallel_matmul", + "scatter_axis", + "mp_slice", + "all_gather_varlen", + "ColumnParallelLinear", + "ColumnSequenceParallelLinear", + "RowParallelLinear", + "RowSequenceParallelLinear", + "GatherOp", + "ScatterOp", + "mark_as_sequence_parallel_parameter", + "RRColumnSequenceParallelLinear", + "RRRowSequenceParallelLinear", + "AllGatherVarlenOp", + "sequence_parallel_sparse_mask_labels", + "get_async_loader", + "hack_offload_wait", + "hack_reload_wait", + "all_gather_group", + "reduce_scatter_group", +] + + +def get_hcg(): + """ + Get hybrid communicate group. + """ + return fleet.get_hybrid_communicate_group() + + +def _parallel_matmul( + x, + y, + bias=None, + transpose_y=False, + tensor_parallel_degree=1, + tensor_parallel_output=True, + fuse_linear=False, +): + """ + Performs parallel matrix multiplication with tensor model parallelism support. + + Args: + x (paddle.Tensor): Input tensor with shape [batch_size, seq_len, hidden_size] + y (Union[paddle.Tensor, EagerParamBase]): Weight matrix which can be: + - Regular tensor + - Distributed parameter in tensor parallel mode + bias (Optional[paddle.Tensor]): Optional bias tensor + transpose_y (bool): Whether to transpose the 'y' matrix before multiplication + tensor_parallel_degree (int): Degree of tensor model parallelism (default: 1) + tensor_parallel_output (bool): Whether to keep output in tensor parallel format + or gather across devices (default: True) + fuse_linear (bool): Whether to use fused linear operation for optimization + + Returns: + paddle.Tensor + + Raises: + AssertionError: If tensor parallel is enabled but weight is not distributed + AttributeError: If called without distributed.launch context + """ + if tensor_parallel_degree > 1: + if isinstance(y, paddle.base.framework.EagerParamBase): + assert y.is_distributed + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + pg = fleet.get_hybrid_communicate_group().get_model_parallel_group() + input_parallel = paddle.distributed.collective._c_identity(x, group=pg) + + if transpose_y: + logits = paddle.matmul(input_parallel, y, transpose_y=True) + if bias is not None: + logits += bias + else: + if fuse_linear: + logits = paddle.incubate.nn.functional.fused_linear(input_parallel, y, bias) + else: + logits = F.linear(input_parallel, y, bias) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=pg) + + else: + if fuse_linear: + logits = paddle.incubate.nn.functional.fused_linear(x, y, bias, transpose_weight=transpose_y) + else: + logits = paddle.matmul(x, y, transpose_y=transpose_y) + if bias is not None: + logits += bias + return logits + + +def scatter_axis(input, group=None, axis=0): + """ + Uniformly splits the `input` along dimension 0 across model parallel groups. + This API is not related to `distributed.scatter`. + + Args: + input: Input tensor to be split + group: Communication group for parallel processing (default: model parallel group) + axis: Dimension along which to split (default: 0) + + Returns: + A slice of the input tensor corresponding to this rank's portion + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + rank = group.rank + seq_len = input.shape[axis] + assert seq_len % parallelism == 0, ( + f"Input sequence length {seq_len} can't be divided exactly" f" by sequence parallelism {parallelism}" + ) + interval = seq_len // parallelism + input = paddle.slice(input, axes=[axis], starts=[interval * rank], ends=[interval * (rank + 1)]) + # slice uses stride, so we maintain the memory of whole input, use assign to free the whole input + # which can avoid OOM. + input = paddle.assign(input) + return input + + +def mp_slice(x, indices=None, group=None, axis=0): + """ + Slices tensor `x` along dimension 0 according to `indices` without communication. + + Args: + x: Input tensor to be sliced + indices: List of indices defining how to slice the tensor + group: Communication group for parallel processing (default: model parallel group) + axis: Dimension along which to slice (default: 0) + + Returns: + A slice of the input tensor corresponding to this rank's portion + """ + if indices is None: + return scatter(x, group, axis) + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return x + rank = group.rank + assert len(indices) == parallelism, (len(indices), parallelism) + indices = F.pad(paddle.to_tensor(indices).cumsum(0), [1, 0]) + input = paddle.slice(x, axes=[axis], starts=[indices[rank]], ends=[indices[rank + 1]]) + input = paddle.assign(input) + return input + + +def all_gather_varlen(input, indices, group=None, axis=0, sync_op=True): + """ + Variable-length version of `all_gather` that behaves similarly to `distributed.all_gather`. + + Args: + input: Local tensor to be gathered + indices: List of sizes from each rank indicating how much to gather from each + group: Communication group for parallel processing (default: model parallel group) + axis: Dimension along which to gather (only 0 is supported) + sync_op: Whether to synchronize the operation + + Returns: + A concatenated tensor containing all gathered data + """ + assert axis == 0, "only support axis=0" + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + input_sizes = [len(input)] * parallelism + output_sizes = indices + out = paddle.empty([sum(indices)] + input.shape[1:], dtype=input.dtype) + task = dist.stream.alltoall_single( + out, + paddle.concat([input] * parallelism, 0) if len(input) else input, # TODO: check this + output_sizes, # input-size + input_sizes, + group=group, + sync_op=sync_op, + use_calc_stream=sync_op, + ) + task.wait() + return out + + +class ReduceScatterGroupOp(PyLayer): + """ + Perform group reduce scatter. + """ + + @staticmethod + def forward(ctx, input, group=None): + """Forward pass: Reduce-Scatter operation + Args: + input (Tensor): Input tensor with shape [s, b, h]. + The 's' dimension will be split across model parallel group. + group (ProcessGroup): Model parallel process group, + uses global group by default. + Returns: + Tensor: Output tensor after Reduce-Scatter with shape [s/n, b, h], + each device holds partial data of the original input. + """ + ctx.group = group + return reduce_scatter_group(input, group=group) + + @staticmethod + def backward(ctx, grad): + """Backward pass: All-Gather operation + Args: + grad (Tensor): Upstream gradient with shape [s/n, b, h] + Returns: + Tensor: Full gradient after All-Gather with restored shape [s, b, h], + aggregating gradients from all devices in model parallel group. + """ + return all_gather_group(grad, group=ctx.group) + + +class AllGatherGroupOp(PyLayer): + """ + Perform group allgather. + """ + + @staticmethod + def forward(ctx, input, group=None): + """Forward pass: All-Gather operation + Args: + input (Tensor): Partitioned tensor with shape [s/n, b, h] + The 's' dimension is distributed across devices + group (ProcessGroup): Model parallel process group, + uses global group by default + Returns: + Tensor: Assembled tensor after All-Gather with shape [s, b, h], + containing full parameter from all devices + """ + ctx.group = group + return all_gather_group(input, group=group) + + @staticmethod + def backward(ctx, grad): + """Backward pass: Reduce-Scatter operation + Args: + grad (Tensor): Full gradient tensor with shape [s, b, h] + Returns: + Tensor: Scattered gradient with shape [s/n, b, h], + distributing reduced gradients to each device + """ + return reduce_scatter_group(grad, group=ctx.group) + + +class RRColumnSequenceParallelLinear(ColumnSequenceParallelLinear): + """ + ColumnSequenceParallelLinear with refined recompute. + """ + + def __init__( + self, + in_features, + out_features, + weight_attr=None, + has_bias=None, + gather_output=True, + fuse_matmul_bias=False, + mp_group=None, + use_rr=False, + name=None, + ): + """ + Initializes a ColumnSequenceParallelLinear module. + + Args: + in_features (int): The number of input features. + out_features (int): The number of output features. + weight_attr (ParamAttr, optional): The parameter attribute for the learnable + weight matrix. Default: None. + has_bias (bool, optional): Whether the layer uses a bias. By default, it is set to False. + If ``has_bias`` is set to False, no bias term is used. If ``has_bias`` is set to True, + a bias vector is used. Default: None, which means inherit the value of `has_bias` + from the current instance's `has_bias`. + gather_output (bool, optional): Whether to gather all outputs from all ranks during forward pass. + Default: True. If True, all outputs from all ranks are gathered during forward pass, which + makes sure that each example's output is produced only once. If False, all outputs are + produced on each rank separately, and the outputs from different ranks may overlap. + This can save communication time but may cause slower convergence. Default: True. + fuse_matmul_bias (bool, optional): Whether to fuse matmul and bias into one op. Default: False. + mp_group (paddle.distributed.Group, optional): The group for model parallel. Default: None. + use_rr (bool, optional): Whether to use refined rcompute. Default: False. + name (str, optional): Name for the instance to use in tracebacks. Default: None. + """ + super().__init__( + in_features=in_features, + out_features=out_features, + weight_attr=weight_attr, + has_bias=has_bias, + gather_output=gather_output, + fuse_matmul_bias=fuse_matmul_bias, + mp_group=mp_group, + name=name, + ) + + self._rr_column_ln = RefinedRecomputeFunction() if use_rr else None + if self.weight.is_distributed: + self.weight.split_axis = 1 + if has_bias and self.bias.is_distributed: + self.bias.split_axis = 0 + + def forward(self, x): + """ + Forward pass function that computes the product of the input tensor and model parameters. + + Args: + x (paddle.Tensor): Input tensor with shape (batch_size, seq_len, hidden_size) or (batch_size, hidden_size). + If sequence parallel is True, the shape is (seq_len, batch_size, hidden_size). + + Returns: + paddle.Tensor: Returns a tensor with shape (batch_size, seq_len, hidden_size) or (batch_size, hidden_size). + If sequence parallel is True, the shape is (seq_len, batch_size, hidden_size). + """ + # sequence parallelism is same as model parallelism + # if sequence parallel is true, input shape is [s, b, h] + # else input shape is [b, s, h] + if self.is_mp: + input_parallel = AllGatherOp.apply(x) + else: + input_parallel = x + + if self._rr_column_ln is not None and self.training: # in eval mode, do not use refined recompute + output = self._rr_column_ln( + self.linear, + x=input_parallel, + weight=self.weight, + bias=self.bias, + ) + else: + output = self.linear(input_parallel, self.weight, self.bias, name=self._name) + return output + + +class RRRowSequenceParallelLinear(RowSequenceParallelLinear): + """ + RowSequenceParallelLinear with refined recompute. + """ + + def __init__( + self, + in_features, + out_features, + weight_attr=None, + has_bias=True, + input_is_parallel=False, + fuse_matmul_bias=False, + mp_group=None, + use_rr=False, + name=None, + ): + """ + Args: + in_features (int): The number of input features. + out_features (int): The number of output features. + weight_attr (ParamAttr, optional): The parameter attribute for the learnable + weight matrix. Defaults to None. If it is None, the system will + generate a default Attribute object. + has_bias (bool, optional): Whether the layer uses a bias term. Defaults to True. + input_is_parallel (bool, optional): Whether the input is parallel. Defaults to False. + fuse_matmul_bias (bool, optional): Whether to fuse matmul and bias into one kernel. Defaults to False. + mp_group (Group, optional): Model parallel group. Defaults to None. + use_rr (bool, optional): Whether to use refined rr. Defaults to False. + name (str, optional): Name of the layer. Defaults to None. + """ + super().__init__( + in_features=in_features, + out_features=out_features, + weight_attr=weight_attr, + has_bias=has_bias, + input_is_parallel=input_is_parallel, + fuse_matmul_bias=fuse_matmul_bias, + mp_group=mp_group, + name=name, + ) + + self._rr_row_ln = RefinedRecomputeFunction() if use_rr else None + + if self.weight.is_distributed: + self.weight.split_axis = 0 + + def forward(self, x): + """ + Forward pass function that computes the product of the input tensor and model parameters. + + Args: + x (paddle.Tensor): Input tensor with shape (batch_size, in_features). + + Returns: + paddle.Tensor: Returns a tensor with shape (batch_size, out_features). + """ + input_parallel = x + if self.is_mp: + if self.mp_scale is not None: + bias = self.mp_scale(self.bias, self.world_size) + else: + bias = None + + def linear_reduce_scatter(input, weight, bias=None, name=None): + output = self.linear(input, weight=weight, bias=bias, name=name) + return ReduceScatterOp.apply(output) + + if self._rr_row_ln is not None and self.training: # in eval mode, do not use refined recompute + output_ = self._rr_row_ln( + linear_reduce_scatter, + input_parallel, + self.weight, + bias=bias, + name=self._name, + ) + else: + output_ = linear_reduce_scatter(input_parallel, self.weight, bias=bias, name=self._name) + + # if self.bias is not none, sequence parallel will use + # register_hook to all_reduce self.bias + if bias is None and self.bias is not None: + output = output_ + self.bias + else: + output = output_ + else: + output = self.linear(input_parallel, self.weight, self.bias, name=self._name) + return output + + +class AllGatherVarlenOp(PyLayer): + """ + A custom PyLayer that performs variable-length allgather operation. + + This operation handles tensors with different shapes across ranks by: + 1. Gathering shape information from all ranks + 2. Padding tensors to maximum size + 3. Performing allgather + 4. Reconstructing the original variable-length tensors + """ + + @staticmethod + def forward(ctx, input): + """Forward pass for variable-length allgather operation. + + Args: + ctx: PyLayer context for saving state + input (Tensor): Input tensor to be gathered (may have different sizes across ranks) + + Returns: + Tensor: Concatenated output from all ranks with original lengths + """ + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + + shape0 = paddle.to_tensor([input.shape[0]]) + shape0_all = paddle.empty(shape=[group.nranks], dtype=shape0.dtype) + dist.stream.all_gather(shape0_all, shape0, group=group, use_calc_stream=True) + shape0_all = shape0_all.numpy() + max_shape0 = shape0_all.max() + + indices = [] + for idx, s in enumerate(shape0_all): + offset = idx * max_shape0 + indices.extend(list(range(offset, offset + s))) + indices = paddle.to_tensor(indices) + + padding = max_shape0 - input.shape[0] + + ctx.shape0 = input.shape[0] + ctx.max_shape0 = max_shape0 + ctx.shape0_all = shape0_all + ctx.padding = padding + ctx.indices = indices + + if padding > 0: + input_shape = input.shape + input_shape[0] = padding + padding_tensor = paddle.empty(shape=input_shape, dtype=input.dtype) + input = paddle.concat([input, padding_tensor], axis=0) + output = all_gather(input) + output = paddle.gather(output, indices, axis=0) + + return output + + @staticmethod + def backward(ctx, grad): + """Backward pass for variable-length allgather operation. + + Args: + ctx: PyLayer context with saved state + grad (Tensor): Gradient flowing back through the graph + + Returns: + Tensor: Scattered gradient with original variable lengths + """ + input_shape = grad.shape + input_shape[0] = ctx.max_shape0 * ctx.shape0_all.shape[0] + output = paddle.zeros(shape=input_shape, dtype=grad.dtype) + + # grad = paddle.put_along_axis(output, ctx.indices, grad, axis=0) + grad = paddle.scatter(output, ctx.indices, grad) + grad = scatter(grad) + + if ctx.padding > 0: + grad = grad[: ctx.shape0] + return grad + + +def sequence_parallel_sparse_mask_labels(labels, ignore_label=-100): + """ + Processes sparse labels in sequence parallel training by gathering non-ignored labels across all ranks. + + This function handles the case where labels may contain ignored values (typically -100) by: + 1. Distributing labels across model parallel ranks + 2. Identifying and gathering only valid (non-ignored) labels + 3. Performing a variable-length allgather operation to collect all valid labels + + Args: + labels (paddle.Tensor): The input label tensor which may contain ignore_label values. + Shape should be compatible with model parallel distribution. + ignore_label (int, optional): The value used to indicate labels that should be ignored. + Defaults to -100 (common convention in NLP tasks). + + Returns: + tuple: Contains two elements: + - labels_all_gather (paddle.Tensor): Concatenated tensor of all non-ignored labels + from all model parallel ranks. + - tgt_index (paddle.Tensor): Indices of the non-ignored labels in the local rank's + portion of the original labels tensor. + + Note: + - This function assumes sequence parallel training is being used. + - If a rank has no valid labels (all ignored), it will still contribute one dummy label + (index 0) to maintain consistency in the distributed computation. + - The returned tgt_index can be used to reconstruct the original label positions. + """ + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + labels = labels.flatten() + labels_local = paddle.split(labels, group.nranks)[group.rank] + + tgt_index = paddle.nonzero(labels_local != ignore_label).reshape([-1]) + if tgt_index.numel() == 0: + tgt_index = paddle.to_tensor([0]) + + labels_local_gather = paddle.gather(labels_local, tgt_index, axis=0) + labels_all_gather = AllGatherVarlenOp.apply(labels_local_gather) + return labels_all_gather, tgt_index + + +async_loader = None + + +def get_async_loader(): + """get_async_loader""" + global async_loader + if not hasattr(fleet.fleet, "_hcg"): + if async_loader is None: + async_loader = create_async_load() + return async_loader + + hcg = get_hcg() + if not hasattr(hcg, "async_loader"): + hcg.async_loader = create_async_load() + return hcg.async_loader + + +def hack_offload_wait(task): + """hack_offload_wait""" + task.cpu_wait() + + +def hack_reload_wait(task): + """hack_offload_wait""" + task.cuda_wait() + + +def all_gather_group(input, group=None, axis=0): + """Perform collective all-gather operation across a process group with axis control. + + Functional Behavior: + - Aggregates input tensors from all processes in the specified group + - Supports concatenation along arbitrary dimensions (axis parameter) + - Optimizes for axis=0 via direct shape expansion to avoid concatenation overhead + + Args: + input (Tensor): Local tensor to be gathered (shape: [..., D, ...]) + group (ProcessGroup): Communication group (defaults to model parallel group) + axis (int): Concatenation dimension (default=0) + + Returns: + Tensor: Concatenated tensor combining inputs from all processes: + - When axis=0: shape [D*N, ...] (N = group size) + - Otherwise: shape [..., D*N, ...] along specified axis + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + output_shape = input.shape + if axis == 0: + output_shape[axis] = output_shape[axis] * parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.all_gather(output, input, group=group, use_calc_stream=True) + return output + outputs = [paddle.empty(output_shape, dtype=input.dtype) for _ in range(parallelism)] + dist.stream.all_gather(outputs, input, group=group, use_calc_stream=True) + output = paddle.concat(outputs, axis=axis) + return output + + +def reduce_scatter_group(input, group=None): + """Perform reduce-scatter collective operation across a process group. + + Functional Behavior: + - Aggregates (sums) input tensors across all processes in the group + - Scatters the reduced result equally to all participants + - Operates along the first dimension (axis=0) of the input tensor + + Args: + input (Tensor): Local tensor to reduce (shape: [N*K, ...] where N=group_size) + group (ProcessGroup): Communication group (defaults to model parallel group) + + Returns: + Tensor: Scattered portion of reduced tensor with shape [K, ...] + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + output_shape = input.shape + assert ( + input.shape[0] % parallelism == 0 + ), f"Input sequence length {input.shape[0]} can't be divided exactly by sequence parallelism {parallelism}" + output_shape[0] = output_shape[0] // parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.reduce_scatter(output, input, op=dist.ReduceOp.SUM, group=group, use_calc_stream=True) + return output diff --git a/ernie/ERNIE/ernie/distributed/xpu_dist_utils.py b/ernie/ERNIE/ernie/distributed/xpu_dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4142bf314dd5cf10331272a2418890118fca4cb9 --- /dev/null +++ b/ernie/ERNIE/ernie/distributed/xpu_dist_utils.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +XPU distributed utils. +""" + +try: + from paddle_xpu.layers.nn import ( + ColumnParallelLinear, + RowParallelLinear, + parallel_matmul, + ) + from paddle_xpu.layers.nn.sequence_parallel import ( + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + ) +except ImportError: + from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + ) + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + ) + + parallel_matmul = None + +__all__ = [ + "ColumnParallelLinear", + "RowParallelLinear", + "ColumnSequenceParallelLinear", + "RowSequenceParallelLinear", + "GatherOp", + "ScatterOp", +] diff --git a/ernie/ERNIE/ernie/fusion_ops/__init__.py b/ernie/ERNIE/ernie/fusion_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0047712cd0507bfdeba0ee901357ef26b4195357 --- /dev/null +++ b/ernie/ERNIE/ernie/fusion_ops/__init__.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Fusion operators. +""" +import paddle +from paddle.incubate.nn.functional import fused_rms_norm_ext +from paddle.incubate.nn.functional import fused_rotary_position_embedding as fused_rope +from paddle.incubate.nn.functional import swiglu as fused_swiglu + +from .common_fusion_ops import Linear, matmul + +if paddle.device.is_compiled_with_custom_device('npu'): + from .npu_fusion_ops import npu_cal_aux_loss_func as cal_aux_loss +else: + from paddle.incubate.nn.functional import cal_aux_loss + +__all__ = [ + 'fused_rope', + 'fused_swiglu', + 'fused_rms_norm_ext', + 'Linear', + 'matmul', + 'cal_aux_loss', +] + + +def fusion_flash_attention( + q, + k, + v, + training_mode, + attention_probs_dropout_prob, + use_sparse_flash_attn, + attention_mask=None, + attn_mask_start_row_indices=None, + seq_length=None, + use_var_len_flash_attn=False, + rr_flash_attn=None, +): + """ + Args: + q (Tensor): Query tensor. + k (Tensor): Key tensor. + v (Tensor): Value tensor. + training_mode (bool): Whether in training mode. + attention_probs_dropout_prob (float): Dropout probability for attention probabilities. + use_sparse_flash_attn (bool): Whether to use sparse flash attention. + attention_mask (Tensor, optional): Attention mask. Defaults to None. + attn_mask_start_row_indices (Tensor, optional): Start row indices for attention mask. Defaults to None. + seq_length (int, optional): Sequence length. Defaults to None. + use_var_len_flash_attn (bool, optional): Whether to use variable length flash attention. Defaults to False. + rr_flash_attn (bool, optional): Whether to use round-robin flash attention. Defaults to None. + + Returns: + Tensor: Output tensor after applying fusion flash attention. + """ + from .common_fusion_ops import _fusion_flash_attention + + return _fusion_flash_attention( + q, + k, + v, + training_mode=training_mode, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_sparse_flash_attn=use_sparse_flash_attn, + attention_mask=attention_mask, + attn_mask_start_row_indices=attn_mask_start_row_indices, + rr_flash_attn=rr_flash_attn, + ) diff --git a/ernie/ERNIE/ernie/fusion_ops/__pycache__/__init__.cpython-311.pyc b/ernie/ERNIE/ernie/fusion_ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44f4cac5763b0726dcd318b02c34b82633d81179 Binary files /dev/null and b/ernie/ERNIE/ernie/fusion_ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/fusion_ops/__pycache__/common_fusion_ops.cpython-311.pyc b/ernie/ERNIE/ernie/fusion_ops/__pycache__/common_fusion_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b33811faacc605affc31dbc079336d5cc3af61 Binary files /dev/null and b/ernie/ERNIE/ernie/fusion_ops/__pycache__/common_fusion_ops.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/fusion_ops/common_fusion_ops.py b/ernie/ERNIE/ernie/fusion_ops/common_fusion_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5fac7d8eae3ab5620f2e5d99a885404030c2e0d4 --- /dev/null +++ b/ernie/ERNIE/ernie/fusion_ops/common_fusion_ops.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Common fusion operators. +""" + +import paddle +import paddle.nn.functional as F +from paddle import matmul, tensor + +if paddle.is_compiled_with_xpu(): + try: + from paddle_xpu.layers.nn import Linear + except ImportError: + from paddle.nn import Linear +else: + from paddle.nn import Linear +from paddle.nn.functional.flash_attention import flashmask_attention + +__all__ = [ + "matmul", + "Linear", +] + + +def _fusion_flash_attention( + q, + k, + v, + training_mode, + attention_probs_dropout_prob, + use_sparse_flash_attn, + attention_mask=None, + attn_mask_start_row_indices=None, + rr_flash_attn=None, +): + """ + Performs fused flash attention with multiple implementation variants. + + Args: + q (paddle.Tensor): Query tensor with shape [batch, heads, seq_len, dim_head] + k (paddle.Tensor): Key tensor with shape [batch, heads, seq_len, dim_head] + v (paddle.Tensor): Value tensor with shape [batch, heads, seq_len, dim_head] + training_mode (bool): Whether in training mode (affects dropout) + attention_probs_dropout_prob (float): Dropout probability for attention weights + use_sparse_flash_attn (bool): Whether to use sparse flash attention optimization + attention_mask (Optional[paddle.Tensor]): Dense attention mask (default: None) + attn_mask_start_row_indices (Optional[paddle.Tensor]): Sparse mask indices (default: None) + rr_flash_attn (Optional[Callable]): Recomputation wrapper for flash attention (default: None) + + Returns: + Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + - Output tensor with shape [batch, seq_len, heads*dim_head] + - Attention weights (None for flash attention implementations) + + Raises: + Warning: If sparse flash attention is requested but unavailable + ValueError: If invalid combination of mask inputs is provided + """ + + version = paddle.version.full_version + if attn_mask_start_row_indices is not None: + if use_sparse_flash_attn: + if rr_flash_attn is None: + out = flashmask_attention( + q, + k, + v, + startend_row_indices=attn_mask_start_row_indices.unsqueeze(-1), + causal=True, + ) + else: + out = rr_flash_attn( + flashmask_attention, + q, + k, + v, + startend_row_indices=attn_mask_start_row_indices.unsqueeze(-1), + causal=True, + ) + else: + attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, q.dtype) + if rr_flash_attn is None: + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + is_causal=False, + ) + else: + out = rr_flash_attn( + F.scaled_dot_product_attention, + q, + k, + v, + attn_mask=attention_mask, + is_causal=False, + ) + weights = None + else: + if rr_flash_attn is None: + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + is_causal=attention_mask is None and q.shape[1] != 1, + ) + weights = None + else: + out = rr_flash_attn( + F.scaled_dot_product_attention, + q, + k, + v, + attn_mask=attention_mask, + is_causal=attention_mask is None and q.shape[1] != 1, + ) + weights = None + + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + return out, weights + + +def _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, dtype): + """ + Recover 4-D attention_mask from attn_mask_start_row_indices. + + Args: + attn_mask_start_row_indices (paddle.Tensor): The start row indices for the attention mask. + dtype (str): The data type of the tensor. + + Returns: + paddle.Tensor: The dense attention mask recovered from attn_mask_start_row_indices. + """ + batch_size, _, max_seq_len = attn_mask_start_row_indices.shape + base = paddle.arange(max_seq_len, dtype="int32").unsqueeze(1).expand([batch_size, -1, max_seq_len]).unsqueeze(1) + mask_indices = attn_mask_start_row_indices.unsqueeze(1) + + tril = paddle.tril( + paddle.ones([max_seq_len, max_seq_len], dtype="bool").expand([batch_size, 1, max_seq_len, max_seq_len]) + ) + attention_mask = paddle.logical_and(base < mask_indices, tril) + attention_mask = paddle.scale( + x=attention_mask.astype(dtype), + scale=1000000.0, + bias=-1.0, + bias_after_scale=False, + ) + + return attention_mask diff --git a/ernie/ERNIE/ernie/fusion_ops/npu_fusion_ops.py b/ernie/ERNIE/ernie/fusion_ops/npu_fusion_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6838515823190d0f7bdb2916630dc750cf1774 --- /dev/null +++ b/ernie/ERNIE/ernie/fusion_ops/npu_fusion_ops.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +npu fusion operators. + +""" +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F + + +def npu_combining(x, combine_weights, scatter_index, hard_gate=False): + """ + Args: + x: Tensor[seq, dim] + combine_weights: [seq, k] + scatter_index: ** [seq, k] ** + Returns: + y: Tensor[s, dim] + """ + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + if hard_gate: + return x_gatherd.squeeze(-2) + y = (combine_weights.unsqueeze(-1) * x_gatherd).sum(1) + return y + + +def npu_cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + global_aux_loss=False, + rank=None, + group=None, + clip_min=1e-6, +): + """cal_aux_loss_func""" + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if tokens_mask is not None and gate_prob.shape[0] != dispatch_tokens_mask.shape[0]: + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + + if global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=group) + dist.all_gather(ce_list, ce, group=group) + me_list[rank] = me + ce_list[rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / moe_k + if scale is not None: + l_aux = l_aux + (scale - 1) * l_aux.detach() + return l_aux, None, None diff --git a/ernie/ERNIE/ernie/loss/__pycache__/dpo.cpython-311.pyc b/ernie/ERNIE/ernie/loss/__pycache__/dpo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba548d7cb76da2c14865233568eece82f3e14be2 Binary files /dev/null and b/ernie/ERNIE/ernie/loss/__pycache__/dpo.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/loss/dpo.py b/ernie/ERNIE/ernie/loss/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..914c9c2f3ab2f75abe77160e179d74f9fc698355 --- /dev/null +++ b/ernie/ERNIE/ernie/loss/dpo.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO Criterion""" + +import paddle +import paddle.nn.functional as F +from paddleformers.transformers import fused_head_and_loss_fn +from paddleformers.trl import DPOCriterion +from paddleformers.utils import infohub + +from ..distributed.common_dist_utils import ( + AllGatherVarlenOp, + GatherOp, + sequence_parallel_sparse_mask_labels, +) +from ..modeling import parallel_matmul + +LOOP_CHUNK_SIZE = 1024 + + +class ErnieDPOCriterion(DPOCriterion): + """DPO Criterion""" + + def dpo_logps( + self, + logits, + chosen_labels, + rejected_labels, + response_indexs, + average_log_prob=False, + ): + """DPO logprobs""" + labels = chosen_labels + rejected_labels + hidden_states, weight, bias, transpose_y = logits + + if self.config.use_sparse_head_and_loss_fn: + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, 0) + + hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0) + hidden_states = AllGatherVarlenOp.apply(hidden_states) + else: + labels = labels.flatten() + sparse_tgt_idx = paddle.nonzero(labels != 0).flatten() + labels = paddle.gather(labels, sparse_tgt_idx, axis=0) + + hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) + hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0) + elif self.config.use_fused_head_and_loss_fn: + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape( + [ + -1, + self.config.max_sequence_length, + hidden_states.shape[-1], + ] + ) + + if self.config.use_fused_head_and_loss_fn: + per_token_logps = -fused_head_and_loss_fn( + hidden_states, + weight, + bias, + labels, + None, + transpose_y, + self.config.vocab_size, + self.config.tensor_parallel_degree, + self.config.tensor_parallel_output, + self.config.fuse_linear, + LOOP_CHUNK_SIZE, + return_token_loss=True, + ignore_index=0, + ) + else: + logits = parallel_matmul( + hidden_states, + weight, + bias=bias, + transpose_y=self.config.tie_word_embeddings, + tensor_parallel_output=self.config.tensor_parallel_output, + fuse_linear=self.config.fuse_linear, + ) + logits = logits.astype("float32") + per_token_logps = -self.logprobs(logits, labels) + + if len(response_indexs.shape) == 3: + response_indexs = response_indexs[0] + + if self.config.use_sparse_head_and_loss_fn: + chosen_logps = paddle.stack( + [ + ( + paddle.gather( + per_token_logps.reshape([-1]), + paddle.arange(response_index[1], response_index[2], dtype=paddle.int32), + axis=0, + ).sum() + if response_index[3] != 0 + else paddle.to_tensor(100.0) + ) + for response_index in response_indexs + ], + axis=0, + ) + rejected_logps = paddle.stack( + [ + ( + paddle.gather( + per_token_logps.reshape([-1]), + paddle.arange(response_index[2], response_index[3], dtype=paddle.int32), + axis=0, + ).sum() + if response_index[3] != 0 + else paddle.to_tensor(100.0) + ) + for response_index in response_indexs + ], + axis=0, + ) + else: + chosen_logps = paddle.stack( + [ + ( + paddle.gather( + paddle.gather(per_token_logps, response_index[0], axis=0), + paddle.arange(response_index[1], response_index[2], dtype=paddle.int32), + axis=0, + ).sum() + if response_index[3] != 0 + else paddle.to_tensor(100.0) + ) + for response_index in response_indexs + ], + axis=0, + ) + rejected_logps = paddle.stack( + [ + ( + paddle.gather( + paddle.gather(per_token_logps, response_index[0], axis=0), + paddle.arange(response_index[2], response_index[3], dtype=paddle.int32), + axis=0, + ).sum() + if response_index[3] != 0 + else paddle.to_tensor(100.0) + ) + for response_index in response_indexs + ], + axis=0, + ) + sft_loss = -chosen_logps.sum() / (chosen_labels != 0).sum() + if average_log_prob: + chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1] + rejected_response_length = response_indexs[:, 3] - response_indexs[:, 2] + chosen_logps /= chosen_response_length.astype("float32") + rejected_logps /= rejected_response_length.astype("float32") + elif self.dpo_config.normalize_logps: + avg_response_length = (response_indexs[:, 3] - response_indexs[:, 1]) / 2 + chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1] + rejected_response_length = response_indexs[:, 3] - response_indexs[:, 2] + chosen_logps *= avg_response_length / chosen_response_length.astype("float32") + rejected_logps *= avg_response_length / rejected_response_length.astype("float32") + return ( + chosen_logps, + rejected_logps, + sft_loss * self.dpo_config.sft_loss_ratio, + ) + + def dpo_loss( + self, + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + score_deltas, + ): + """DPO Loss""" + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + + if self.dpo_config.loss_type == "sigmoid": + if self.dpo_config.offset_alpha > 0: + logits = logits - self.dpo_config.offset_alpha / self.dpo_config.beta * paddle.log(score_deltas + 1e-6) + loss = ( + -F.log_sigmoid(self.dpo_config.beta * logits) * (1 - self.dpo_config.label_smoothing) + - F.log_sigmoid(-self.dpo_config.beta * logits) * self.dpo_config.label_smoothing + ) + elif self.dpo_config.loss_type == "hinge": + loss = F.relu(1 - self.dpo_config.beta * logits) + elif self.dpo_config.loss_type == "simpo": + gamma_logratios = self.dpo_config.simpo_gamma / self.dpo_config.beta + logits -= gamma_logratios + loss = ( + -F.log_sigmoid(self.dpo_config.beta * logits) * (1 - self.dpo_config.label_smoothing) + - F.log_sigmoid(-self.dpo_config.beta * logits) * self.dpo_config.label_smoothing + ) + elif self.dpo_config.loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter + # for the IPO loss, denoted by tau in the paper. + loss = (logits - 1 / (2 * self.dpo_config.beta)) ** 2 + elif self.dpo_config.loss_type == "dpop": + loss = -F.log_sigmoid(self.dpo_config.beta * logits) + positive_reg = reference_chosen_logps - policy_chosen_logps + loss += self.dpo_config.dpop_lambda * paddle.clip(positive_reg, min=0) + elif self.dpo_config.loss_type == "kto_pair": + # eqn (7) of the HALOs paper + chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clip(min=0) + rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clip(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + # As described in the KTO report, the KL term for chosen (rejected) is + # estimated using the rejected (chosen) half. + loss = paddle.concat( + ( + 1 - F.sigmoid(self.dpo_config.beta * (chosen_logratios - rejected_KL)), + 1 - F.sigmoid(self.dpo_config.beta * (chosen_KL - rejected_logratios)), + ), + 0, + ) + elif self.dpo_config.loss_type == "sppo_hard": + # In the paper (https://arxiv.org/pdf/2405.00675), SPPO employs a soft probability approach, + # estimated using the PairRM score. The probability calculation is conducted outside of + # the trainer class. The version described here is the hard probability version, where P + # in Equation (4.7) of Algorithm 1 is set to 1 for the winner and 0 for the loser. + a = policy_chosen_logps - reference_chosen_logps + b = policy_rejected_logps - reference_rejected_logps + + loss = (a - 0.5 / self.dpo_config.beta) ** 2 + (b + 0.5 / self.dpo_config.beta) ** 2 + elif self.dpo_config.loss_type == "nca_pair": + chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.dpo_config.beta + rejected_rewards = (policy_rejected_logps - reference_rejected_logps) * self.dpo_config.beta + loss = ( + -F.log_sigmoid(chosen_rewards) + - 0.5 * F.log_sigmoid(-chosen_rewards) + - 0.5 * F.log_sigmoid(-rejected_rewards) + ) + elif self.dpo_config.loss_type == "or": + # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using + # log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + paddle.log1p(-paddle.exp(policy_chosen_logps)) - paddle.log1p(-paddle.exp(policy_rejected_logps)) + ) + loss = -F.log_sigmoid(log_odds) + else: + raise ValueError( + f"Unknown loss type: {self.dpo_config.loss_type}. " + "Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair'," + "'sppo_hard', 'nca_pair', 'dpop', 'or', 'simpo']" + ) + return loss.mean() * self.dpo_config.pref_loss_ratio + + def forward( + self, + logits, + labels, + ): + """Forward""" + if self.dpo_config.offset_alpha > 0: + ( + chosen_labels, + rejected_labels, + response_indexs, + score_deltas, + reference_chosen_logps, + reference_rejected_logps, + ) = labels + else: + ( + chosen_labels, + rejected_labels, + response_indexs, + reference_chosen_logps, + reference_rejected_logps, + ) = labels + score_deltas = None + + if self.dpo_config.loss_type in ["ipo", "or", "simpo"]: + average_log_prob = True + else: + average_log_prob = False + if reference_chosen_logps is None or reference_rejected_logps is None: + reference_chosen_logps, reference_rejected_logps, sft_loss = self.dpo_logps( + logits, chosen_labels, rejected_labels, response_indexs, average_log_prob + ) + if self.use_infohub: + infohub.reference_chosen_logps.append(reference_chosen_logps) + infohub.reference_rejected_logps.append(reference_rejected_logps) + # pipeline mode requires return loss when self._compute_loss is True + return paddle.zeros([1]) + else: + return reference_chosen_logps, reference_rejected_logps + policy_chosen_logps, policy_rejected_logps, sft_loss = self.dpo_logps( + logits, chosen_labels, rejected_labels, response_indexs, average_log_prob + ) + dpo_loss = self.dpo_loss( + policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, score_deltas + ) + loss = dpo_loss + sft_loss + if self.use_infohub: + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.sft_loss.append(sft_loss.detach()) + infohub.dpo_loss.append(dpo_loss.detach()) + return loss + else: + return policy_chosen_logps, policy_rejected_logps, sft_loss, dpo_loss, loss diff --git a/ernie/ERNIE/ernie/refined_recompute/__pycache__/utils.cpython-311.pyc b/ernie/ERNIE/ernie/refined_recompute/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4b5dd5b9df9cb6113478729c8dd5ef78d6b8aa1 Binary files /dev/null and b/ernie/ERNIE/ernie/refined_recompute/__pycache__/utils.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/utils/__pycache__/common_utils.cpython-311.pyc b/ernie/ERNIE/ernie/utils/__pycache__/common_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57027c373bc144f8f65773139414964855960534 Binary files /dev/null and b/ernie/ERNIE/ernie/utils/__pycache__/common_utils.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/utils/__pycache__/moe_utils.cpython-311.pyc b/ernie/ERNIE/ernie/utils/__pycache__/moe_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35ec0d245eb454f200b5c21be9078b3cd8ce7096 Binary files /dev/null and b/ernie/ERNIE/ernie/utils/__pycache__/moe_utils.cpython-311.pyc differ diff --git a/ernie/ERNIE/ernie/utils/common_utils.py b/ernie/ERNIE/ernie/utils/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b25d67ae70721dc772e3bcaaf90b194ded1faeb7 --- /dev/null +++ b/ernie/ERNIE/ernie/utils/common_utils.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import numpy as np +import paddle +from paddleformers.utils.log import logger + +MODEL_LIB_NAMES = [ + "ernie.modeling", + "ernie.modeling_moe", + "ernie.modeling_moe_pp", +] + +MAX_BSZ = 256 +MAX_DRAFT_TOKENS = 6 +LOCAL_RANK = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + + +def calculate_effective_tokens(training_args, train_dataset, max_seq_len): + """ + Calculate the effective tokens during training. + + Args: + training_args: Configuration object containing training parameters. + train_dataset: Dataset used for training. + max_seq_len: Maximum sequence length of the model. + + Returns: + tuple: Contains total_effective_tokens (int) and total_tokens (int). + """ + total_effective_tokens = 0 + try: + data_parallel_degree = training_args.data_parallel_degree + except: + data_parallel_degree = 1 + if training_args.sharding_parallel_degree > 1: + sharding_parallel_degree = training_args.sharding_parallel_degree + else: + sharding_parallel_degree = 1 + + total_batch = ( + training_args.max_steps + * training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * sharding_parallel_degree + * data_parallel_degree + ) + for i, data in enumerate(train_dataset): + if i == total_batch: + break + for dd in data: + total_effective_tokens += len(dd.token_ids) + total_tokens = total_batch * max_seq_len + + return total_effective_tokens, total_tokens + + +def estimate_training(train_dataset, data_args, training_args, model_args): + """ + Estimate required training steps based on dataset. + + Args: + train_dataset: Dataset used for training estimation. + data_args: Configuration object containing data parameters. + training_args: Configuration object containing training parameters. + model_args: Configuration object containing model parameters. + + Returns: + dict: Contains estimated training steps and related parameters. + """ + train_dataset.estimate = True + logger.info("Start to estimate max training steps...") + + train_dataset_path_list = [path for path in str(data_args.train_dataset_path).replace(" ", "").split(',')] + if len(train_dataset_path_list) > 1: + logger.warning("Suggest to use max_steps instead of num_train_epochs for multi source dataset.") + logger.info( + "Multi source dataset detected, number of samples will be estimated by following rule. " + "num_samples = (source1_num_samples * prob1 + source2_num_samples * prob2 + ...) * epochs" + ) + + max_samples = train_dataset.max_estimate_samples + + if training_args.max_estimate_samples != -1: + # Set estimate samples to max_estimate_samples + logger.warning("The results between sampling and non-sampling methods may differ.") + train_dataset.max_estimate_samples = min( + training_args.max_estimate_samples, train_dataset.max_estimate_samples + ) + + if train_dataset.max_estimate_samples > 0: + train_batches = 0 + train_tokens = 0 + for sequences in train_dataset: + if not train_dataset.estimate: + break + train_batches += 1 + for sequence in sequences: + train_tokens += len(sequence.token_ids) + + train_tokens *= training_args.num_train_epochs + train_batches *= training_args.num_train_epochs + global_batch_size = ( + training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * max(training_args.data_parallel_degree, 1) + * max(training_args.sharding_parallel_degree, 1) + ) + max_steps = train_batches / global_batch_size + + if max_samples != train_dataset.max_estimate_samples: + max_steps *= max_samples / train_dataset.max_estimate_samples + train_tokens *= max_samples / train_dataset.max_estimate_samples + train_dataset.used_samples *= max_samples / train_dataset.max_estimate_samples + train_dataset.unused_samples *= max_samples / train_dataset.max_estimate_samples + + max_steps = int(np.ceil(max_steps)) + + res = { + "num_train_epochs": int(training_args.num_train_epochs), + "max_steps": max_steps, + "train_tokens": int(train_tokens), + "global_batch_size": int(global_batch_size), + "gradient_accumulation_steps": training_args.gradient_accumulation_steps, + "warmup_steps": int(np.ceil(0.1 * max_steps)), + "per_device_train_batch_size": int(training_args.per_device_train_batch_size), + "tensor_parallel_degree": int(training_args.tensor_parallel_degree), + "pipeline_parallel_degree": int(training_args.pipeline_parallel_degree), + "sharding_parallel_degree": int(training_args.sharding_parallel_degree), + "seed": training_args.seed, + "num_samples_each_epoch": data_args.num_samples_each_epoch, + "max_seq_len": int(data_args.max_seq_len), + "valid": True, + "train_samples": int(max_samples * training_args.num_train_epochs), + "estimate_samples": int(train_dataset.max_estimate_samples), + "actual_train_samples": int(train_dataset.used_samples * training_args.num_train_epochs), + "skip_samples": int(train_dataset.unused_samples * training_args.num_train_epochs), + } + if hasattr(training_args, "num_of_gpus"): + res["num_of_gpus"] = training_args.num_of_gpus + + if train_batches / training_args.num_train_epochs / global_batch_size < 1: + logger.warning("This dataset is too small, you'd better enlarge your dataset.") + res["valid"] = False + + if getattr(training_args, "estimation_output_file", None): + with open(training_args.estimation_output_file, "w", encoding="utf-8") as f: + json.dump(res, f) + + return max_steps + else: + res = { + "num_train_epochs": int(training_args.num_train_epochs), + "max_steps": 0, + "gradient_accumulation_steps": training_args.gradient_accumulation_steps, + "train_tokens": 0, + "per_device_train_batch_size": int(training_args.per_device_train_batch_size), + "tensor_parallel_degree": int(training_args.tensor_parallel_degree), + "pipeline_parallel_degree": int(training_args.pipeline_parallel_degree), + "sharding_parallel_degree": int(training_args.sharding_parallel_degree), + "num_samples_each_epoch": data_args.num_samples_each_epoch, + "max_seq_len": int(data_args.max_seq_len), + "seed": data_args.seed, + "valid": False, + "train_samples": 0, + } + if hasattr(training_args, "num_of_gpus"): + res["num_of_gpus"] = training_args.num_of_gpus + + if getattr(training_args, "estimation_output_file", None): + with open(training_args.estimation_output_file, "w", encoding="utf-8") as f: + json.dump(res, f) + + logger.error("No valid data found, please check your dataset format.") + return 0 + + +def check_refined_recompute(rr, sequence_parallel, lora=False): + """ + Update refined recompute configuration. + + Args: + rr: Original recompute configuration (dict). + sequence_parallel: Boolean indicating if sequence parallel is enabled. + lora: Boolean indicating if LoRA is used. + + Returns: + dict: Updated recompute configuration. + """ + if len(rr) > 0: + rr = {} + logger.error("Currently do not support refine recompute; to be supported soon.") + + for op_name in rr.keys(): + if op_name in ["mlp_row_ln", "attention_row_ln", "attention_column_ln", "mlp_column_ln"]: + if not sequence_parallel: + logger.warning( + f"Currently, the `{op_name}` op is only supported " + "when `sequence_parallel=True`. This refined recompute op will be ignored." + ) + continue + if lora: + logger.warning( + "Currently, LoRA does not support refined recompute " + f"for the `{op_name}` op. This refined recompute op will be ignored." + ) + continue + + +def save_stop_info(args, stop_step, outside_eval, outside_predict): + """ + Save training stop information to JSON file. + + Args: + args: Command line arguments. + stop_step: Step number when training stopped. + outside_eval: Number of external evaluations performed. + outside_predict: Number of external predictions made. + + Returns: + None + """ + + process_index = paddle.distributed.get_rank() if LOCAL_RANK != -1 else 0 + if process_index != 0: + return + + output_path = args.logging_dir + eval_turns = 0 + outside_eval + predict_turns = 0 + outside_predict + if args.do_eval: + eval_turns += stop_step // args.eval_steps + + data = { + "stop_step": stop_step, + "eval_turns": eval_turns, + "predict_turns": predict_turns, + } + os.makedirs(output_path, exist_ok=True) + file_path = os.path.join(output_path, "stop_step.json") + with open(file_path, 'w') as json_file: + json.dump(data, json_file) + logger.info(f"Saving stop info into {file_path}") + return + + +def add_start_docstrings(*docstr): + """ + Decorator to prepend docstrings to function documentation. + + Args: + *docstr: Variable length argument list of docstrings. + + Returns: + function: Decorator function. + """ + + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +def infer_save_test_case(cases: list[list[dict]], file: str): + """save test to result file + + Args: + cases (list[list[dict]]): the content of case + file (str): the path of saved file + """ + with open(file, "a+", encoding="utf-8") as f: + for case in cases: + raw = json.dumps(case, ensure_ascii=False) + f.write(raw + "\n") diff --git a/ernie/ERNIE/ernie/utils/mm_data_utils.py b/ernie/ERNIE/ernie/utils/mm_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6c193fbb394b6380fdae990c05fcec1d83d51f --- /dev/null +++ b/ernie/ERNIE/ernie/utils/mm_data_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MMSpecialTokensConfig class""" + +import logging + +logger = logging.getLogger(__name__) + +__all__ = ("MMSpecialTokensConfig", "DATATYPE_2_ID", "IDTYPES_2_ID", "IMAGETYPES_2_ID") + +DATATYPE_2_ID = {"mm": 0, "lm": 1, "audio": 2} +IDTYPES_2_ID = {"text": 0, "image": 1, "video": 2, "audio": 3} +IMAGETYPES_2_ID = {"image": 0, "video": 1, "padded_image": 2} + + +class MMSpecialTokensConfig: + """_summary_""" + + use_ocr_specialtoken = True + use_crop_specialtoken = True + coor_num = 1001 + image_placeholder = "<|IMAGE_PLACEHOLDER|>" + audio_placeholder = "<|AUDIO_PLACEHOLDER|>" + crop = ["<|CROP_COL_SEP|>", "<|CROP_ROW_SEP|>", "<|IMAGE_SEP|>"] + ocr_coor = [f"<|LOC_{i}|>" for i in range(coor_num)] + ocr_begin_end = ["<|LOC_BEGIN|>", "<|LOC_END|>", "<|LOC_SEP|>"] + mm_begin_end = ["<|BOI|>", "<|EOI|>", "<|BOA|>", "<|EOA|>", "<|BOV|>", "<|EOV|>"] + + @classmethod + def get_special_tokens_info(cls): + """_summary_ + + Returns: + _type_: _description_ + """ + return { + k: getattr(cls, k) + for k in ["image_placeholder", "audio_placeholder", "crop", "ocr_coor", "ocr_begin_end", "mm_begin_end"] + } diff --git a/ernie/ERNIE/ernie/utils/moe_utils.py b/ernie/ERNIE/ernie/utils/moe_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b582e01d8b6f1abbc18205938d655cde1e71341 --- /dev/null +++ b/ernie/ERNIE/ernie/utils/moe_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MoE utils""" + +from paddle.distributed import fleet + + +def distributed_optimizer_for_moe( + optimizer, + use_moe=False, +): + """ + Create a distributed optimizer with MoE (Mixture of Experts) support. + + Args: + optimizer: Base optimizer to decorate. + use_moe (bool): Whether to enable MoE expert parallel. + + Returns: + HybridParallelOptimizer: Configured optimizer for distributed training. + """ + + if not use_moe: + return fleet.distributed_optimizer(optimizer) + + from ernie.moe.distributed.hybrid_parallel_optimizer import ( + HybridParallelOptimizer as MoEHybridParallelOptimizer, + ) + + fleet_env = fleet.fleet + fleet_env.user_defined_optimizer = optimizer + hp_optim = MoEHybridParallelOptimizer(optimizer, fleet_env._hcg, fleet_env._user_defined_strategy) + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].dp_comm_overlap: + hp_optim._dp_enable = False + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].sharding_comm_overlap: + hp_optim._sharding_enable = False + return hp_optim diff --git a/ernie/ERNIE/ernie/utils/peft_utils.py b/ernie/ERNIE/ernie/utils/peft_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6308b25a2dc0ff6443b65ca222a5014d796844de --- /dev/null +++ b/ernie/ERNIE/ernie/utils/peft_utils.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PEFT utils""" +from paddleformers.peft import LoRAConfig, LoRAModel +from paddleformers.utils.log import logger + + +def initialize_lora_model( + model, training_args, model_args, resume_from_checkpoint, dtype +): + """Initialize LoRAModel""" + + logger.info("Start to wrap model with LoRA config ...") + if model_args.lora_path is None or resume_from_checkpoint: + # If resume from checkpoint, LoRA adatper will be overwritten in the checkpoint loading process. + target_modules = [ + ".*qkv_proj.*", + ".*o_proj.*", + ".*up_gate_proj.*", + ".*down_proj.*", + ] + if model_args.rslora_plus: + model_args.rslora = True + model_args.lora_plus_scale = 4 + model_args.lora_alpha = 4 + + if training_args.weight_quantize_algo is not None: + if model_args.rslora or model_args.lora_plus_scale != 1.0: + logger.info("Weight quantization is not supported in LoRA+ and RsLoRA.") + if model_args.lora_alpha == -1: + if model_args.rslora: + model_args.lora_alpha = 4 + else: + model_args.lora_alpha = 2 * model_args.lora_rank + lora_config = LoRAConfig( + target_modules=target_modules, + r=model_args.lora_rank, + lora_alpha=model_args.lora_alpha, + rslora=model_args.rslora, + lora_plus_scale=model_args.lora_plus_scale, + tensor_parallel_degree=training_args.tensor_parallel_degree, + dtype=dtype, + head_dim=model.config.hidden_size // model.config.num_attention_heads, + base_model_name_or_path=model_args.model_name_or_path, + ) + model = LoRAModel(model, lora_config) + else: + model = LoRAModel.from_pretrained( + model=model, + lora_path=model_args.lora_path, + ) + + model.mark_only_lora_as_trainable() + model.print_trainable_parameters() + logger.info("Wraping model with LoRA config successfully !") + return model diff --git a/ernie/ERNIE/ernie/utils/seed_utils.py b/ernie/ERNIE/ernie/utils/seed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e126f27dc16f066b4f2d4c7900cc5baf9883b4 --- /dev/null +++ b/ernie/ERNIE/ernie/utils/seed_utils.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Set random seed for reproducibility in hybrid parallel training. +""" +import random + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed.fleet import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker + + +def set_seed(seed): + """set random seed for reproducibility in hybrid parallel training.""" + # NOTE(shenliang03): For parameter init seed: + # seed: dp/mp_undistributed_paramter/sharding is same; others is different + # For compute seed(dropout): + # global seed: only mp group is same. + # local seed: all groups are different + + if hasattr(fleet, "_hcg"): # 混合并行下,才分开设置local-seed和global-seed + # obtain rank message of hybrid parallel + hcg = fleet.get_hybrid_communicate_group() + + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + + pp_rank = hcg.get_stage_id() + pp_size = hcg.get_pipe_parallel_world_size() + + dp_rank = hcg.get_data_parallel_rank() + dp_size = hcg.get_data_parallel_world_size() + + sharding_rank = hcg.get_sharding_parallel_rank() + sharding_size = hcg.get_sharding_parallel_world_size() + else: + mp_rank, mp_size = 0, 1 + pp_rank, pp_size = 0, 1 + dp_rank, dp_size = dist.get_rank(), dist.get_world_size() + sharding_rank, sharding_size = 0, 1 + + # NOTE: the commented seeds are set only for precision validation + # 与框架中的实现对齐, + # 无论是否启用混合并行,都设置 model_parallel_rng 用于同步初始化参数 + model_parallel_rng = seed + 1 + mp_rank * pp_size + pp_rank + + seed += ( + 1 * dp_rank + ) # EB4框架中数据流并不需要全局seed,。此处操作对数据没什么影响,对组网也没什么影响。只是为了兼容 fleet 传统而设置。 + random.seed(seed) + np.random.seed(seed) + + # seed = mp_rank + + # pp_rank * (mp_size) + + # dp_rank * (mp_size * pp_size) + + # sharding_rank * (mp_size * pp_size * dp_size) + # seed offset is order to avoid conflicts with the parameter initialization seed + + seed_offset = seed + 1024 + paddle.distributed.get_world_size() + global_seed = ( + seed_offset + + pp_rank * (mp_size) + + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size) + ) + + seed_offset += paddle.distributed.get_world_size() + local_seed = ( + seed_offset + + mp_rank + + pp_rank * (mp_size) + + dp_rank * (mp_size * pp_size) + + sharding_rank * (mp_size * pp_size * dp_size) + ) + + tracker = get_rng_state_tracker() + tracker.add("global_seed", global_seed) + tracker.add("local_seed", local_seed) + if "model_parallel_rng" not in tracker.states_: + tracker.add("model_parallel_rng", model_parallel_rng) + paddle.seed(global_seed) + + print( + f""" + The global seed is set to {global_seed} and local seed is set to {local_seed}. + mp_init_seed={model_parallel_rng} + """ + ) diff --git a/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_chat.yaml b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_chat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a41229408cfc3ecec03b5f3d76e333964c4ac2a --- /dev/null +++ b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_chat.yaml @@ -0,0 +1,13 @@ +### model +model_name_or_path: baidu/ERNIE-4.5-300B-A47B-Paddle +tensor_parallel_degree: 8 +output_dir: None + +### server +max_model_len: 16384 +port: 8188 + +### chat +max_new_tokens: 8192 +top_p: 0.7 +temperature: 0.95 diff --git a/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_eval.yaml b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd21204403600a8ab3cb1d771b0ccf5c2f7a341f --- /dev/null +++ b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_eval.yaml @@ -0,0 +1,53 @@ +### data +eval_dataset_type: "erniekit" +eval_dataset_path: "./examples/data/sft-eval.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 8192 +num_samples_each_epoch: 6000000 + +### model +model_name_or_path: baidu/ERNIE-4.5-300B-A47B-Paddle +moe_group: mp +fine_tuning: LoRA +lora_rank: 32 +fuse_rope: True + +### eval +seed: 23 +do_train: False +do_eval: True +distributed_dataloader: False +dataloader_num_workers: 1 +batch_size: 1 +logging_dir: ./vdl_log +output_dir: ./output +disable_tqdm: True + +# performance +tensor_parallel_degree: 8 +pipeline_parallel_degree: 1 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: True +pipeline_parallel_config: disable_partial_send_recv enable_clear_every_step_cache +recompute: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +unified_checkpoint_config: async_save diff --git a/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_export.yaml b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_export.yaml new file mode 100644 index 0000000000000000000000000000000000000000..785747bf029614c5a01a947a068c19fdc65415f5 --- /dev/null +++ b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/run_export.yaml @@ -0,0 +1,18 @@ +### model +model_name_or_path: baidu/ERNIE-4.5-300B-A47B-Paddle +fine_tuning: LoRA + +### split +max_shard_size: 5 +hf_hub_id: null +output_dir: ./output + +### performance +tensor_parallel_degree: 8 +pipeline_parallel_degree: 1 +sharding_parallel_degree: 1 +sharding: stage1 +pipeline_parallel_config: disable_partial_send_recv enable_clear_every_step_cache +sequence_parallel: True +compute_type: bf16 +fp16_opt_level: O2 diff --git a/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_32k.yaml b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_32k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd3d6cff335e23cae1692459769fc1dabd00b87d --- /dev/null +++ b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_32k.yaml @@ -0,0 +1,84 @@ +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "./examples/data/sft-train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "./examples/data/sft-eval.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 32768 +num_samples_each_epoch: 6000000 + +### model +model_name_or_path: baidu/ERNIE-4.5-300B-A47B-Paddle +moe_group: mp +fine_tuning: Full +fuse_rope: True +use_sparse_head_and_loss_fn: True + +### finetuning +# base +stage: SFT +seed: 23 +do_train: True +do_eval: True +distributed_dataloader: False +dataloader_num_workers: 1 +batch_size: 1 +num_train_epochs: 1 +max_steps: 100 +max_evaluate_steps: 10000 +eval_steps: 10000 +evaluation_strategy: steps +save_steps: 10000000 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: ./vdl_log +output_dir: ./output +disable_tqdm: True + +# train +warmup_steps: 20 +learning_rate: 1.0e-5 +lr_scheduler_type: cosine +min_lr: 1.0e-6 +layerwise_lr_decay_bound: 1.0 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.95 +offload_optim: True + +# performance +tensor_parallel_degree: 8 +pipeline_parallel_degree: 14 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: True +pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv +recompute: True +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +# unified_checkpoint_config: async_save diff --git a/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_fp8_8k.yaml b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_fp8_8k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf513ee1fff8a4bff1e9c1cf2167f1d3156f9cf5 --- /dev/null +++ b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_fp8_8k.yaml @@ -0,0 +1,94 @@ +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "./examples/data/sft-train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "./examples/data/sft-eval.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 8192 +num_samples_each_epoch: 6000000 + +### model +model_name_or_path: baidu/ERNIE-4.5-300B-A47B-Paddle +moe_group: mp +fine_tuning: Full +fuse_rope: True +use_sparse_head_and_loss_fn: True +# Not yet support MTP training, so set num_nextn_predict_layers to 0. +num_nextn_predict_layers: 0 + +### finetuning +# base +stage: SFT +seed: 23 +do_train: True +do_eval: True +distributed_dataloader: False +dataloader_num_workers: 1 +batch_size: 1 +num_train_epochs: 1 +max_steps: 100 +max_evaluate_steps: 10000 +eval_steps: 10000 +evaluation_strategy: steps +save_steps: 10000000 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: ./vdl_log +output_dir: ./output +disable_tqdm: True + +# train +warmup_steps: 20 +learning_rate: 1.0e-5 +lr_scheduler_type: cosine +min_lr: 1.0e-6 +layerwise_lr_decay_bound: 1.0 + +# optimizer +optim: adamw_custom +optim_shard_num: 8 +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.95 + +# performance +tensor_parallel_degree: 8 +pipeline_parallel_degree: 2 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: True +pp_seg_method: [0,29,57] +pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv +recompute: True +recompute_use_reentrant: True +compute_type: fp8 +fp16_opt_level: O2 +disable_ckpt_quant: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +unified_checkpoint_config: ignore_merge_optimizer +apply_hadamard: True +use_lowprecision_moment: True +tensorwise_offload_optimizer: True +# Save optim requires a total of 2 TB of storage across all nodes. +# Skip saving/loading optimizer states (if not resuming training) +# ignore_save_lr_and_optim: True +# ignore_load_lr_and_optim: True diff --git a/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_lora_32k.yaml b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_lora_32k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ad62d36aab67d87c3e3f797cb08cc6cb8cd3b2a --- /dev/null +++ b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_lora_32k.yaml @@ -0,0 +1,88 @@ +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "./examples/data/sft-train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "./examples/data/sft-eval.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 32768 +num_samples_each_epoch: 6000000 + +### model +model_name_or_path: baidu/ERNIE-4.5-300B-A47B-Paddle +moe_group: mp +fine_tuning: LoRA +lora_rank: 32 +lora_alpha: -1 +lora_plus_scale: 1.0 +rslora: False +fuse_rope: True +use_sparse_head_and_loss_fn: True + +### finetuning +# base +stage: SFT +seed: 23 +do_train: True +do_eval: True +distributed_dataloader: False +dataloader_num_workers: 1 +batch_size: 1 +num_train_epochs: 1 +max_steps: 100 +max_evaluate_steps: 10000 +eval_steps: 10000 +evaluation_strategy: steps +save_steps: 10000000 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: ./vdl_log +output_dir: ./output +disable_tqdm: True + +# train +warmup_steps: 20 +learning_rate: 3.0e-4 +lr_scheduler_type: cosine +min_lr: 1.0e-6 +layerwise_lr_decay_bound: 1.0 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.999 +offload_optim: True + +# performance +tensor_parallel_degree: 8 +pipeline_parallel_degree: 2 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: True +pipeline_parallel_config: disable_partial_send_recv enable_clear_every_step_cache +recompute: True +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +unified_checkpoint_config: async_save diff --git a/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_lora_8k.yaml b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_lora_8k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aad584ce4694d6abbe448a2a2d3d9d300342a180 --- /dev/null +++ b/ernie/ERNIE/examples/configs/ERNIE-4.5-300B-A47B/sft/run_sft_lora_8k.yaml @@ -0,0 +1,88 @@ +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "./examples/data/sft-train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "./examples/data/sft-eval.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 8192 +num_samples_each_epoch: 6000000 + +### model +model_name_or_path: baidu/ERNIE-4.5-300B-A47B-Paddle +moe_group: mp +fine_tuning: LoRA +lora_rank: 32 +lora_alpha: -1 +lora_plus_scale: 1.0 +rslora: False +fuse_rope: True +use_sparse_head_and_loss_fn: True + +### finetuning +# base +stage: SFT +seed: 23 +do_train: True +do_eval: True +distributed_dataloader: False +dataloader_num_workers: 1 +batch_size: 1 +num_train_epochs: 1 +max_steps: 100 +max_evaluate_steps: 10000 +eval_steps: 10000 +evaluation_strategy: steps +save_steps: 10000000 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: ./vdl_log +output_dir: ./output +disable_tqdm: True + +# train +warmup_steps: 20 +learning_rate: 3.0e-4 +lr_scheduler_type: cosine +min_lr: 1.0e-6 +layerwise_lr_decay_bound: 1.0 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.999 +offload_optim: True + +# performance +tensor_parallel_degree: 8 +pipeline_parallel_degree: 2 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: True +pipeline_parallel_config: disable_partial_send_recv enable_clear_every_step_cache +recompute: True +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +unified_checkpoint_config: async_save diff --git a/ernie/ERNIE/examples/post-training/dpo/dpo_estimate_training.py b/ernie/ERNIE/examples/post-training/dpo/dpo_estimate_training.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e19308e1712c0e7dbeea150153928eeb91a806 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/dpo_estimate_training.py @@ -0,0 +1,221 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Estimate DPO """ + +import json +import os +import sys + +import numpy as np +import paddle +from paddleformers.trainer import PdArgumentParser +from paddleformers.utils.log import logger + +# isort: off +# fmt: off +from ernie.tokenizer import Ernie4_5_Tokenizer +from ernie.configuration import Ernie4_5_MoeConfig +# isort: on + +from ernie.dataset.dpo import create_dataset + + +def calculate_acc_steps(num_samples, train_batch, dataset_world_size, per_device_train_batch_size): + """calculate_acc_steps + + Args: + num_samples (int): Total training samples in dataset + train_batch (int): Target global batch size + dataset_world_size (int): Number of dataset parallel training devices + per_device_train_batch_size (int): Batch size per GPU/device + + Returns: + int: Number of gradient accumulation steps needed to achieve: + - Global batch size target + - Full dataset coverage + """ + samples_per_batch = per_device_train_batch_size * dataset_world_size * num_samples / train_batch + if num_samples < 100: + recommend_bs = 8 + elif num_samples < 1000: + recommend_bs = 16 + elif num_samples < 10000: + recommend_bs = 32 + elif num_samples < 100000: + recommend_bs = 64 + else: + recommend_bs = 128 + return min(np.ceil(recommend_bs / samples_per_batch), 32) + + +def dpo_estimate_training(tokenizer, data_args, training_args, config, train_dataset=None): + """ dpo_estimate_training + + Args: + tokenizer (PreTrainedTokenizer): Text tokenization + data_args (DataArguments): Datasets configuration + training_args (TrainingArguments): Training configuration + config (PretrainedConfig): Model configuration + train_dataset (Dataset, optional): Preloaded dataset + + Returns: + training_args (TrainingArguments): Training configuration with max_steps setting + res (Dict): Training estimate results + """ + + if training_args.should_save or training_args.should_save_model_state: + os.makedirs(training_args.output_dir, exist_ok=True) + if train_dataset is None: + dataset_config = { + "tokenizer": tokenizer, + "max_seq_len": data_args.max_seq_len, + "max_prompt_len": data_args.max_prompt_len, + "random_seed": training_args.seed, + "num_replicas": 1, + "rank": 0, + "num_samples_each_epoch": data_args.num_samples_each_epoch, + "random_shuffle": data_args.random_shuffle, + "greedy_intokens": data_args.greedy_intokens, + "buffer_size": data_args.buffer_size, + "mask_out_eos_token": data_args.mask_out_eos_token, + } + train_dataset = create_dataset( + task_group=data_args.train_dataset_path, + task_group_prob=data_args.train_dataset_prob, + sub_dataset_type=data_args.train_dataset_type, + **dataset_config + ) + if len(train_dataset.example_dataset._task_group) > 1: + logger.warning("Suggest to use max_steps instead of num_train_epochs for multi source dataset") + logger.info( + "Multi source dataset detected, number of samples will be estimated by following rule. " + "num_samples= (source1_num_samples * prob1 + source2_num_samples * prob2 + ...)*epochs" + ) + max_samples = 0 + for task in train_dataset.example_dataset._task_group: + max_samples += np.ceil(task['num_examples'] * task["prob_origin"]) + else: + max_samples = train_dataset.example_dataset._task_group[0]['num_examples'] + if max_samples > 0 : + if training_args.num_of_gpus > 0: + dataset_world_size = ( + training_args.num_of_gpus + // max(1, training_args.tensor_parallel_degree) + // max(1, training_args.pipeline_parallel_degree)) + if dataset_world_size < 1: + raise ValueError("dataset_world_size must be positive, please verify your config") + else: + dataset_world_size = training_args.dataset_world_size + + num_samples = 0 + train_tokens = 0 + train_batch = 0 + for sequences in train_dataset: + if num_samples >= max_samples: + break + train_batch += 1 + for sequence in sequences: + train_tokens += len(sequence.input_ids) + num_samples += 1 + if training_args.gradient_accumulation_steps < 0: + training_args.gradient_accumulation_steps = calculate_acc_steps( + num_samples, train_batch, dataset_world_size, training_args.per_device_train_batch_size) + max_samples *= training_args.num_train_epochs + train_tokens *= training_args.num_train_epochs + train_batch *= training_args.num_train_epochs + global_batch_size = ( + training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * dataset_world_size + ) + if training_args.num_of_gpus < 0: + training_args.num_of_gpus = paddle.distributed.get_world_size() + + training_args.max_steps = np.ceil(train_batch / global_batch_size) + total_tokens = training_args.max_steps * data_args.max_seq_len * global_batch_size + res = { + "num_train_epochs": int(training_args.num_train_epochs), + "max_steps": int(training_args.max_steps), + "train_samples": int(max_samples), + "gradient_accumulation_steps": int(training_args.gradient_accumulation_steps), + "num_of_gpus": int(training_args.num_of_gpus), + "per_device_train_batch_size": int(training_args.per_device_train_batch_size), + "pipeline_parallel_degree": int(max(1, training_args.pipeline_parallel_degree)), + "tensor_parallel_degree": int(max(1, training_args.tensor_parallel_degree)), + "seed": int(training_args.seed), + "num_samples_each_epoch": int(data_args.num_samples_each_epoch), + "max_seq_len": int(data_args.max_seq_len), + "max_prompt_len": int(data_args.max_prompt_len), + "total_tokens": int(total_tokens), + "train_tokens": int(train_tokens), + "valid": True, + } + if train_batch / training_args.num_train_epochs / global_batch_size < 1: + logger.warning("This dataset is too small, you'd better enlarge your dataset.") + res["valid"] = False + else: + training_args.max_steps = 0 + logger.error("No valid data found, please check your dataset format.") + res = { + "num_train_epochs": int(training_args.num_train_epochs), + "max_steps": int(training_args.max_steps), + "train_samples": 0, + "gradient_accumulation_steps": int(training_args.gradient_accumulation_steps), + "num_of_gpus": int(training_args.num_of_gpus), + "per_device_train_batch_size": int(training_args.per_device_train_batch_size), + "pipeline_parallel_degree": int(max(1, training_args.pipeline_parallel_degree)), + "tensor_parallel_degree": int(max(1, training_args.tensor_parallel_degree)), + "seed": int(training_args.seed), + "num_samples_each_epoch": 6000000, + "max_seq_len": int(data_args.max_seq_len), + "max_prompt_len": int(data_args.max_prompt_len), + "valid": False, + } + + + logger.info(f"training argument: {res}") + # NOTE(gongenlei): if not int, broadcast will overflow + training_args.max_steps = int(training_args.max_steps) + with open(os.path.join(training_args.output_dir, "dpo_train_args.json"), "w", encoding="utf-8") as f: + json.dump(res, f) + return training_args, res + + +if __name__ == "__main__": + from dpo_utils import ( + DataArgument, + DPOConfig, + DPOTrainingArguments, + ModelArgument, + ) + parser = PdArgumentParser((ModelArgument, DataArgument, DPOTrainingArguments, DPOConfig)) + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args, dpo_config = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args, dpo_config = parser.parse_args_into_dataclasses() + training_args.pipeline_parallel_degree = training_args.pipeline_degree + training_args.tensor_parallel_degree = training_args.tensor_degree + training_args.seed = 42 + + if training_args.gradient_accumulation_steps == -1: + logger.info("gradient_accumulation_steps will estimate automaically.") + if paddle.distributed.get_world_size() > 1: + raise NotImplementedError("DPO estimate training does not support multi-node.") + + if training_args.num_of_gpus < 0: + raise ValueError(f"num_of_gpus must be positive, but got num_of_gpus={training_args.num_of_gpus}") + tokenizer = Ernie4_5_Tokenizer.from_pretrained(model_args.model_name_or_path) + config = Ernie4_5_MoeConfig.from_pretrained(model_args.model_name_or_path) + dpo_estimate_training(tokenizer, data_args, training_args, config) diff --git a/ernie/ERNIE/examples/post-training/dpo/dpo_train.py b/ernie/ERNIE/examples/post-training/dpo/dpo_train.py new file mode 100644 index 0000000000000000000000000000000000000000..53066c615599bc28b9c13421e767ba5d7160927d --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/dpo_train.py @@ -0,0 +1,536 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Training DPO """ +import gc +import importlib.util +import os +import sys +import time +import json +from functools import partial + +if importlib.util.find_spec("triton") is not None: + try: + import use_triton_in_paddle + + use_triton_in_paddle.make_triton_compatible_with_paddle() + except Exception as _: + raise RuntimeError( + "Triton is installed, but not yet compatible with Paddle. " + "Please run 'python -m pip install use-triton-in-paddle' to enable Triton support in Paddle." + ) + +import paddle +from paddleformers.peft import LoRAConfig, LoRAModel +from paddleformers.trainer import ( + IntervalStrategy, + PdArgumentParser, + get_last_checkpoint, + set_seed, +) +from paddleformers.trainer.trainer_utils import ShardingOption +from paddleformers.utils.log import logger + +from ernie.callbacks import LayerwiseDropoutCallback +from ernie.configuration import Ernie4_5_MoeConfig +from ernie.dataset.dpo import collate_fn, create_dataset +from ernie.modeling_moe import Ernie4_5_MoeForCausalLM +from ernie.modeling_moe_pp import Ernie4_5_MoeForCausalLMPipe +from ernie.tokenizer import Ernie4_5_Tokenizer +from ernie.utils.common_utils import check_refined_recompute + +# isort: off +from dpo_estimate_training import dpo_estimate_training +from dpo_trainer import ErnieMoEDPOTrainer +from dpo_utils import ( + DataArgument, + DPOConfig, + DPOTrainingArguments, + ModelArgument, + calculate_effective_tokens, +) + +# isort: on + + +def main(): + """main""" + parser = PdArgumentParser( + (ModelArgument, DataArgument, DPOTrainingArguments, DPOConfig) + ) + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args, dpo_config = ( + parser.parse_json_file_and_cmd_lines() + ) + else: + model_args, data_args, training_args, dpo_config = ( + parser.parse_args_into_dataclasses() + ) + + if not model_args.use_sparse_head_and_loss_fn: + model_args.use_sparse_head_and_loss_fn = True + logger.warning( + "Dpo training requires use_sparse_head_and_loss_fn=True. Set use_sparse_head_and_loss_fn to True" + ) + + if data_args.max_seq_len < 16: + data_args.max_seq_len = 16 + logger.warning( + f"max_seq_len must be greater than 16, set max_seq_len to {data_args.max_seq_len}." + ) + if data_args.max_seq_len < data_args.max_prompt_len + 10: + data_args.max_prompt_len = data_args.max_seq_len - 10 + logger.warning( + "max_seq_len must be greater than max_prompt_len + 10, " + "set max_prompt_len to {data_args.max_prompt_len}." + ) + if dpo_config.loss_type == "orpo": + dpo_config.reference_free = True + dpo_config.sft_loss_ratio = 1.0 + dpo_config.loss_type = "or" + logger.info("orpo loss_type is equal to sft_loss + pref_loss_ratio * or_loss.") + if dpo_config.loss_type in ["or", "simpo"] and not dpo_config.reference_free: + dpo_config.reference_free = True + logger.warning( + f"{dpo_config.loss_type} loss_type only supports reference_free. Set reference_free to True." + ) + if dpo_config.lora: + assert model_args.continue_training, "Continue training is required for LoRA." + if training_args.pipeline_parallel_degree > 1: + assert ( + hasattr(training_args, "pipeline_parallel_config") + and "enable_clear_every_step_cache" + in training_args.pipeline_parallel_config + ), "Should set '--pipeline_parallel_config enable_clear_every_step_cache' in bash script for pp." + if training_args.sequence_parallel: + if training_args.pipeline_parallel_degree > 1: + assert ( + hasattr(training_args, "pipeline_parallel_config") + and "disable_partial_send_recv" + in training_args.pipeline_parallel_config + ), "Should set '--pipeline_parallel_config disable_partial_send_recv' in bash script for pp with sp." + if training_args.tensor_parallel_degree <= 1: + training_args.sequence_parallel = False + logger.info("Tensor_parallel_degree = 1. Set sequence_parallel to False.") + + if dpo_config.lora and model_args.fuse_linear: + model_args.fuse_linear = False + logger.info("LoRA does not support fuse_linear. Set fuse_linear to False.") + if dpo_config.lora: + dpo_config.ref_model_update_steps = -1 + logger.warning( + "LoRA does not support ref_model_update_steps. Set ref_model_update_steps to -1." + ) + + if training_args.sharding_parallel_degree > 1: + if ( + ShardingOption.SHARD_GRAD_OP in training_args.sharding + or ShardingOption.FULL_SHARD in training_args.sharding + ): + if training_args.release_grads is True: + training_args.release_grads = False + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + training_args.print_config(dpo_config, "DPOConfig") + + paddle.set_device(training_args.device) + + set_seed(training_args.seed) + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: " + f"{training_args.world_size}, distributed training: {bool(training_args.local_rank != -1)}, " + f"16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + uc_async_save = ( + training_args.unified_checkpoint + and "async_save" in training_args.unified_checkpoint_config + ) + last_checkpoint = get_last_checkpoint( + training_args.output_dir, + signal_folder=training_args.output_signal_dir, + uc_async_save=uc_async_save, + ) + + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set the dtype for loading model + dtype = paddle.get_default_dtype() + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + logger.info("Start to load model ...") + + # Detect torch model. + config_path = os.path.join(model_args.model_name_or_path, "config.json") + with open(config_path, "r", encoding="utf-8") as f: + config_dict = json.load(f) + if "torch_dtype" in config_dict: + raise ValueError( + "Unsupported weight format: Torch weights are not compatible with Paddle model currently." + ) + + # fuse_softmax_mask only support for rocm. + if not paddle.is_compiled_with_rocm(): + if model_args.fuse_softmax_mask: + logger.warning( + "The fuse_softmax_mask flag is only available when using the ROCM version of paddlepaddle. " + ) + model_args.fuse_softmax_mask = False + + check_refined_recompute( + training_args.refined_recompute, + training_args.sequence_parallel, + lora=dpo_config.lora, + ) + + if model_args.weight_quantize_algo is not None: + if model_args.weight_quantize_algo == "weight_only_mix": + quantization_config = dict( + weight_quantize_algo={ + "weight_only_int4": [".*mlp.experts.*"], + "weight_only_int8": [ + ".*self_attn.qkv_proj.*", + ".*self_attn.o_proj.*", + ".*mlp.up_gate_proj.*", + ".*mlp.down_proj.*", + ], + }, + ignore_modules=[".*out_linear.*"], + ) + else: + quantization_config = dict( + weight_quantize_algo=model_args.weight_quantize_algo, + ignore_modules=[".*out_linear.*"], + ) + else: + quantization_config = dict(weight_quantize_algo=model_args.weight_quantize_algo) + + model_kwargs = dict( + pretrained_model_name_or_path=model_args.model_name_or_path, + dtype=dtype, + tensor_parallel_degree=training_args.tensor_parallel_degree, + tensor_parallel_rank=training_args.tensor_parallel_rank, + virtual_pp_degree=model_args.virtual_pp_degree, + pp_seg_method=model_args.pp_seg_method, + recompute=training_args.recompute, + recompute_granularity=model_args.recompute_granularity, + use_flash_attention=model_args.use_flash_attention, + tensor_parallel_output=model_args.tensor_parallel_output, + fuse_linear=model_args.fuse_linear, + fuse_softmax_mask=model_args.fuse_softmax_mask, + fuse_rms_norm=model_args.fuse_rms_norm, + fuse_swiglu=model_args.fuse_swiglu, + fuse_gate_detach_matmul=model_args.fuse_gate_detach_matmul, + dpo_config=dpo_config, + sequence_parallel=training_args.sequence_parallel, + max_sequence_length=data_args.max_seq_len, + use_sparse_head_and_loss_fn=model_args.use_sparse_head_and_loss_fn, + no_recompute_layers=model_args.no_recompute_layers, + quantization_config=quantization_config, + use_fused_head_and_loss_fn=model_args.use_fused_head_and_loss_fn, + recompute_use_reentrant=model_args.recompute_use_reentrant, + use_sparse_flash_attn=model_args.use_sparse_flash_attn, + refined_recompute=training_args.refined_recompute, + fuse_rope=model_args.fuse_rope, + moe_group=model_args.moe_group, + hidden_dropout_prob=training_args.hidden_dropout_prob, + attention_probs_dropout_prob=training_args.attention_probs_dropout_prob, + moe_multimodal_dispatch_use_allgather=model_args.moe_multimodal_dispatch_use_allgather, + moe_group_experts=model_args.moe_group_experts, + moe_aux_loss_lambda=model_args.moe_aux_loss_lambda, + moe_orthogonal_loss_lambda=model_args.moe_orthogonal_loss_lambda, + moe_z_loss_lambda=model_args.moe_z_loss_lambda, + moe_use_hard_gate=model_args.moe_use_hard_gate, + num_acc_steps=training_args.gradient_accumulation_steps, + add_tail_layers=model_args.add_tail_layers, + num_nextn_predict_layers=0, + ) + if model_args.moe_use_aux_free is False: + model_kwargs.update({"moe_use_aux_free": False}) + config = Ernie4_5_MoeConfig.from_pretrained(**model_kwargs) + + if ( + training_args.pipeline_parallel_degree > 1 + and model_args.weight_quantize_algo is not None + and config.tie_word_embeddings + ): + raise NotImplementedError( + "Quantization is not supported for models with tied lm_head and word_embedding \ + weights when using Pipeline Parallelism (PP)." + ) + + if config.moe_num_experts is None or config.moe_num_experts == 0: + config.moe_group = ( + "dummy" if model_args.moe_group == "mp" else model_args.moe_group + ) + + if training_args.pipeline_parallel_degree > 1: + model_class = Ernie4_5_MoeForCausalLMPipe + else: + model_class = Ernie4_5_MoeForCausalLM + if model_args.continue_training: + model = model_class.from_pretrained( + model_args.model_name_or_path, config=config + ) + else: + model = model_class._from_config(config, dtype=dtype) + + if not dpo_config.reference_free and not dpo_config.lora: + ref_config = Ernie4_5_MoeConfig.from_pretrained(**model_kwargs) + if ref_config.moe_num_experts is None or ref_config.moe_num_experts == 0: + ref_config.moe_group = ( + "dummy" if model_args.moe_group == "mp" else model_args.moe_group + ) + ref_model = model_class._from_config(ref_config, dtype=dtype) + # make sure the state_dict is the same to get the same loss for first step + ref_model.set_state_dict(model.state_dict()) + else: + ref_model = None + + model.config.dpo_config = None + + if model.config.head_dim is None: + del model.config.head_dim + if ref_model is not None and ref_model.config.head_dim is None: + del ref_model.config.head_dim + + if dpo_config.lora: + logger.info("Start to wrap model with LoRA config ...") + if model_args.lora_path is None: + target_modules = [ + ".*qkv_proj.*", + ".*out_proj.*", + ".*linear1.*", + ".*linear2.*", + ] + if model_args.rslora_plus: + model_args.rslora = True + model_args.lora_plus_scale = 4 + model_args.lora_alpha = 4 + if model_args.weight_quantize_algo is not None: + if model_args.rslora or model_args.lora_plus_scale != 1.0: + logger.info( + "Weight quantization is not supported in LoRA+ and RsLoRA." + ) + if model_args.lora_alpha == -1: + if model_args.rslora: + model_args.lora_alpha = 4 + else: + model_args.lora_alpha = 2 * model_args.lora_rank + lora_config = LoRAConfig( + target_modules=target_modules, + r=model_args.lora_rank, + lora_alpha=model_args.lora_alpha, + rslora=model_args.rslora, + lora_plus_scale=model_args.lora_plus_scale, + tensor_parallel_degree=training_args.tensor_parallel_degree, + dtype=dtype, + head_dim=model.config.hidden_size // model.config.num_attention_heads, + base_model_name_or_path=model_args.model_name_or_path, + ) + model = LoRAModel(model, lora_config) + else: + model = LoRAModel.from_pretrained( + model=model, lora_path=model_args.lora_path + ) + model.print_trainable_parameters() + logger.info("Wraping model with LoRA config successfully !") + + tokenizer = Ernie4_5_Tokenizer.from_pretrained( + model_args.model_name_or_path, + ) + logger.info("Loading model & tokenizer successfully !") + + logger.info("Start to create dataset ...") + dataset_config = { + "tokenizer": tokenizer, + "max_seq_len": data_args.max_seq_len, + "max_prompt_len": data_args.max_prompt_len, + "random_seed": training_args.seed, + "num_replicas": training_args.dataset_world_size, + "rank": training_args.dataset_rank, + "num_samples_each_epoch": data_args.num_samples_each_epoch, + "random_shuffle": data_args.random_shuffle, + "greedy_intokens": data_args.greedy_intokens, + "buffer_size": data_args.buffer_size, + "use_attn_mask_start_row_indices": model_args.use_attn_mask_start_row_indices, + "mask_out_eos_token": data_args.mask_out_eos_token, + } + + if training_args.max_steps == -1: + if training_args.should_load_dataset and paddle.distributed.get_rank() == 0: + # NOTE(gongenlei): not to feed train_dataset, or the data will be wrong in next training. + training_args, _ = dpo_estimate_training( + tokenizer, data_args, training_args, config=model.config + ) + + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + pd_max_steps = paddle.to_tensor([training_args.max_steps]) + paddle.distributed.broadcast(pd_max_steps, src=0) + training_args.max_steps = int(pd_max_steps.item()) + logger.info( + f"Re-setting training_args.max_steps to {training_args.max_steps} ({training_args.num_train_epochs})" + ) + if training_args.max_steps <= 0: + raise ValueError( + f"Invalid max_steps: {training_args.max_steps}. Please check your dataset" + ) + if training_args.save_strategy == IntervalStrategy.EPOCH: + training_args.save_strategy = IntervalStrategy.STEPS + training_args.save_steps = int( + training_args.max_steps / training_args.num_train_epochs + ) + if training_args.evaluation_strategy == IntervalStrategy.EPOCH: + training_args.evaluation_strategy = IntervalStrategy.STEPS + training_args.eval_steps = int( + training_args.max_steps / training_args.num_train_epochs + ) + if training_args.logging_strategy == IntervalStrategy.EPOCH: + training_args.logging_strategy = IntervalStrategy.STEPS + training_args.logging_steps = int( + training_args.max_steps / training_args.num_train_epochs + ) + + if training_args.should_load_dataset: + train_dataset = create_dataset( + task_group=data_args.train_dataset_path, + task_group_prob=data_args.train_dataset_prob, + sub_dataset_type=data_args.train_dataset_type, + **dataset_config, + ) + + if training_args.do_eval and training_args.should_load_dataset: + eval_dataset = create_dataset( + task_group=data_args.eval_dataset_path, + task_group_prob=data_args.eval_dataset_prob, + sub_dataset_type=data_args.eval_dataset_type, + is_valid=True, + **dataset_config, + ) + logger.info("Creating dataset successfully ...") + + trainer = ErnieMoEDPOTrainer( + model=model, + ref_model=ref_model, + dpo_config=dpo_config, + args=training_args, + train_dataset=( + train_dataset + if training_args.do_train and training_args.should_load_dataset + else None + ), + eval_dataset=( + eval_dataset + if training_args.do_eval and training_args.should_load_dataset + else None + ), + tokenizer=tokenizer, + data_collator=partial( + collate_fn, + tokenizer=tokenizer, + max_seq_len=data_args.max_seq_len, + use_sparse_head_and_loss_fn=model_args.use_sparse_head_and_loss_fn, + use_fused_head_and_loss_fn=model_args.use_fused_head_and_loss_fn, + use_response_score_delta=dpo_config.offset_alpha > 0.0, + ), + model_with_dpo_criterion=True, + ) + + if training_args.hidden_dropout_prob or training_args.attention_probs_dropout_prob: + trainer.add_callback(LayerwiseDropoutCallback()) + + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) + if ( + training_args.dpo_benchmark + and training_args.should_load_dataset + and paddle.distributed.get_rank() == 0 + ): + del train_dataset + gc.collect() + train_dataset = create_dataset( + task_group=data_args.train_dataset_path, + task_group_prob=data_args.train_dataset_prob, + sub_dataset_type=data_args.train_dataset_type, + **dataset_config, + ) + total_effective_tokens, total_tokens = calculate_effective_tokens( + training_args, train_dataset, data_args.max_seq_len + ) + effective_tokens_per_second = ( + total_effective_tokens / train_result.metrics["train_runtime"] + ) + total_tokens_per_second = ( + total_tokens / train_result.metrics["train_runtime"] + ) + effective_ratio = 100 * total_effective_tokens / total_tokens + logger.info( + "[timelog] {}: {:.2f} % ({}) ".format( + "Effective ratio", + effective_ratio, + time.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + logger.info( + "[timelog] {}: {:.2f} token/s ({}) ".format( + "Effective tokens per second", + effective_tokens_per_second, + time.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + logger.info( + "[timelog] {}: {:.2f} token/s ({}) ".format( + "Tokens per second", + total_tokens_per_second, + time.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + + if not training_args.dpo_benchmark: + trainer.save_model( + merge_tensor_parallel=training_args.tensor_parallel_degree > 1 + ) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + if training_args.do_eval: + eval_result = trainer.evaluate() + trainer.log_metrics("eval", eval_result) + trainer.save_metrics("eval", eval_result, combined=False) + + +if __name__ == "__main__": + with paddle.amp.auto_cast(enable=False): + main() diff --git a/ernie/ERNIE/examples/post-training/dpo/dpo_trainer.py b/ernie/ERNIE/examples/post-training/dpo/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..55510e176387eb32873cdf831122beb0c3d8964d --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/dpo_trainer.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DPO Trainer for Ernie-MoE model with enhanced distributed training support. +""" + +from functools import partial + +import paddle +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.fleet.utils.sequence_parallel_utils import register_sequence_parallel_allreduce_hooks +from paddleformers.trainer import Trainer +from paddleformers.trl import DPOTrainer +from paddleformers.utils.log import logger + +from ernie.callbacks import SPGradSyncCallback +from ernie.moe.distributed.hybrid_parallel_optimizer import ( + HybridParallelClipGrad as MoEHybridParallelClipGrad, +) +from ernie.moe.moe_clip import ClipGradForMOEByGlobalNorm + + +class ErnieMoEDPOTrainer(DPOTrainer): + """ + Custom DPO trainer class for Ernie-MoE model with enhanced distributed training support. + """ + + def _wrap_model(self, model, training=True): + """Wrap model.""" + model = super()._wrap_model(model, training) + + def enable_sequence_parallel(_model): + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: + if self.args.use_sp_callback: + self.add_callback(SPGradSyncCallback(_model._layers)) + else: + register_sequence_parallel_allreduce_hooks( + _model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce + ) + + enable_sequence_parallel(model) + return model + + def create_optimizer(self, lr_scheduler=None): + """ + Create and configure the optimizer for training. + + Args: + lr_scheduler (Optional): Learning rate scheduler for adjusting the learning rate during training. + + Returns: + paddle.optimizer.Optimizer: The configured optimizer instance with specified parameters and settings. + """ + self.static_name_to_dyg_name = {p.name: n for n, p in self.model.named_parameters()} + + if self.optimizer is None: + if self.optimizer_grouped_parameters is not None: + optimizer_params = self.optimizer_grouped_parameters + else: + optimizer_params = self.model.parameters() + + decay_parameters = [ + p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2": + optimizer_kwargs["multi_precision"] = True + + def _get_layer_lrs(x, lr_lower_bound, n_layers): + """ + Calculate layer-wise learning rates with depth-based scaling. + + Implements a learning rate schedule where layers closer to the input (lower depth) + get smaller learning rates, while deeper layers get progressively higher rates. + This follows the common practice that earlier layers typically need finer tuning. + + Args: + x (Parameter): The model parameter to calculate learning rate for + lr_lower_bound (float): Minimum learning rate (for depth=0 layers) + n_layers (int): Total number of transformer layers in the model + + Returns: + float: Computed learning rate for the given parameter + + Note: + - Special layers (embedding and head) get fixed positions in the depth hierarchy + - The depth-to-LR mapping follows a linear interpolation between lower bound and 1.0 + - TODO: Needs to consider LoRA (Low-Rank Adaptation) parameters in future + """ + name = self.static_name_to_dyg_name[x.name] + if "lm_head" in name or "ernie.norm" in name: + depth = n_layers + 2 + elif "embed_tokens" in name: + depth = 0 + elif self.dpo_config.lora and "lora" in name: + if "ernie.layers" in name: + depth = int(name.split(".")[3]) + else: + depth = int(name.split(".")[1]) + else: + if name.startswith("ernie.layers."): + depth = int(name.split(".")[2]) + else: + depth = int(name.split(".")[0]) + return lr_lower_bound + depth / (n_layers + 2) * (1 - lr_lower_bound) + + lr_ratio_func = None + layerwise_lr_decay_bound = self.args.layerwise_lr_decay_bound + assert ( + layerwise_lr_decay_bound > 0 and layerwise_lr_decay_bound <= 1 + ), f"layerwise_lr_decay_bound: {layerwise_lr_decay_bound} out of range. should be in (0, 1]" + if layerwise_lr_decay_bound < 1: + lr_ratio_func = partial( + _get_layer_lrs, + lr_lower_bound=layerwise_lr_decay_bound, + n_layers=self.model.config.num_hidden_layers, + ) + + if self.args.max_grad_norm <= 0: + grad_clip = None + elif self.args.use_expert_parallel and not self.args.use_hybrid_parallel: + + def expert_fn(p): + return getattr(p, "no_sync", False) + + grad_clip = ClipGradForMOEByGlobalNorm( + self.args.max_grad_norm, + is_expert_param_func=expert_fn, + moe_group=_get_global_group(), + ) + else: + grad_clip = nn.ClipGradByGlobalNorm(self.args.max_grad_norm) + + self.optimizer = optimizer_cls( + learning_rate=(self.lr_scheduler if lr_scheduler is None else lr_scheduler), + apply_decay_param_fun=apply_decay_param_fun, + parameters=optimizer_params, + weight_decay=self.args.weight_decay, + grad_clip=grad_clip, + lr_ratio=lr_ratio_func, + **optimizer_kwargs, + ) + + if self.args.use_expert_parallel and self.args.use_hybrid_parallel: + logger.debug('using moe-hybrid-clip under hybrid parallel') + hcg = fleet.get_hybrid_communicate_group() + self.optimizer._grad_clip = MoEHybridParallelClipGrad( + self.optimizer._grad_clip, + hcg, + moe_group=hcg.get_data_parallel_group(), + ) + + self.optimizer._dtype = paddle.get_default_dtype() + return self.optimizer diff --git a/ernie/ERNIE/examples/post-training/dpo/dpo_utils.py b/ernie/ERNIE/examples/post-training/dpo/dpo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..925db6a58c366bc74b55a60dc2ddcb13c21d776e --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/dpo_utils.py @@ -0,0 +1,343 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO utils""" +from dataclasses import dataclass, field +from typing import Optional + +from paddleformers.trainer import IntervalStrategy, TrainingArguments + + +def add_start_docstrings(*docstr): + """Adds docstrings for a function.""" + + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class DPOTrainingArguments(TrainingArguments): + """DPOTrainingArguments""" + + unified_checkpoint: bool = field( + default=True, + metadata={"help": "Enable fused linear grad add strategy."}, + ) + unified_checkpoint_config: Optional[str] = field( + default="", + metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"}, + ) + num_of_gpus: int = field(default=-1, metadata={"help": "Number of gpus."}) + pipeline_degree: int = field(default=1, metadata={"help": "pipeline_degree for estimate"}) + tensor_degree: int = field(default=1, metadata={"help": "tensor_degree for estimate"}) + sharding_degree: int = field(default=1, metadata={"help": "sharding_degree for estimate"}) + dpo_benchmark: bool = field( + default=False, + metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."}, + ) + dropout_warmup_steps: int = field( + default=0, + metadata={"help": "dropout warmup steps"}, + ) + hidden_dropout_prob: float = field( + default=0.0, + metadata={"help": "dropout probability for hidden layers"}, + ) + attention_probs_dropout_prob: float = field( + default=0.0, + metadata={"help": "dropout probability for attention layers"}, + ) + sequence_parallel: bool = field(default=True, metadata={"help": "Whether to use sequence_parallel"}) + layerwise_lr_decay_bound: Optional[float] = field( + default=1.0, + metadata={ + "help": "Use a large learning rate for the top layers and " + "a small learning rate for the bottom layers. 1.0: Do not use this strategy." + }, + ) + use_sp_callback: bool = field( + default=False, + metadata={ + "help": "Using the SP callback will skip the implementation of SPHook " + "to avoid redundant gradient computation." + }, + ) + + def __post_init__(self): + super().__post_init__() + if self.dpo_benchmark: + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + if not self.disable_tqdm: + self.logging_steps = 1 + self.logging_strategy = IntervalStrategy.STEPS + + +@dataclass +class DataArgument: + """DataArgument""" + + train_dataset_type: str = field(default="erniekit", metadata={"help": "List contains type of training datasets."}) + train_dataset_path: str = field( + default="examples/data/sft-train.jsonl", + metadata={"help": "List contains path of training data sources."}, + ) + train_dataset_prob: str = field( + default="1.0", + metadata={"help": "List contains probabilities of training data sources."}, + ) + eval_dataset_type: str = field(default="erniekit", metadata={"help": "List contains type of eval datasets."}) + eval_dataset_path: str = field( + default="examples/data/sft-eval.jsonl", + metadata={"help": "List contains path of eval data sources."}, + ) + eval_dataset_prob: str = field( + default="1.0", + metadata={"help": "List contains probabilities of eval data sources."}, + ) + max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."}) + max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."}) + num_samples_each_epoch: int = field( + default=6000000, + metadata={"help": "Number of samples per epoch. Used for SFT."}, + ) + random_shuffle: bool = field( + default=True, + metadata={"help": "Whether to enable authorize code for privatization. Defaults to False."}, + ) + greedy_intokens: bool = field( + default=True, + metadata={"help": "Whether apply greedy intokens."}, + ) + buffer_size: int = field( + default=500, + metadata={"help": "Buffer size for greedy_intokens strategy."}, + ) + mask_out_eos_token: bool = field(default=True, metadata={"help": "Mask out eos token"}) + + +@dataclass +class ModelArgument: + """ModelArgument""" + + model_name_or_path: str = field( + default="ernie-bot", + metadata={"help": "Pretrained model name or path to local directory."}, + ) + use_flash_attention: bool = field(default=True, metadata={"help": "Whether to use flash attention"}) + recompute_granularity: str = field( + default="full", + metadata={ + "help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`." + }, + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + pp_seg_method: str = field( + default="layer:Ernie4_5_DecoderLayer|EmptyLayer", + metadata={ + "help": ( + "The method used to segment the pipeline layers among pipeline stages. " + "`layer:Ernie4_5_DecoderLayer|EmptyLayer`, `uniform`, `[0, 30, 59]`." + ) + }, + ) + fuse_linear: bool = field(default=True, metadata={"help": "Whether to use fuse_linear"}) + fuse_softmax_mask: bool = field(default=False, metadata={"help": "Whether to fuse softmax and add"}) + fuse_rms_norm: bool = field(default=True, metadata={"help": "Whether to fuse RMSNorm for efficiency"}) + fuse_swiglu: bool = field( + default=True, metadata={"help": "Whether to fuse SwiGLU projection and activation for efficiency"} + ) + fuse_gate_detach_matmul: bool = field( + default=True, + metadata={"help": "Whether to use the fused gate-detach matmul implementation."}, + ) + tensor_parallel_output: bool = field(default=True, metadata={"help": "tensor_parallel_output"}) + use_sparse_head_and_loss_fn: bool = field( + default=False, + metadata={"help": "Whether to use sparse LM Head and loss function."}, + ) + use_sparse_flash_attn: bool = field( + default=True, + metadata={"help": "Under use attn_mask_start_row_indices=True, whether use sparse flash attention or not."}, + ) + use_attn_mask_start_row_indices: bool = field( + default=True, + metadata={"help": "Whether to use attn_mask_start_row_indices in flash attention."}, + ) + no_recompute_layers: Optional[int] = field( + default=None, + metadata={"help": "Specify the full transformer layers that should not be recomputed."}, + ) + weight_quantize_algo: str = field( + default=None, + metadata={"help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'."}, + ) + add_tail_layers: int = field( + default=False, + metadata={"help": ("Add EmptyLayer after Ernie4_5_DecoderLayerPipe. Only for Pipeline Parallel")}, + ) + # LoRA + lora_rank: int = field(default=8, metadata={"help": "Lora rank."}) + lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."}) + rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) + lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) + lora_alpha: int = field(default=-1, metadata={"help": "lora_alpha"}) + rslora_plus: bool = field(default=False, metadata={"help": "Strengthen lora performance"}) + use_fused_head_and_loss_fn: bool = field( + default=False, + metadata={"help": "Whether to fuse LM Head and loss function."}, + ) + recompute_use_reentrant: bool = field( + default=False, + metadata={"help": "recompute_use_reentrant"}, + ) + continue_training: bool = field( + default=True, + metadata={ + "help": ( + "Whether to train from existing paddleformers model weights.\n" + "If set True, the model_name_or_path argument must exist in the paddleformers models." + ) + }, + ) + fuse_rope: bool = field( + default=False, + metadata={"help": "Whether to fuse rotary postition embedding"}, + ) + # MoE + use_recompute_moe: Optional[bool] = field(default=False, metadata={"help": "Whether to use recompute moe"}) + moe_group: Optional[str] = field( + default="dummy", metadata={"help": "MoE communication group, currently support 'mp|dummy'"} + ) + moe_multimodal_dispatch_use_allgather: Optional[str] = field( + default="v2-alltoall-unpad", metadata={"help": "moe dispatch use allgather"} + ) + moe_group_experts: Optional[bool] = field( + default=False, metadata={"help": "Whether to apply group-wise processing to expert gate logits."} + ) + moe_aux_loss_lambda: Optional[float] = field( + default=1e-5, + metadata={"help": "Lambda value for moe aux loss."}, + ) + moe_orthogonal_loss_lambda: Optional[float] = field( + default=0.0, + metadata={"help": "Lambda value for moe orthogonal loss."}, + ) + moe_z_loss_lambda: Optional[float] = field( + default=0.0, + metadata={"help": "Lambda value for moe z loss."}, + ) + moe_use_hard_gate: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to use hard gate. If `moe_use_hard_gate` is True, a hard " + "routing strategy is used instead of a learned gating network." + }, + ) + moe_use_aux_free: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to use auxiliary‑loss‑free routing. If True, " + "load balancing (using expert bias adjustments) is used instead " + "of traditional auxiliary loss for MoE." + }, + ) + + +@dataclass +class DPOConfig: + """DPOConfig""" + + beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + offset_alpha: float = field(default=0.0, metadata={"help": "the offset coefficient for score-based DPO loss"}) + simpo_gamma: float = field(default=0.5, metadata={"help": "the gamma parameter for SimPO loss"}) + normalize_logps: bool = field( + default=True, + metadata={"help": "Apply logprobs normalization."}, + ) + label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"}) + loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"}) + pref_loss_ratio: float = field(default=1.0, metadata={"help": "DPO loss ratio"}) + sft_loss_ratio: float = field(default=0.0, metadata={"help": "SFT loss ratio"}) + dpop_lambda: float = field(default=50, metadata={"help": "dpop_lambda"}) + ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "}) + reference_free: bool = field(default=False, metadata={"help": "No reference model."}) + lora: bool = field(default=False, metadata={"help": "Use LoRA model."}) + + def __post_init__(self): + if self.offset_alpha > 0.0: + if self.loss_type != "sigmoid": + raise ValueError( + "Only sigmoid loss_type supports score-based loss (offset_alpha > 0), " + "please set loss_type to sigmoid or set offset_alpha to 0." + ) + + +def calculate_effective_tokens(training_args, train_dataset, max_seq_len): + """ + Caculate the effective tokens during training. + + Args: + training_args (TrainingArguments): Configuration object containing: + - data_parallel_degree (int): Number of data parallel partitions + - sharding_parallel_degree (int): Number of sharding partitions + - max_steps (int): Total training iterations + - per_device_train_batch_size (int): Batch size per GPU/device + - gradient_accumulation_steps (int): Grad accumulation steps + train_dataset (IterableDataset): Training dataset with input_ids fields + max_seq_len (int): Padded sequence length + + Returns: + tuple: (effective_tokens, total_possible_tokens) where: + - effective_tokens (int): Actual processed tokens (excludes padding) + - total_possible_tokens (int): Theoretical maximum (batch_size * seq_len) + """ + total_effective_tokens = 0 + try: + data_parallel_degree = training_args.data_parallel_degree + except: + data_parallel_degree = 1 + if training_args.sharding_parallel_degree > 1: + sharding_parallel_degree = training_args.sharding_parallel_degree + else: + sharding_parallel_degree = 1 + + total_batch = ( + training_args.max_steps + * training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * sharding_parallel_degree + * data_parallel_degree + ) + for i, data in enumerate(train_dataset): + if i == total_batch: + break + for dd in data: + total_effective_tokens += len(dd.input_ids) + total_tokens = total_batch * max_seq_len + + return total_effective_tokens, total_tokens diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_8k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..dfc3527420bac63369cab1c3b43e7261abdbe5da --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_8k.sh @@ -0,0 +1,113 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-1} +model_path="ERNIE4.5T_0.3B" +task="dpo_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 1 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree $nnodes \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 36 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 0.5 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --offset_alpha 1.0 \ + --unified_checkpoint_config "async_save" diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_lora_32k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..64c24c2048f17eb354673472800796d22d9ed682 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_lora_32k.sh @@ -0,0 +1,118 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-1} +model_path="ERNIE4.5T_0.3B" +task="dpo_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 1 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree $nnodes \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 36 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 0.5 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --offset_alpha 1.0 \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_lora_8k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..04ee5b4febde8e1dcaabe01727add0846441f2a0 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_0.3b_dpo_lora_8k.sh @@ -0,0 +1,118 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-1} +model_path="ERNIE4.5T_0.3B" +task="dpo_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 1 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree $nnodes \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 36 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 0.5 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --offset_alpha 1.0 \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_32k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3acd805febcab16003bd6bb2e01e62e2c2c0c15 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_32k.sh @@ -0,0 +1,116 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-14} +model_path="ERNIE4.5T_chat" +task="dpo_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 8 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree $nnodes \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 36 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 0.5 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --pp_seg_method "[0,7,10,14,18,22,26,30,34,38,42,46,50,53,57]" \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_8k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..6abee0c25092e527e74607926f525ba667b7f5af --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_8k.sh @@ -0,0 +1,116 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-14} +model_path="ERNIE4.5T_chat" +task="dpo_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 8 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree $nnodes \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 36 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 0.5 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --pp_seg_method "[0,7,10,14,18,22,26,30,34,38,42,46,50,53,57]" \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_lora_8k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..284ff7087b3203c99005ef263e5d46b25e4cfd93 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_dpo_lora_8k.sh @@ -0,0 +1,119 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-2} +model_path="ERNIE4.5T_chat" +task="dpo_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 8 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree $nnodes \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 36 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 0.5 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --offset_alpha 1.0 \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_128k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_128k.sh new file mode 100644 index 0000000000000000000000000000000000000000..f58a67dbb8acd5fa343ceae0e41cf85ce791b9e5 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_128k.sh @@ -0,0 +1,112 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_128k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 4 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 131072 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_32k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..1c8cc8f305bfd32c3ed30ce55bea8c991c82cdd4 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_32k.sh @@ -0,0 +1,112 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 4 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_8k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..50242f6b080c563f2d8fd053a6c12a7cf31a3642 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_8k.sh @@ -0,0 +1,111 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 4 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_128k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_128k.sh new file mode 100644 index 0000000000000000000000000000000000000000..724bfdd63c570f031c392d20376b284e14f80c49 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_128k.sh @@ -0,0 +1,117 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_lora_128k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 4 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 131072 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_32k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..5638d1506b054cc3dcf9019bbbeab93db40f9e96 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_32k.sh @@ -0,0 +1,117 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_lora_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 2 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_8k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..7875def75dcdfad6584326f312d5b642463214c2 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_lora_8k.sh @@ -0,0 +1,117 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 2 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_128k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_128k.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc7a1b09e10f6bd47161a43d5d6edc0d94b78013 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_128k.sh @@ -0,0 +1,119 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_wint8mix_lora_128k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 4 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 131072 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_32k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..ba539ddcbada24a4353c716f3b70929a18b77e00 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_32k.sh @@ -0,0 +1,119 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_wint8mix_lora_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 2 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_8k.sh b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..85ddec123532724576eb6ef7eb9629af291eb206 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/dpo/scripts/run_lite_dpo_wint8mix_lora_8k.sh @@ -0,0 +1,119 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="dpo_wint8mix_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + ./examples/post-training/dpo/dpo_train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/dpo-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/dpo-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --max_steps 800 \ + --save_steps 100 \ + --logging_steps 1 \ + --eval_steps 20000 \ + --weight_decay 0.1 \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --tensor_parallel_degree 2 \ + --tensor_parallel_config "sync_param sync_grad sync_moment" \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --gradient_accumulation_steps 8 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 42 \ + --warmup_steps 50 \ + --learning_rate 5e-7 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 4 \ + --distributed_dataloader 1 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler" \ + --dpo_benchmark 0 \ + --greedy_intokens 1 \ + --beta 0.1 \ + --loss_type "sigmoid" \ + --label_smoothing 0.0 \ + --pref_loss_ratio 1.0 \ + --sft_loss_ratio 0.0 \ + --ref_model_update_steps -1 \ + --sequence_parallel 1 \ + --use_attn_mask_start_row_indices 1 \ + --tensor_parallel_output 1 \ + --reference_free 0 \ + --simpo_gamma 0.5 \ + --recompute_use_reentrant 1 \ + --moe_group mp \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --hidden_dropout_prob 0 \ + --attention_probs_dropout_prob 0.1 \ + --dropout_warmup_steps 100 \ + --adam_epsilon 1e-8 \ + --layerwise_lr_decay_bound 1 \ + --use_sp_callback 1 \ + --save_total_limit 5 \ + --scale_loss 8192 \ + --release_grads 1 \ + --amp_master_grad 1 \ + --lr_scheduler_type "cosine" \ + --min_lr 5e-7 \ + --fuse_rope 1 \ + --fuse_linear 1 \ + --offset_alpha 1.0 \ + --offload_optim \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --lora_alpha 128 \ + --lora_plus_scale 12 \ + --rslora \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/sft/create_sft_data.py b/ernie/ERNIE/examples/post-training/sft/create_sft_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9e6b2beec5634a737a0e5a055ef62c86f3a3184d --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/create_sft_data.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from dataclasses import fields + +import numpy as np +from paddleformers.data.indexed_dataset import SFTMMapIndexedDatasetBuilder +from paddleformers.trainer import PdArgumentParser, RuntimeTimer +from paddleformers.utils.log import logger +from sft_utils import ( + BuildDataArgument, + BuildSFTTrainingArguments, + DataGenerator, +) +from train import ModelArgument + +from ernie.configuration import Ernie4_5_Config +from ernie.dataset.finetuning import Sequence, create_dataset +from ernie.tokenizer import Ernie4_5_Tokenizer +from ernie.utils.common_utils import estimate_training + + +def main(): + """ + Convert the dataset to the MapDataset format that can be used by the SFT training. + """ + runtime_timer = RuntimeTimer("Creating SFT MapDataset") + + parser = PdArgumentParser((ModelArgument, BuildDataArgument, BuildSFTTrainingArguments)) + + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Avoid initializing tp, pp, and sdp to -1 in CPU environment. + training_args.pipeline_parallel_degree = training_args.pp_degree + training_args.tensor_parallel_degree = training_args.tp_degree + training_args.sharding_parallel_degree = training_args.sdp_degree + + training_args.data_parallel_degree = ( + training_args.num_of_gpus + // training_args.tensor_parallel_degree + // training_args.sharding_parallel_degree + // training_args.pipeline_parallel_degree + ) + global_batch_size = ( + training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * training_args.data_parallel_degree + * training_args.sharding_parallel_degree + ) + + tokenizer = Ernie4_5_Tokenizer.from_pretrained(model_args.model_name_or_path) + config = Ernie4_5_Config.from_pretrained(model_args.model_name_or_path) + + if tokenizer.vocab_size < 2**16 - 1: + save_dtype = np.uint16 + else: + save_dtype = np.int32 + + dataset_config = { + "tokenizer": tokenizer, + "max_seq_len": data_args.max_seq_len, + "random_seed": training_args.seed, + "num_samples_each_epoch": data_args.num_samples_each_epoch, + "random_shuffle": data_args.random_shuffle, + "greedy_intokens": data_args.greedy_intokens, + } + dataclass = Sequence + + if training_args.do_train and data_args.train_dataset_path: + runtime_timer.start("Create SFT Train MapDataset") + os.makedirs(os.path.join(data_args.dataset_output_dir, 'train'), exist_ok=True) + + train_output_idx_files = os.path.join(data_args.dataset_output_dir, 'train', 'index.idx') + train_dataset = create_dataset( + task_group=data_args.train_dataset_path, + task_group_prob=data_args.train_dataset_prob, + sub_dataset_type=data_args.train_dataset_type, + is_valid=False, + **dataset_config, + ) + if training_args.max_steps == -1: + training_args.estimation_output_file = ( + 'estimate_training.json' + if training_args.estimation_output_file is None + else training_args.estimation_output_file + ) + training_args.max_steps = estimate_training(train_dataset, data_args, training_args, model_args) + + train_samples = training_args.max_steps * global_batch_size + + output_file_dict = {} + train_dir = os.path.join(data_args.dataset_output_dir, 'train') + for field in fields(dataclass): + output_path = os.path.join(train_dir, f"{field.name}.bin") + output_file_dict[field.name] = output_path + + train_builder = SFTMMapIndexedDatasetBuilder(output_file_dict, save_dtype) + train_sample_generator = DataGenerator(train_dataset) + + used_samples = 0 + while used_samples < train_samples: + train_sample = next(train_sample_generator) + for sequence in train_sample: + train_builder.add_item(sequence) + + train_builder.end_document() + used_samples += 1 + train_builder.finalize(train_output_idx_files) + logger.info(f"{runtime_timer.log()}") + + if training_args.do_eval and data_args.eval_task_config: + runtime_timer.start("Create SFT Eval MapDataset") + os.makedirs(os.path.join(data_args.dataset_output_dir, 'eval'), exist_ok=True) + eval_output_idx_files = os.path.join(data_args.dataset_output_dir, 'eval', 'index.idx') + eval_dataset = create_dataset( + task_group=data_args.eval_dataset_path, + task_group_prob=data_args.eval_dataset_prob, + sub_dataset_type=data_args.eval_dataset_type, + is_valid=True, + **dataset_config, + ) + output_file_dict = {} + eval_dir = os.path.join(data_args.dataset_output_dir, 'eval') + for field in fields(dataclass): + output_path = os.path.join(eval_dir, f"{field.name}.bin") + output_file_dict[field.name] = output_path + eval_builder = SFTMMapIndexedDatasetBuilder(output_file_dict, save_dtype) + + for sequences in eval_dataset: + for sequence in sequences: + eval_builder.add_item(sequence) + eval_builder.end_document() + eval_builder.finalize(eval_output_idx_files) + logger.info(f"{runtime_timer.log()}") + + +if __name__ == "__main__": + main() diff --git a/ernie/ERNIE/examples/post-training/sft/estimate_training.py b/ernie/ERNIE/examples/post-training/sft/estimate_training.py new file mode 100644 index 0000000000000000000000000000000000000000..2844d0fba468f0e85768c539ae589754e7a5db53 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/estimate_training.py @@ -0,0 +1,287 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import numpy as np +from paddleformers.trainer.argparser import strtobool +from paddleformers.utils.log import logger + +from ernie.configuration import Ernie4_5_Config +from ernie.dataset.finetuning import create_dataset +from ernie.tokenizer import Ernie4_5_Tokenizer + + +def parse_arguments(): + """Parse command line arguments using the argparse library and return a Namespace object containing parameter values. + + Args: + None + + Returns: + Namespace object containing the following parameters: + --train_dataset_path: + str type. Used to specify the path of training dataset. + Default: examples/data/sft-train.jsonl. + --train_dataset_type: + str type. Used to specify the type of training dataset. Default: erniekit. + --train_dataset_prob: + str type. Used to specify the prob of training dataset. Default: 1.0. + --model_name_or_path: str type. Used to specify the model directory or filename. Default: ./inference. + --max_seq_len: int type. Used to specify the maximum input sequence length after tokenization. Default: 4096. + --num_epochs: int type. Number of epochs to train for. No default value provided. + --per_device_train_batch_size: int type. Batch size per device for training. Default: 1. + --out_file: str type. Filename to save results. Default: estimate_training.json. + --num_of_gpus: int type. Number of GPUs to use. Default: 8. + --tensor_parallel_degree: int type. Tensor parallelism degree. Default: 8. + --gradient_accumulation_steps: + int type. (Not yet implemented, to be done) + Number of steps to accumulate gradients before backward pass and parameter update. + Default: 16. + + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--train_dataset_path", + default="examples/data/sft-train.jsonl", + help="path of training datasets.", + ) + parser.add_argument( + "--train_dataset_prob", + default="1.0", + help="probabilities of training datasets.", + ) + parser.add_argument( + "--train_dataset_type", + default="erniekit", + help="type of training datasets.", + ) + parser.add_argument( + "--model_name_or_path", + default="./inference", + help="The directory of model.", + ) + parser.add_argument( + "--max_seq_len", + default=4096, + type=int, + help="The maximum total input sequence length after tokenization", + ) + parser.add_argument("--num_train_epochs", type=int, help="Number of epochs to train.") + parser.add_argument( + "--per_device_train_batch_size", + default=1, + type=int, + help="Batch size per GPU for training.", + ) + parser.add_argument( + "--out_file", + default="estimate_training.json", + help="The file to save results.", + ) + parser.add_argument("--num_of_gpus", type=int, default=8, help="The number of GPUs.") + parser.add_argument( + "--tensor_parallel_degree", + type=int, + default=8, + help="The degree of tensor_parallel.", + ) + parser.add_argument( + "--pipeline_parallel_degree", + type=int, + default=1, + help="The degree of pipeline.", + ) + parser.add_argument( + "--sharding_parallel_degree", + type=int, + default=1, + help="The degree of sharding parallel.", + ) + # TODO(gongenlei): support gradient_accumulation_steps args + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=0, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + + # Data args, should be same with training. + parser.add_argument("--seed", type=int, default=23, help="Random seed.") + parser.add_argument( + "--num_samples_each_epoch", + type=int, + default=6000000, + help="Number of samples per epoch. Used for SFT.", + ) + parser.add_argument( + "--max_estimate_samples", + type=int, + default=1e5, + help="Maximum number of samples used in estimation.", + ) + parser.add_argument( + "--greedy_intokens", + type=strtobool, + default="True", + help="Whether to use greedy intokens.", + ) + parser.add_argument( + "--random_shuffle", + type=strtobool, + default="True", + help="Whether to shuffle data.", + ) + return parser.parse_args() + + +def estimate_training(args): + """ + Estimate the number of steps required for training based on the training data. + + Args: + This arguments for the function is defined in parse_arguments(). + + Returns: + dict: Returns a dictionary containing information about the number of training steps required. This function has similar functionality to the function of the same name in utils.py. + + """ + if len(args.train_dataset_path) > 1: + logger.warning("Suggest to use max_steps instead of num_train_epochs for multi source dataset.") + logger.info( + "Multi source dataset detected, number of samples will be estimated by following rule. " + "num_samples = (source1_num_samples * prob1 + source2_num_samples * prob2 + ...) * epochs)" + ) + + tokenizer = Ernie4_5_Tokenizer.from_pretrained(args.model_name_or_path) + config = Ernie4_5_Config.from_pretrained(args.model_name_or_path) + logger.info("Start to estimate max training steps...") + dataset_config = { + "tokenizer": tokenizer, + "max_seq_len": args.max_seq_len, + "random_seed": args.seed, + "num_samples_each_epoch": args.num_samples_each_epoch, + "random_shuffle": args.random_shuffle, + "greedy_intokens": args.greedy_intokens, + } + train_dataset = create_dataset( + task_group=args.train_dataset_path, + task_group_prob=args.train_dataset_prob, + sub_dataset_type=args.train_dataset_type, + **dataset_config, + ) + train_dataset.estimate = True + + max_samples = train_dataset.max_estimate_samples + + if args.max_estimate_samples != -1: + # Set estimate samples to max_estimate_samples + logger.warning("The results between sampling and non-sampling methods may differ.") + train_dataset.max_estimate_samples = min(args.max_estimate_samples, train_dataset.max_estimate_samples) + + if train_dataset.max_estimate_samples > 0: + train_batches = 0 + train_tokens = 0 + for sequences in train_dataset: + if not train_dataset.estimate: + break + train_batches += 1 + for sequence in sequences: + train_tokens += len(sequence.token_ids) + + train_tokens *= args.num_train_epochs + train_batches *= args.num_train_epochs + if args.gradient_accumulation_steps > 0: + grad_acc_steps = args.gradient_accumulation_steps + else: + grad_acc_steps = np.round(min(max(train_tokens / 1e5, 1), 16)) + + data_parallel_degree = ( + args.num_of_gpus + // args.tensor_parallel_degree + // args.sharding_parallel_degree + // args.pipeline_parallel_degree + ) + global_batch_size = ( + args.per_device_train_batch_size * grad_acc_steps * data_parallel_degree * args.sharding_parallel_degree + ) + max_steps = np.ceil(train_batches / global_batch_size) + + if max_samples != train_dataset.max_estimate_samples: + max_steps *= max_samples / train_dataset.max_estimate_samples + train_tokens *= max_samples / train_dataset.max_estimate_samples + train_dataset.used_samples *= max_samples / train_dataset.max_estimate_samples + train_dataset.unused_samples *= max_samples / train_dataset.max_estimate_samples + + res = { + "num_train_epochs": int(args.num_train_epochs), + "max_steps": int(np.ceil(max_steps)), + "train_tokens": int(train_tokens), + "global_batch_size": int(global_batch_size), + "gradient_accumulation_steps": int(grad_acc_steps), + "warmup_steps": int(np.ceil(0.1 * max_steps)), + "num_of_gpus": int(args.num_of_gpus), + "per_device_train_batch_size": int(args.per_device_train_batch_size), + "tensor_parallel_degree": int(args.tensor_parallel_degree), + "pipeline_parallel_degree": int(args.pipeline_parallel_degree), + "sharding_parallel_degree": int(args.sharding_parallel_degree), + "seed": args.seed, + "num_samples_each_epoch": args.num_samples_each_epoch, + "max_seq_len": int(args.max_seq_len), + "valid": True, + "train_samples": int(max_samples * args.num_train_epochs), + "estimate_samples": int(train_dataset.max_estimate_samples), + "actual_train_samples": int(train_dataset.used_samples * args.num_train_epochs), + "skip_samples": int(train_dataset.unused_samples * args.num_train_epochs), + } + if train_batches / args.num_train_epochs / global_batch_size < 1: + logger.warning("This dataset is too small, you'd better enlarge your dataset.") + res["valid"] = False + else: + logger.error("No valid data found, please check your dataset format.") + res = { + "num_train_epochs": int(args.num_train_epochs), + "max_steps": 0, + "train_tokens": 0, + "num_of_gpus": int(args.num_of_gpus), + "per_device_train_batch_size": int(args.per_device_train_batch_size), + "tensor_parallel_degree": int(args.tensor_parallel_degree), + "pipeline_parallel_degree": int(args.pipeline_parallel_degree), + "sharding_parallel_degree": int(args.sharding_parallel_degree), + "seed": args.seed, + "num_samples_each_epoch": args.num_samples_each_epoch, + "max_seq_len": int(args.max_seq_len), + "valid": False, + "train_samples": 0, + } + out_file = getattr(args, 'out_file', None) + if out_file: + with open(args.out_file, "w", encoding="utf-8") as f: + json.dump(res, f) + + return res + + +if __name__ == "__main__": + args = parse_arguments() + enable_auth = False + if enable_auth: + from encryption.auth import auth_product + + product_name = auth_product(args.model_name_or_path) + + training_hp = estimate_training(args) + print(training_hp) diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/create_sft_data.sh b/ernie/ERNIE/examples/post-training/sft/scripts/create_sft_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..89643591c08578351b60acd0df3df06a159e064c --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/create_sft_data.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +dataset_output_path="./sft-data" +model_path="./to-your-model-path" +mkdir -p $dataset_output_path + +python examples/post-training/sft/create_sft_data.py \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --model_name_or_path $model_path \ + --num_of_gpus 1 \ + --gradient_accumulation_steps 8 \ + --per_device_train_batch_size 1 \ + --num_samples_each_epoch 6000000 \ + --dataset_output_dir $dataset_output_path \ + --seed 23 \ + --max_seq_len 8192 \ + --max_steps 1200 \ + --num_train_epochs 1.0 \ + --do_train \ + --do_eval \ + --tp_degree 1 \ + --sdp_degree 1 \ + --pp_degree 1 \ + --random_shuffle \ + --greedy_intokens \ + --estimation_output_file estimate_training.json diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..dac4bdb539f68b53a455329cc9b6058c70957cfb --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_32k.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-1} +model_path="ERNIE4.5T_0.3B" +task="sft_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 1 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..608bba42b0d85a11fd823545fcb4ede70286869b --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_8k.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-1} +model_path="ERNIE4.5T_0.3B" +task="sft_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 1 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_lora_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..51af753aff18a0ce3f159b9f6bb893b70a7481c9 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_lora_32k.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-1} +model_path="ERNIE4.5T_0.3B" +task="sft_lora_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --weight_decay 0.01 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 1 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --learning_rate 3e-4 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --release_grads 1 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_lora_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..b1c011d0d9120ec30cc08572159829e82c80ed30 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_0.3b_sft_lora_8k.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-1} +model_path="ERNIE4.5T_0.3B" +task="sft_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --weight_decay 0.01 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 1 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --learning_rate 3e-4 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --release_grads 1 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..81b5e368cade28e7a4ca07b476b76956e5899796 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_32k.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 4 \ + --pipeline_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "" +# Mem: 51GB - 64GB interval runtime: 14s diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..5fe8419ffa2084fb5442745bf0a3f0e01a93d0a3 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_8k.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 4 \ + --pipeline_parallel_degree 2 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_128k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_128k.sh new file mode 100644 index 0000000000000000000000000000000000000000..f610b944e7d85ea2c8b7e0803f01ebf30fdb86e7 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_128k.sh @@ -0,0 +1,95 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_lora_128k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 4 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 131072 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..ffda6a2790653b089583c0203cc27a6344a12e29 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_32k.sh @@ -0,0 +1,95 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_lora_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..8ea9432c91a57f2436aed6971dfa374510e81370 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_lora_8k.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --moe_multimodal_dispatch_use_allgather "" \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_128k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_128k.sh new file mode 100644 index 0000000000000000000000000000000000000000..0254515e47a690fc23c56fe0c67a9ded443682ee --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_128k.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export NCCL_DEBUG=WARN +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_wint8mix_lora_128k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 4 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 131072 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..5142f5e5a99d947239ac0901469f7f5a7217f7d3 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_32k.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export NCCL_DEBUG=WARN +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_wint8mix_lora_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..746385ac2f08cf66f3cbbd995bc0bbf6e79a4825 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_lite_sft_wint8mix_lora_8k.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export NCCL_DEBUG=WARN +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +model_path="ERNIE-4.5-21B-A3B" +task="sft_wint8mix_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1 \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 0 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..10eea4a04875ea874e81e3e4be95b88539704939 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_32k.sh @@ -0,0 +1,101 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-14} +model_path="ERNIE4.5T_chat" +task="sft_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 8 \ + --pipeline_parallel_degree ${nnodes} \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --offload_optim \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b495b4e2e3c2065824a1e7d848da63b542de520 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_8k.sh @@ -0,0 +1,101 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-12} +model_path="ERNIE4.5T_chat" +task="sft_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 8 \ + --pipeline_parallel_degree ${nnodes} \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --offload_optim \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "" diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_fp8_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_fp8_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..ee776f8e27e1c980355700db56c9a3b8fcb7f16b --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_fp8_8k.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-2} +model_path="ERNIE4.5T_chat" +task="sft_fp_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 8 \ + --pipeline_parallel_degree 2 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --recompute_use_reentrant True \ + --weight_quantize_algo "fp8linear" \ + --apply_hadamard True \ + --optim "adamw_custom" \ + --use_lowprecision_moment True \ + --tensorwise_offload_optimizer True \ + --pp_seg_method "[0,29,57]" \ + --optim_shard_num 8 \ + --unified_checkpoint_config "ignore_merge_optimizer" \ + --num_nextn_predict_layers 0 \ + # --ignore_save_lr_and_optim 1 \ + # --ignore_load_lr_and_optim 1 \ diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_lora_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..664ae9b3ad4b92a3b0dba466bea1d1002737e5a2 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_lora_32k.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-12} +model_path="ERNIE4.5T_chat" +task="sft_lora_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 8 \ + --pipeline_parallel_degree ${nnodes} \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --offload_optim \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_lora_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3ed14330dd143e2a5c8f86e56124212a23b92ac --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_lora_8k.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-12} +model_path="ERNIE4.5T_chat" +task="sft_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 8 \ + --pipeline_parallel_degree ${nnodes} \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 3e-4 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --offload_optim \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora \ + --lora_rank 32 diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_wint8mix_lora_32k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_wint8mix_lora_32k.sh new file mode 100644 index 0000000000000000000000000000000000000000..01ea365385cdd4a4e6fbbfb1b6962323ad75f0bc --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_wint8mix_lora_32k.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-12} +model_path="ERNIE4.5T_chat" +task="sft_wint8mix_lora_32k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 8 \ + --pipeline_parallel_degree ${nnodes} \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 32768 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --offload_optim \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_wint8mix_lora_8k.sh b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_wint8mix_lora_8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..22622e2bb6091d8a0069c9834c6b0f3ecdbbec23 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/scripts/run_sft_wint8mix_lora_8k.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINERS_NUM +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +export PYTHONPATH=$(dirname "$0")/../../../..:$PYTHONPATH +export FLAGS_set_to_1d=False +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_dataloader_use_file_descriptor=False + +master_ip=${1:-} +nnodes=${2:-12} +model_path="ERNIE4.5T_chat" +task="sft_wint8mix_lora_8k" +paddle_log_dir="${model_path}_${task}_log" +vdl_log_dir="${model_path}_${task}_vdl" +output_dir="${model_path}_${task}_checkpoint" + +rm -rf ${log_dir} + +python -m paddle.distributed.launch \ + --log_dir ${paddle_log_dir} \ + --gpus 0,1,2,3,4,5,6,7 \ + --master ${master_ip}:8080 \ + --nnodes ${nnodes} \ + examples/post-training/sft/train.py \ + --logging_dir ${vdl_log_dir} \ + --model_name_or_path ${model_path} \ + --output_dir ${output_dir} \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --train_dataset_path "examples/data/sft-train.jsonl" \ + --train_dataset_prob "1.0" \ + --train_dataset_type "erniekit" \ + --eval_dataset_path "examples/data/sft-eval.jsonl" \ + --eval_dataset_prob "1.0" \ + --eval_dataset_type "erniekit" \ + --max_steps 100 \ + --max_evaluate_steps 10000 \ + --num_train_epochs 1 \ + --save_steps 10000000 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --tensor_parallel_degree 8 \ + --pipeline_parallel_degree ${nnodes} \ + --sharding_parallel_degree 1 \ + --sharding stage1 \ + --max_seq_len 8192 \ + --seed 23 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 20 \ + --weight_decay 0.1 \ + --learning_rate 1e-5 \ + --min_lr 1e-6 \ + --num_samples_each_epoch 6000000 \ + --bf16 \ + --fp16_opt_level O2 \ + --disable_tqdm True \ + --recompute 1 \ + --recompute_granularity "full" \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" "flash_attn" "matmul" "matmul_v2" "fused_gemm_epilogue" \ + --amp_custom_black_list "reduce_sum" "softmax_with_cross_entropy" "c_softmax_with_cross_entropy" "elementwise_div" "sin" "cos" \ + --use_flash_attention 1 \ + --use_sparse_head_and_loss_fn 1 \ + --use_attn_mask_start_row_indices 1 \ + --pipeline_parallel_config "disable_partial_send_recv enable_clear_every_step_cache" \ + --greedy_intokens 1 \ + --release_grads 1 \ + --lr_scheduler_type cosine \ + --sequence_parallel 1 \ + --moe_group "mp" \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --adam_epsilon 1e-8 \ + --amp_master_grad 1 \ + --fuse_rope 1 \ + --disable_ckpt_quant 1 \ + --offload_optim \ + --recompute_use_reentrant True \ + --unified_checkpoint_config "async_save" \ + --lora 1 \ + --lora_rank 32 \ + --weight_quantize_algo weight_only_mix diff --git a/ernie/ERNIE/examples/post-training/sft/sft_utils.py b/ernie/ERNIE/examples/post-training/sft/sft_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd28b99eb4531e4e37aa3a2a67c3698f17e65f17 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/sft_utils.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SFT utils""" +from dataclasses import dataclass, field +from typing import Optional + +from train import DataArgument, SFTTrainingArguments + + +class DataGenerator: + """Generates an infinite stream of examples""" + + def __init__(self, data_source): + """ + Initializes the iterator for a given data source. + + Args: + data_source : IterableDataset + + Returns: + None. - Initialization only. No return value. + """ + self.data_source_iter = iter(data_source) + self.data_source = data_source + + def __iter__(self): + """ + Returns: + Iterator: The iterator object itself. + """ + return self + + def __next__(self): + """ + Get the next item from the iterator. If there are no more items left, reset the iterator. + + Returns: + Any: The next item from the iterator. + """ + try: + return next(self.data_source_iter) + except StopIteration: + self.data_source_iter = iter(self.data_source) + return next(self.data_source_iter) + + +@dataclass +class BuildSFTTrainingArguments(SFTTrainingArguments): + """TrainingArguments for building SFT MapDataset""" + + output_dir: Optional[str] = field( + default=None, + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + num_of_gpus: int = field( + default=1, + metadata={"help": "The number of GPUs."}, + ) + estimation_output_file: Optional[str] = field( + default=None, metadata={"help": "The file to save estimation results."} + ) + pp_degree: int = field(default=1, metadata={"help": "Pipeline parallel degree."}) + sdp_degree: int = field(default=1, metadata={"help": "Sharding parallel degree."}) + tp_degree: int = field(default=1, metadata={"help": "Tensor parallel degree."}) + + +@dataclass +class BuildDataArgument(DataArgument): + """DataArgument for building SFT MapDataset""" + + dataset_output_dir: Optional[str] = field( + default=None, + metadata={"help": "The output directory where the SFT MapDataset will be written."}, + ) diff --git a/ernie/ERNIE/examples/post-training/sft/train.py b/ernie/ERNIE/examples/post-training/sft/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c46d8f03f977d94cc68d2767da2d83bbe5838d --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/train.py @@ -0,0 +1,976 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Training Ernie Model. """ + +import gc +import importlib.util +import math +import os +import sys +import time +import json +from dataclasses import dataclass, field +from functools import partial +from typing import Optional + +if importlib.util.find_spec("triton") is not None: + try: + import use_triton_in_paddle + + use_triton_in_paddle.make_triton_compatible_with_paddle() + except Exception as _: + raise RuntimeError( + "Triton is installed, but not yet compatible with Paddle. " + "Please run 'python -m pip install use-triton-in-paddle' to enable Triton support in Paddle." + ) + +import paddle +from paddleformers.trainer import ( + IntervalStrategy, + PdArgumentParser, + RuntimeTimer, + TrainingArguments, + get_last_checkpoint, + set_seed, +) +from paddleformers.trainer.trainer_utils import ShardingOption +from paddleformers.transformers.model_utils import unwrap_model +from paddleformers.utils.log import logger + +from ernie.callbacks import LayerwiseDropoutCallback +from ernie.configuration import Ernie4_5_MoeConfig +from ernie.modeling_moe import Ernie4_5_MoeForCausalLM +from ernie.modeling_moe_pp import Ernie4_5_MoeForCausalLMPipe +from ernie.tokenizer import Ernie4_5_Tokenizer +from ernie.utils.common_utils import ( + add_start_docstrings, + calculate_effective_tokens, + check_refined_recompute, + estimate_training, + save_stop_info, +) + +# isort: off +from trainer import ErnieMoETrainer + +# isort: on + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class SFTTrainingArguments(TrainingArguments): + """SFT Training Arguments""" + + unified_checkpoint: bool = field( + default=True, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise " + "add for grad accumulation in the backward of nn.Linear ." + }, + ) + unified_checkpoint_config: Optional[str] = field( + default="", + metadata={ + "help": ( + "Configs to unify hybrid parallel checkpoint.\n" + "Following options are supports:\n" + "- skip_save_model_weight: do not save model weights when the masters weight exist\n" + "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" + " 2. if master weights does not exist, convert model weights" + " to master weights when needed\n" + "- async_save: enable asynchronous saving checkpoints to disk\n" + "- enable_all_options: enable all optimization configurations\n" + ) + }, + ) + decay_steps: int = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, " + "will use the min_learning_rate." + }, + ) + max_estimate_samples: int = field( + default=1e5, + metadata={"help": "Maximum number of samples used in estimation."}, + ) + dropout_warmup_steps: int = field( + default=0, + metadata={"help": "dropout warmup steps"}, + ) + hidden_dropout_prob: float = field( + default=0.0, + metadata={"help": "dropout probability for hidden layers"}, + ) + attention_probs_dropout_prob: float = field( + default=0.0, + metadata={"help": "dropout probability for attention layers"}, + ) + disable_ckpt_quant: bool = field( + default=False, + metadata={"help": "Whether disable checkpoint quantization."}, + ) + sequence_parallel: bool = field( + default=True, metadata={"help": "Whether to use sequence_parallel"} + ) + layerwise_lr_decay_bound: Optional[float] = field( + default=1.0, + metadata={ + "help": "Use a large learning rate for the top layers and " + "a small learning rate for the bottom layers. 1.0: Do not use this strategy." + }, + ) + use_sp_callback: bool = field( + default=False, + metadata={ + "help": "Using the SP callback will skip the implementation of SPHook " + "to avoid redundant gradient computation." + }, + ) + # Quantiztaion + weight_quantize_algo: str = field( + default=None, + metadata={ + "help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'." + }, + ) + + +@dataclass +class DataArgument: + """Data Argument""" + + train_dataset_type: str = field( + default="erniekit", + metadata={"help": "List contains type of training datasets."}, + ) + train_dataset_path: str = field( + default="examples/data/sft-train.jsonl", + metadata={"help": "List contains path of training data sources."}, + ) + train_dataset_prob: str = field( + default="1.0", + metadata={"help": "List contains probabilities of training data sources."}, + ) + eval_dataset_type: str = field( + default="erniekit", metadata={"help": "List contains type of eval datasets."} + ) + eval_dataset_path: str = field( + default="examples/data/sft-eval.jsonl", + metadata={"help": "List contains path of eval data sources."}, + ) + eval_dataset_prob: str = field( + default="1.0", + metadata={"help": "List contains probabilities of eval data sources."}, + ) + max_seq_len: int = field( + default=4096, metadata={"help": "Maximum sequence length."} + ) + in_tokens_batching: bool = field( + default=True, + metadata={"help": "Whether to using in tokens batching strategy."}, + ) + num_samples_each_epoch: int = field( + default=100000, + metadata={"help": "Number of samples per epoch. Used for SFT."}, + ) + num_comparisons: int = field( + default=6, metadata={"help": "Number of candidate responses."} + ) + use_cls: bool = field( + default=True, + metadata={"help": "Whether to use cls to predict RM score."}, + ) + sft_benchmark: bool = field( + default=False, + metadata={"help": "Whether to calculate effective token per second"}, + ) + random_shuffle: bool = field( + default=True, + metadata={ + "help": "Whether to enable authorize code for privatization. Defaults to False." + }, + ) + greedy_intokens: bool = field( + default=True, + metadata={"help": "Whether to use greedy_intokens packing method."}, + ) + dataset_type: str = field( + default="iterable", + metadata={ + "help": ( + "Specify the type of dataset to use. Options are 'iterable' " + "for 'IterableDataset' and 'map' for 'MapDataset'." + ) + }, + ) + offline_dataset_path: str = field( + default=None, + metadata={ + "help": ( + "If 'dataset_type' is set to 'map', this field is required to " + "specify the path to the offline dataset." + ) + }, + ) + + +@dataclass +class ModelArgument: + """Model Argument""" + + model_name_or_path: str = field( + default="ernie-bot", + metadata={"help": "Pretrained model name or path to local directory."}, + ) + tensor_parallel_output: bool = field( + default=True, + metadata={ + "help": "If set to True, this option is used with fleet.meta_parallel. " + "ParallelCrossEntropy to calculate cross-entropy loss for parallel model." + }, + ) + # LoRA + lora: bool = field( + default=False, metadata={"help": "Whether to use LoRA technique."} + ) + lora_rank: int = field(default=8, metadata={"help": "Lora rank."}) + lora_path: str = field( + default=None, metadata={"help": "Initialize lora state dict."} + ) + rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) + lora_plus_scale: float = field( + default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"} + ) + lora_alpha: int = field(default=-1, metadata={"help": "lora_alpha"}) + rslora_plus: bool = field( + default=False, metadata={"help": "Strengthen lora performance"} + ) + use_flash_attention: bool = field( + default=True, metadata={"help": "Whether to use flash attention"} + ) + use_sparse_head_and_loss_fn: bool = field( + default=False, + metadata={"help": "Whether to use sparse LM Head and loss function."}, + ) + use_fused_head_and_loss_fn: bool = field( + default=False, + metadata={"help": "Whether to fuse LM Head and loss function."}, + ) + recompute_granularity: str = field( + default="full", + metadata={ + "help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`. " + " `full` means complete all transformers, `full_attn` indicates only recompute all self attention parts," + " `core_attn` indicates that only the `softmax (qkT) v` part is recomputed. Note: In terms of memory usage," + " `core_attn` > `full_attn` > `full`, if the selected policy generates an OOM error, the recompute can be" + " changed appropriately recompute_granularity. (default: `full`)" + }, + ) + no_recompute_layers: Optional[int] = field( + default=None, + metadata={ + "help": "Specify the full transformer layers that should not be recomputed." + }, + ) + offload_recompute_inputs: bool = field( + default=False, + metadata={ + "help": "Whether to offload input Tensors of recompute to Pinned-Memory/CPU." + }, + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + pp_seg_method: str = field( + default="layer:Ernie4_5_DecoderLayer|EmptyLayer", + metadata={ + "help": ( + "The method used to segment the pipeline layers among pipeline stages. " + "Possible values include `layer:Ernie4_5_DecoderLayer`, " + "`layer:Ernie4_5_DecoderLayer|Empty`, `uniform`, `[0, 30, 59]`." + ) + }, + ) + fuse_linear: bool = field( + default=False, metadata={"help": "Whether to use fused_gemm_epilogue"} + ) + fuse_rope: bool = field( + default=False, + metadata={"help": "Whether to fuse rotary postition embedding"}, + ) + fuse_softmax_mask: bool = field( + default=False, metadata={"help": "Whether to fuse softmax and add"} + ) + fuse_rms_norm: bool = field( + default=True, metadata={"help": "Whether to fuse RMSNorm for efficiency"} + ) + fuse_swiglu: bool = field( + default=True, + metadata={ + "help": "Whether to fuse SwiGLU projection and activation for efficiency" + }, + ) + fuse_gate_detach_matmul: bool = field( + default=True, + metadata={ + "help": "Whether to use the fused gate-detach matmul implementation." + }, + ) + use_attn_mask_start_row_indices: bool = field( + default=True, + metadata={ + "help": "Whether to use attn_mask_start_row_indices in flash attention." + }, + ) + use_sparse_flash_attn: bool = field( + default=True, + metadata={ + "help": "Under use attn_mask_start_row_indices=True, whether use sparse flash attention or not." + }, + ) + recompute_use_reentrant: bool = field( + default=False, + metadata={"help": "recompute_use_reentrant"}, + ) + continue_training: bool = field( + default=True, + metadata={ + "help": ( + "Whether to train from existing paddleformers model weights.\n" + "If set True, the model_name_or_path argument must exist in the paddleformers models." + ) + }, + ) + add_tail_layers: int = field( + default=False, + metadata={ + "help": ( + "Add EmptyLayer after Ernie4_5_DecoderLayerPipe. Only for Pipeline Parallel" + ) + }, + ) + + # MoE + use_recompute_moe: Optional[bool] = field( + default=False, metadata={"help": "Whether to apply recompute to MoE layers."} + ) + moe_group: Optional[str] = field( + default="dummy", + metadata={"help": "MoE communication group. Supported values: 'mp', 'dummy'."}, + ) + moe_multimodal_dispatch_use_allgather: Optional[str] = field( + default="v2-alltoall-unpad", + metadata={"help": "moe dispatch use unpad allgather strategy."}, + ) + moe_group_experts: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to apply group-wise processing to expert gate logits." + }, + ) + moe_aux_loss_lambda: Optional[float] = field( + default=1e-5, + metadata={"help": "Lambda value for moe aux loss."}, + ) + moe_orthogonal_loss_lambda: Optional[float] = field( + default=0.0, + metadata={"help": "Lambda value for moe orthogonal loss."}, + ) + moe_z_loss_lambda: Optional[float] = field( + default=0.0, + metadata={"help": "Lambda value for moe z loss."}, + ) + moe_use_hard_gate: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to use hard gate. If `moe_use_hard_gate` is True, a hard " + "routing strategy is used instead of a learned gating network." + }, + ) + moe_use_aux_free: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to use auxiliary‑loss‑free routing. If True, " + "load balancing (using expert bias adjustments) is used instead " + "of traditional auxiliary loss for MoE." + }, + ) + + apply_hadamard: bool = field( + default=True, metadata={"help": "Whether to apply hadamard"} + ) + hadamard_block_size: int = field( + default=32, metadata={"help": "hadamard block size"} + ) + quant_input_grad: bool = field( + default=False, metadata={"help": "Whether to quantize input grad"} + ) + quant_weight_grad: bool = field( + default=False, metadata={"help": "Whether to quantize weight grad"} + ) + apply_online_actscale_step: int = field( + default=200, + metadata={ + "help": "Use online activation scale for first N step to keep stable training." + }, + ) + actscale_moving_rate: float = field( + default=0.01, metadata={"help": "EMA moving_rate for activation scale"} + ) + fp8_format_type: str = field(default="hybrid", metadata={"help": "FP8 Format"}) + num_nextn_predict_layers: int = field( + default=0, metadata={"help": "Number of nextn predict layers."} + ) + multi_token_pred_lambda: float = field( + default=0.3, metadata={"help": "multi token pred lambda"} + ) + use_recompute_mtp: bool = field( + default=False, metadata={"help": "Whether to use recompute_mtp"} + ) + + +def main(): + """ + The main function that creates a model with parameters configured from pretrained settings, + arguments, and training the sft/lora model. + + Args: + None + + Returns: + None + """ + parser = PdArgumentParser((ModelArgument, DataArgument, SFTTrainingArguments)) + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.sequence_parallel: + if training_args.pipeline_parallel_degree > 1: + assert ( + hasattr(training_args, "pipeline_parallel_config") + and "disable_partial_send_recv" + in training_args.pipeline_parallel_config + ), "Should set '--pipeline_parallel_config disable_partial_send_recv' in bash script for pp with sp." + if training_args.tensor_parallel_degree <= 1: + training_args.sequence_parallel = False + logger.info("Tensor_parallel_degree = 1. Set sequence_parallel to False.") + if model_args.lora and model_args.fuse_linear: + model_args.fuse_linear = False + logger.info("LoRA does not support fuse_linear. Set fuse_linear to False.") + if training_args.recompute and model_args.offload_recompute_inputs: + assert ( + model_args.recompute_use_reentrant + ), "offload_recompute_inputs can only be enabled along with reentrant recompute." + assert ( + model_args.recompute_granularity == "full" + ), "To save device memory, please try higher recompute_granularity before enabling offload_recompute_inputs." + if training_args.pipeline_parallel_degree > 1: + logger.debug( + "offload_recompute_inputs is not supported in pipeline parallel. Set offload_recompute_inputs to False." + ) + model_args.offload_recompute_inputs = False + + runtime_timer = RuntimeTimer("Training") + + if training_args.sharding_parallel_degree > 1: + if ( + ShardingOption.SHARD_GRAD_OP in training_args.sharding + or ShardingOption.FULL_SHARD in training_args.sharding + ): + if training_args.release_grads is True: + training_args.release_grads = False + + # checkpoint O1 quantization is open by default. + if ( + not training_args.disable_ckpt_quant + and training_args.ckpt_quant_stage == "O0" + and not model_args.lora + ): + training_args.ckpt_quant_stage = "O1" + elif training_args.disable_ckpt_quant: + training_args.ckpt_quant_stage = "O0" + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + if data_args.sft_benchmark: + training_args.do_train = True + training_args.do_export = False + training_args.do_predict = False + training_args.do_eval = False + training_args.overwrite_output_dir = True + training_args.load_best_model_at_end = False + training_args.report_to = [] + training_args.save_strategy = IntervalStrategy.NO + training_args.evaluation_strategy = IntervalStrategy.NO + if not training_args.disable_tqdm: + training_args.logging_steps = 1 + training_args.logging_strategy = IntervalStrategy.STEPS + + paddle.set_device(training_args.device) + + set_seed(training_args.seed) + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: " + f"{training_args.world_size}, distributed training: {bool(training_args.local_rank != -1)}, " + f"16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + uc_async_save = ( + training_args.unified_checkpoint + and "async_save" in training_args.unified_checkpoint_config + ) + last_checkpoint = get_last_checkpoint( + training_args.output_dir, + signal_folder=training_args.output_signal_dir, + uc_async_save=uc_async_save, + ) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + if ( + last_checkpoint is not None + and model_args.continue_training + and not model_args.lora + ): + model_args.continue_training = False + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. Set `continue_training` to False." + ) + + # Set the dtype for loading model + dtype = paddle.get_default_dtype() + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + logger.info("Start to load model ...") + + # Detect torch model. + config_path = os.path.join(model_args.model_name_or_path, "config.json") + with open(config_path, "r", encoding="utf-8") as f: + config_dict = json.load(f) + if "torch_dtype" in config_dict: + raise ValueError( + "Unsupported weight format: Torch weights are not compatible with Paddle model currently." + ) + + model_class = Ernie4_5_MoeForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = Ernie4_5_MoeForCausalLMPipe + if ( + model_args.moe_group.lower() in {"data", "dp"} + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + # fuse_softmax_mask only support for rocm. + if not paddle.is_compiled_with_rocm(): + if model_args.fuse_softmax_mask: + logger.warning( + "The fuse_softmax_mask flag is only available when using the ROCM version of paddlepaddle. " + ) + model_args.fuse_softmax_mask = False + + check_refined_recompute( + training_args.refined_recompute, + training_args.sequence_parallel, + lora=model_args.lora, + ) + + runtime_timer.start("basemodel loading time") + if training_args.weight_quantize_algo is not None: + if training_args.weight_quantize_algo == "weight_only_mix": + weight_quantize_algo = { + "weight_only_int4": [".*mlp.experts.*"], + "weight_only_int8": [ + ".*self_attn.qkv_proj.*", + ".*self_attn.o_proj.*", + ".*mlp.up_gate_proj.*", + ".*mlp.down_proj.*", + ], + } + else: + weight_quantize_algo = training_args.weight_quantize_algo + quantization_config = dict( + weight_quantize_algo=weight_quantize_algo, + ignore_modules=[".*out_linear.*"], + apply_hadamard=model_args.apply_hadamard, + hadamard_block_size=model_args.hadamard_block_size, + quant_input_grad=model_args.quant_input_grad, + quant_weight_grad=model_args.quant_weight_grad, + apply_online_actscale_step=model_args.apply_online_actscale_step, + actscale_moving_rate=model_args.actscale_moving_rate, + fp8_format_type=model_args.fp8_format_type, + ) + if training_args.weight_quantize_algo == "fp8linear": + quantization_config.update( + { + "dense_quant_type": "tensor_wise_fp8", + "moe_quant_type": "tensor_wise_fp8", + "quantization": "mix_quant", + } + ) + else: + quantization_config = dict( + weight_quantize_algo=training_args.weight_quantize_algo + ) + + model_config = Ernie4_5_MoeConfig.from_pretrained( + model_args.model_name_or_path, + dtype=dtype, + quantization_config=quantization_config, + ) + model_config.tensor_parallel_degree = training_args.tensor_parallel_degree + model_config.tensor_parallel_rank = training_args.tensor_parallel_rank + model_config.recompute = training_args.recompute + model_config.recompute_granularity = model_args.recompute_granularity + model_config.no_recompute_layers = model_args.no_recompute_layers + model_config.refined_recompute = training_args.refined_recompute + model_config.offload_recompute_inputs = model_args.offload_recompute_inputs + model_config.use_flash_attention = model_args.use_flash_attention + model_config.sequence_parallel = training_args.sequence_parallel + model_config.use_sparse_head_and_loss_fn = model_args.use_sparse_head_and_loss_fn + model_config.use_fused_head_and_loss_fn = model_args.use_fused_head_and_loss_fn + model_config.tensor_parallel_output = model_args.tensor_parallel_output + model_config.virtual_pp_degree = model_args.virtual_pp_degree + model_config.pp_seg_method = model_args.pp_seg_method + model_config.add_tail_layers = model_args.add_tail_layers + model_config.fuse_linear = model_args.fuse_linear + model_config.fuse_rope = model_args.fuse_rope + model_config.fuse_softmax_mask = model_args.fuse_softmax_mask + model_config.fuse_rms_norm = model_args.fuse_rms_norm + model_config.fuse_swiglu = model_args.fuse_swiglu + model_config.fuse_gate_detach_matmul = model_args.fuse_gate_detach_matmul + model_config.max_sequence_length = data_args.max_seq_len + model_config.recompute_use_reentrant = model_args.recompute_use_reentrant + model_config.use_sparse_flash_attn = model_args.use_sparse_flash_attn + model_config.use_recompute_moe = model_args.use_recompute_moe + model_config.moe_group = model_args.moe_group + model_config.moe_group_experts = model_args.moe_group_experts + model_config.moe_aux_loss_lambda = model_args.moe_aux_loss_lambda + model_config.moe_orthogonal_loss_lambda = model_args.moe_orthogonal_loss_lambda + model_config.moe_z_loss_lambda = model_args.moe_z_loss_lambda + model_config.moe_use_hard_gate = model_args.moe_use_hard_gate + model_config.moe_multimodal_dispatch_use_allgather = ( + model_args.moe_multimodal_dispatch_use_allgather + ) + if model_args.moe_use_aux_free is False: + model_config.moe_use_aux_free = model_args.moe_use_aux_free + model_config.hidden_dropout_prob = training_args.hidden_dropout_prob + model_config.attention_probs_dropout_prob = ( + training_args.attention_probs_dropout_prob + ) + model_config.num_acc_steps = training_args.gradient_accumulation_steps + model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers + model_config.multi_token_pred_lambda = model_args.multi_token_pred_lambda + model_config.use_recompute_mtp = model_args.use_recompute_mtp + if model_config.moe_num_experts is None or model_config.moe_num_experts == 0: + model_config.moe_group = ( + "dummy" if model_args.moe_group == "mp" else model_args.moe_group + ) + + if ( + training_args.pipeline_parallel_degree > 1 + and training_args.weight_quantize_algo is not None + and model_config.tie_word_embeddings + ): + raise NotImplementedError( + "Quantization is not supported for models with tied lm_head and word_embedding \ + weights when using Pipeline Parallelism (PP)." + ) + + if model_args.continue_training or training_args.weight_quantize_algo is not None: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=model_config, + ) + else: + model = model_class.from_config(model_config, dtype=dtype) + + if model.config.head_dim is None: + del model.config.head_dim + + paddle.device.cuda.empty_cache() + logger.info("Loading model successfully !") + logger.debug(f"Model config: {model.config}") + logger.info(f"{runtime_timer.log()}") + + tokenizer = Ernie4_5_Tokenizer.from_pretrained( + model_args.model_name_or_path, + ) + + logger.info("Start to create dataset ...") + dataset_config = { + "tokenizer": tokenizer, + "max_seq_len": data_args.max_seq_len, + "random_seed": training_args.seed, + "num_replicas": training_args.dataset_world_size, + "rank": training_args.dataset_rank, + } + from ernie.dataset.finetuning import collate_fn + + if data_args.dataset_type == "map": + from ernie.dataset.finetuning import ( + create_indexed_dataset as create_dataset, + ) + else: + from ernie.dataset.finetuning import create_dataset + dataset_config.update( + { + "num_samples_each_epoch": data_args.num_samples_each_epoch, + "random_shuffle": data_args.random_shuffle, + "greedy_intokens": data_args.greedy_intokens, + } + ) + + if training_args.should_load_dataset: + if data_args.dataset_type == "map": + train_file_path = os.path.join(data_args.offline_dataset_path, "train") + train_dataset = create_dataset(data_file_prefix=train_file_path) + else: + train_dataset = create_dataset( + task_group=data_args.train_dataset_path, + task_group_prob=data_args.train_dataset_prob, + sub_dataset_type=data_args.train_dataset_type, + **dataset_config, + ) + + if training_args.do_eval and training_args.should_load_dataset: + if data_args.dataset_type == "map": + eval_file_path = os.path.join(data_args.offline_dataset_path, "eval") + eval_dataset = create_dataset(data_file_prefix=eval_file_path) + else: + eval_dataset = create_dataset( + task_group=data_args.eval_dataset_path, + task_group_prob=data_args.eval_dataset_prob, + sub_dataset_type=data_args.eval_dataset_type, + is_valid=True, + **dataset_config, + ) + + logger.info("Creating dataset successfully ...") + + data_collator = partial( + collate_fn, + tokenizer=tokenizer, + model_args=model_args, + max_seq_len=data_args.max_seq_len, + ) + + if model_args.lora: + logger.info("Start to wrap model with LoRA config ...") + + from ernie.utils.peft_utils import initialize_lora_model + + model = initialize_lora_model( + model=model, + training_args=training_args, + model_args=model_args, + resume_from_checkpoint=last_checkpoint is not None, + dtype=dtype, + ) + + if training_args.max_steps == -1: + if training_args.should_load_dataset and paddle.distributed.get_rank() == 0: + if data_args.dataset_type != "map": + training_args.max_steps = estimate_training( + train_dataset, data_args, training_args, model_args + ) + del train_dataset + gc.collect() + train_dataset = create_dataset( + task_group=data_args.train_dataset_path, + task_group_prob=data_args.train_dataset_prob, + sub_dataset_type=data_args.train_dataset_type, + **dataset_config, + ) + else: + global_batch_size = ( + training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * training_args.dataset_world_size + ) + training_args.max_steps = math.ceil( + len(train_dataset) / global_batch_size + ) + + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + max_steps = paddle.to_tensor([training_args.max_steps]) + paddle.distributed.broadcast(max_steps, src=0) + training_args.max_steps = int(max_steps.item()) + if training_args.max_steps <= 0: + raise ValueError( + f"Invalid max_steps: {training_args.max_steps}. Please check your dataset" + ) + + logger.info(f"Re-setting training_args.max_steps to {training_args.max_steps}.") + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + + if training_args.save_strategy == IntervalStrategy.EPOCH: + training_args.save_strategy = IntervalStrategy.STEPS + training_args.save_steps = int( + training_args.max_steps / training_args.num_train_epochs + ) + if training_args.evaluation_strategy == IntervalStrategy.EPOCH: + training_args.evaluation_strategy = IntervalStrategy.STEPS + training_args.eval_steps = int( + training_args.max_steps / training_args.num_train_epochs + ) + if training_args.logging_strategy == IntervalStrategy.EPOCH: + training_args.logging_strategy = IntervalStrategy.STEPS + training_args.logging_steps = int( + training_args.max_steps / training_args.num_train_epochs + ) + + if ( + not model_args.use_sparse_head_and_loss_fn + and not training_args.prediction_loss_only + ): + unwraped_model = unwrap_model(model) + if hasattr(model, "compute_metrics"): + compute_metrics = model.compute_metrics + elif hasattr(unwraped_model, "compute_metrics"): + # NOTE(liuting): if model is LoRAModel, we need to unwrap it first. + compute_metrics = unwraped_model.compute_metrics + else: + compute_metrics = None + else: + compute_metrics = None + + trainer = ErnieMoETrainer( + model=model, + args=training_args, + train_dataset=( + train_dataset + if training_args.do_train and training_args.should_load_dataset + else None + ), + eval_dataset=( + eval_dataset + if training_args.do_eval and training_args.should_load_dataset + else None + ), + tokenizer=tokenizer, + do_generation=False, + data_args=data_args, + data_collator=data_collator, + compute_metrics=compute_metrics, + ) + trainable_parameters = [ + p + for p in model.parameters() + if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name) + ] + trainer.set_optimizer_grouped_parameters(trainable_parameters) + + if training_args.hidden_dropout_prob or training_args.attention_probs_dropout_prob: + trainer.add_callback(LayerwiseDropoutCallback()) + + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) + if not data_args.sft_benchmark: + runtime_timer.start("model saving time") + trainer.save_model( + merge_tensor_parallel=training_args.tensor_parallel_degree > 1 + ) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + logger.info(f"{runtime_timer.log()}") + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + if ( + training_args.should_load_dataset + and data_args.sft_benchmark + and paddle.distributed.get_rank() == 0 + ): + del train_dataset + gc.collect() + train_dataset = create_dataset( + task_group=data_args.train_dataset_path, + task_group_prob=data_args.train_dataset_prob, + sub_dataset_type=data_args.train_dataset_type, + **dataset_config, + ) + total_effective_tokens, total_tokens = calculate_effective_tokens( + training_args, train_dataset, data_args.max_seq_len + ) + + effective_tokens_per_second = ( + total_effective_tokens / train_result.metrics["train_runtime"] + ) + total_tokens_per_second = ( + total_tokens / train_result.metrics["train_runtime"] + ) + effective_ratio = 100 * total_effective_tokens / total_tokens + logger.info( + "[timelog] {}: {:.2f} % ({}) ".format( + "Effective ratio", + effective_ratio, + time.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + logger.info( + "[timelog] {}: {:.2f} token/s ({}) ".format( + "Effective tokens per second", + effective_tokens_per_second, + time.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + logger.info( + "[timelog] {}: {:.2f} token/s ({}) ".format( + "Tokens per second", + total_tokens_per_second, + time.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + + if training_args.do_eval: + eval_result = trainer.evaluate() + trainer.log_metrics("eval", eval_result) + # NOTE(gongenlei): set combined=False to avoid overwriting errors on AFS + trainer.save_metrics("eval", eval_result, combined=False) + + save_stop_info( + training_args, + trainer.state.global_step, + outside_eval=training_args.do_eval, + outside_predict=0, + ) + + +if __name__ == "__main__": + with paddle.amp.auto_cast(enable=False): + main() diff --git a/ernie/ERNIE/examples/post-training/sft/trainer.py b/ernie/ERNIE/examples/post-training/sft/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3854a711db20fa100b06c2c1e24cca10d8f609 --- /dev/null +++ b/ernie/ERNIE/examples/post-training/sft/trainer.py @@ -0,0 +1,738 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Trainer for Ernie-MoE model with enhanced distributed training support. +""" + +import inspect +import os +import random +from collections import OrderedDict +from functools import partial +from typing import Dict + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.distributed.fleet.utils.sequence_parallel_utils import register_sequence_parallel_allreduce_hooks +from paddleformers.peft import LoRAModel +from paddleformers.trainer import Trainer +from paddleformers.trainer.trainer_utils import OptimizerNames, ShardingOption, has_length +from paddleformers.transformers.model_utils import _add_variant, unwrap_model +from paddleformers.utils import infohub +from paddleformers.utils.batch_sampler import DistributedBatchSampler +from paddleformers.utils.env import PADDLE_OPTIMIZER_NAME, PADDLE_WEIGHTS_NAME +from paddleformers.utils.log import logger + +try: + from paddleformers.quantization.quantization_linear import QuantizationLinear +except: + QuantizationLinear = None + +# moe hack +from ernie.callbacks import SPGradSyncCallback +from ernie.moe.distributed.data_parallel import DataParallel as MoEDDP +from ernie.moe.distributed.hybrid_parallel_optimizer import ( + HybridParallelClipGrad as MoEHybridParallelClipGrad, +) +from ernie.moe.moe_clip import ClipGradForMOEByGlobalNorm +from ernie.utils.moe_utils import distributed_optimizer_for_moe + + +def is_dp_group_support_in_group_sharded_parallel(): + """ + Check if 'dp_group' parameter is supported in group_sharded_parallel function. + + Returns: + bool: True if 'dp_group' is a valid parameter, False otherwise. + """ + return "dp_group" in set(inspect.signature(paddle.distributed.sharding.group_sharded_parallel).parameters.keys()) + + +class ErnieMoETrainer(Trainer): + """ + Custom trainer class for Ernie-MoE model with enhanced distributed training support. + """ + + def __init__(self, data_args, do_generation: bool, **kwargs): + """ + Initialize ErnieMoETrainer. + + Args: + data_args: Dataset configuration arguments. + do_generation (bool): Flag to enable generation mode. + **kwargs: Additional keyword arguments for base Trainer class. + """ + super().__init__(**kwargs) + self.data_args = data_args + self.do_generation = do_generation + self.data_seed = kwargs.pop("data_seed", None) + + def prediction_pipeline_step_with_logits_acc( + self, + *args, + **kwargs, + ): + """ + Pipeline step for prediction with logits accuracy calculation. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + tuple: Contains: + - loss (Tensor): Computed loss value. + - preds (Tensor): Predicted indices from logits. + - weight (Tensor): Weights from logits. + - labels (Tensor): Ground truth labels. + """ + loss, _, labels = self.prediction_pipeline_step(*args, **kwargs) + if "pp_preds" in infohub: + preds = paddle.concat(infohub["pp_preds"], axis=0) + weight = paddle.concat(infohub["pp_preds_w"], axis=0) + infohub["pp_preds"] = [] + infohub["pp_preds_w"] = [] + + return (loss, (preds, weight), labels) + return (loss, None, labels) + + def _wrap_model(self, model, training=True): + """ + Wrap model with distributed training components. + + Args: + model: Model to wrap. + training (bool): Whether in training mode. Defaults to True. + + Returns: + Model: Wrapped model with distributed training components. + """ + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if unwrap_model(model) is not model: + return model + + # Note: in paddle.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Mixed precision training + if training and self.do_grad_scaling: # self.args.fp16_opt_level=="O2": + # model, self.optimizer + decorated = paddle.amp.decorate( + models=model, + optimizers=self.optimizer, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + excluded_layers=QuantizationLinear, + ) + + if self.optimizer is None: + model = decorated + else: + model, self.optimizer = decorated + + def enable_sequence_parallel(_model): + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: + if self.args.use_sp_callback: + self.add_callback(SPGradSyncCallback(_model._layers)) + else: + register_sequence_parallel_allreduce_hooks( + _model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce + ) + + if self.args.world_size == 1: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) + assert self.optimizer is not None, "optimizer is empty!" + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + + # Multi-gpu training + if self.args.use_expert_parallel: + logger.debug("using moe ddp, hack Paddle") # TODO move this into paddle + paddle.DataParallel = MoEDDP + + in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1 + in_sharding_parallel_mode = self.sharding is not None + in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1 + + # Multi-gpu training + if ( + self.args.world_size > 1 + and not self.args.use_hybrid_parallel + and not (in_pipeline_parallel_mode or in_sharding_parallel_mode or in_tensor_parallel_mode) + ): + model = paddle.DataParallel(model) + # Distributed training (should be after fp16 initialization) + + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) + assert self.optimizer is not None, "optimizer is empty!" + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + + # Pipeline mode + if in_pipeline_parallel_mode: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use + # hack for pipeline model mini batch to batch + # need batter solution @ZHUI + # make batch_fn compatible for fleet.distributed_model decorate. + prepare_pipeline_inputs_func = ( + model._prepare_pipeline_inputs_func if hasattr(model, "_prepare_pipeline_inputs_func") else None + ) + if isinstance(model, LoRAModel): + model = model.model + model = fleet.distributed_model(model) + if prepare_pipeline_inputs_func is not None: + model._prepare_pipeline_inputs_func = prepare_pipeline_inputs_func + else: + + def _prepare_pipeline_inputs_func(inputs): + first_stage_keys = [ + "input_ids", + "attention_mask", + "position_ids", + ] + last_stage_keys = ["labels"] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) for k in keys if k in inputs]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict or type(inputs) is OrderedDict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + + logger.warning( + "Using default prepare pipeline inputs func, only support input_ids and labels as inputs." + ) + model._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func + + enable_sequence_parallel(model) + + assert self.optimizer is not None, "Pipeline mode need decorate optimizer, pelease init optimizer." + if self.args.amp_master_grad: + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + self.optimizer = distributed_optimizer_for_moe(self.optimizer, self.args.use_expert_parallel) + + # No pipeline mode, sharding only + if not in_pipeline_parallel_mode and in_sharding_parallel_mode: + # Sharded DDP! + if self.args.tensor_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + assert ( + ShardingOption.SHARD_GRAD_OP in self.args.sharding or ShardingOption.SHARD_OP in self.args.sharding + ), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now." + model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) + enable_sequence_parallel(model) + + if ShardingOption.SHARD_OP in self.args.sharding: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use + model = fleet.distributed_model(model) + if self.args.amp_master_grad: + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + self.optimizer = distributed_optimizer_for_moe(self.optimizer, self.args.use_expert_parallel) + else: + # sync params (broadcast) buffers in dp group, no quite understanding here. + if ( + not is_dp_group_support_in_group_sharded_parallel() or self.args.use_expert_parallel + ) and self.args.data_parallel_degree > 1: + from paddle.distributed.parallel import sync_params_buffers + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0]) + + cpu_offload = ShardingOption.OFFLOAD in self.args.sharding + assert self.optimizer is not None, "optimizer is empty!" + level = None + if ShardingOption.SHARD_GRAD_OP in self.args.sharding: + level = "os_g" + if ShardingOption.FULL_SHARD in self.args.sharding: + level = "p_g_os" + + from paddle.distributed.sharding import group_sharded_parallel + + # add dp_group and exclude_layer params + # https://www.paddlepaddle.org.cn/ + # documentation/docs/zh/develop/ + # api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel + extra_kwargs = {} + if is_dp_group_support_in_group_sharded_parallel() and not self.args.use_expert_parallel: + extra_kwargs["dp_group"] = self.dp_group + extra_kwargs["exclude_layer"] = ["GroupNorm"] + + if self.args.amp_master_grad: + assert ( + self.args.data_parallel_degree == 1 + ), "Sharding stage 2 / Sharding stage 3 main grad is not compatible with dp for now." + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + + model, optimizer, _ = group_sharded_parallel( + model, + self.optimizer, + level=level, + scaler=None, + group=self.sharding_group, + offload=cpu_offload, + **extra_kwargs, + ) + if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad: + assert hasattr(optimizer, "use_main_grad"), ( + "Current installed paddle doesn't support sharding stage 2 with main grad, " + "please upgrade your paddle (using nightly version)." + ) + + sharding_parallel_config = set(self.args.sharding_parallel_config.split(" ")) + if level == "os_g" and "enable_stage2_overlap" in sharding_parallel_config: + model._set_reduce_overlap(True) + optimizer._set_broadcast_overlap(True, model) + self.optimizer = optimizer + + # pure tesnor parallel mode, no pipeline_parallel, no sharding. + if not in_pipeline_parallel_mode and not in_sharding_parallel_mode and in_tensor_parallel_mode: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use + model = fleet.distributed_model(model) + model.accumulate_steps = self.args.gradient_accumulation_steps + assert self.optimizer is not None, "Tensor parallel mode need decorate optimizer, pelease init optimizer." + enable_sequence_parallel(model) + if self.args.amp_master_grad: + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + self.optimizer = distributed_optimizer_for_moe(self.optimizer, self.args.use_expert_parallel) + + return model + + def create_optimizer(self, lr_scheduler=None): + """ + Create and configure the optimizer for training. + + Args: + lr_scheduler (Optional): Learning rate scheduler for adjusting the learning rate during training. + + Returns: + paddle.optimizer.Optimizer: The configured optimizer instance with specified parameters and settings. + """ + self.static_name_to_dyg_name = {p.name: n for n, p in self.model.named_parameters()} + + if self.optimizer is None: + if self.optimizer_grouped_parameters is not None: + optimizer_params = self.optimizer_grouped_parameters + else: + optimizer_params = self.model.parameters() + + decay_parameters = [ + p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2": + optimizer_kwargs["multi_precision"] = True + if self.args.optim == OptimizerNames.ADAMW_CUSTOM: + optimizer_kwargs["quantization_config"] = self.model.config.quantization_config + optimizer_kwargs["use_lowprecision_moment"] = self.args.use_lowprecision_moment + optimizer_kwargs["tensorwise_offload_optimizer"] = self.args.tensorwise_offload_optimizer + + def _get_layer_lrs(x, lr_lower_bound, n_layers): + """ + Calculate layer-wise learning rates with depth-based scaling. + + Implements a learning rate schedule where layers closer to the input (lower depth) + get smaller learning rates, while deeper layers get progressively higher rates. + This follows the common practice that earlier layers typically need finer tuning. + + Args: + x (Parameter): The model parameter to calculate learning rate for + lr_lower_bound (float): Minimum learning rate (for depth=0 layers) + n_layers (int): Total number of transformer layers in the model + + Returns: + float: Computed learning rate for the given parameter + + Note: + - Special layers (embedding and head) get fixed positions in the depth hierarchy + - The depth-to-LR mapping follows a linear interpolation between lower bound and 1.0 + - TODO: Needs to consider LoRA (Low-Rank Adaptation) parameters in future + """ + name = self.static_name_to_dyg_name[x.name] + if "lm_head" in name or "ernie.norm" in name: + depth = n_layers + 2 + elif "embed_tokens" in name: + depth = 0 + else: + if name.startswith("ernie.layers."): + depth = int(name.split(".")[2]) + else: + depth = int(name.split(".")[0]) + return lr_lower_bound + depth / (n_layers + 2) * (1 - lr_lower_bound) + + lr_ratio_func = None + layerwise_lr_decay_bound = self.args.layerwise_lr_decay_bound + assert ( + layerwise_lr_decay_bound > 0 and layerwise_lr_decay_bound <= 1 + ), f"layerwise_lr_decay_bound: {layerwise_lr_decay_bound} out of range. should be in (0, 1]" + if layerwise_lr_decay_bound < 1: + lr_ratio_func = partial( + _get_layer_lrs, + lr_lower_bound=layerwise_lr_decay_bound, + n_layers=self.model.config.num_hidden_layers, + ) + + if self.args.max_grad_norm <= 0: + grad_clip = None + elif self.args.use_expert_parallel and not self.args.use_hybrid_parallel: + + def expert_fn(p): + return getattr(p, "no_sync", False) + + grad_clip = ClipGradForMOEByGlobalNorm( + self.args.max_grad_norm, + is_expert_param_func=expert_fn, + moe_group=_get_global_group(), + ) + else: + grad_clip = nn.ClipGradByGlobalNorm(self.args.max_grad_norm) + + self.optimizer = optimizer_cls( + learning_rate=(self.lr_scheduler if lr_scheduler is None else lr_scheduler), + apply_decay_param_fun=apply_decay_param_fun, + parameters=optimizer_params, + weight_decay=self.args.weight_decay, + grad_clip=grad_clip, + lr_ratio=lr_ratio_func, + **optimizer_kwargs, + ) + + if self.args.use_expert_parallel and self.args.use_hybrid_parallel: + logger.debug('using moe-hybrid-clip under hybrid parallel') + hcg = fleet.get_hybrid_communicate_group() + self.optimizer._grad_clip = MoEHybridParallelClipGrad( + self.optimizer._grad_clip, + hcg, + moe_group=hcg.get_data_parallel_group(), + ) + + self.optimizer._dtype = paddle.get_default_dtype() + return self.optimizer + + def prediction_step( + self, + model, + inputs, + prediction_loss_only: bool, + ignore_keys=None, + ): + """ + Perform a single prediction step with model inference. + + Args: + model: The neural network model used for prediction. + inputs: Input data for the prediction step. + prediction_loss_only (bool): Flag indicating whether to return only loss. + ignore_keys (Optional): Keys to ignore during the prediction process. + + Returns: + Tuple: A tuple containing loss values, predictions, and labels. + """ + if prediction_loss_only: + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) + elif not self.do_generation: + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + pipeline_parallel_group = hcg.get_pipe_parallel_group() + pipeline_parallel_degree = hcg.get_pipe_parallel_world_size() + except: + model_parallel_group = None + tensor_parallel_degree = 1 + pipeline_parallel_group = None + pipeline_parallel_degree = 1 + + # register `pp_accuracy` flag + if pipeline_parallel_degree > 1: + infohub["pp_accuracy"] = True + inputs = self._prepare_inputs(inputs) + loss, logits, labels = self.prediction_pipeline_step_with_logits_acc( + model, inputs, prediction_loss_only, ignore_keys + ) + else: + loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) + + # argmax here to avoid gather all logits, which is too memory-consuming. + # keepdim in order to maintain the same shape as logits + is_tensor_parallel_output = ( + model.config.tensor_parallel_output + if hasattr(model, "config") + else model._layers.config.tensor_parallel_output + ) + + if pipeline_parallel_degree > 1: + if logits is None: + preds = None + preds_shape = [[]] + else: + # preds: [bz, seq], logits: [bz, seq, part_hidden] + vocab_size_part = model._layers.config.vocab_size // tensor_parallel_degree + # logits were already argmax in `modeling.py` in pp mode. + preds = logits[0] + weight = logits[1] + # tp group concat + if tensor_parallel_degree > 1 and is_tensor_parallel_output: + # extract maximum `weight` + # weight: [bz, seq], logits: [bz, seq, part_hidden] + + batch_size, seq_len = preds.shape + + # indices offset + offset = ( + paddle.arange(tensor_parallel_degree) + .unsqueeze(0) + .unsqueeze(0) + .expand([batch_size, seq_len, tensor_parallel_degree]) + * vocab_size_part + ) + preds = paddle.distributed.collective._c_concat(preds, group=model_parallel_group) + preds = preds.reshape([batch_size, -1, seq_len]).transpose([0, 2, 1]) + preds = preds + offset + # preds: [bz, seq, tp_size], weight: [bz, seq, tp_size] + weight = ( + paddle.distributed.collective._c_concat(weight, group=model_parallel_group) + .reshape([batch_size, -1, seq_len]) + .transpose([0, 2, 1]) + ) + + # weight: [bz, seq] + # concat and argmax again to get true maximum + weight = weight.argmax(axis=-1) + + # preds: [bz, seq, tp_size], weight: [bz, seq] + # extract maximum indices `preds` + preds = preds[ + paddle.arange(preds.shape[0]).unsqueeze(1), + paddle.arange(preds.shape[1]).unsqueeze(0), + weight, + ] + + if len(preds.shape) == 1: + # NOTE(hehuang): Adapt evaluation with use_sparse_head_and_loss_fn. + # logits' shape is [num_predictions, vocab_size] when use_sparse_flash_attn is on, + # and need to add virtual batch dim for Trainer._pad_across_processes. + preds = preds[None] + + preds_shape = [preds.shape] + else: + if logits is None: + preds = None + preds_shape = [[]] + else: + # NOTE(hehuang): Decrease the communication cost of nested_gather. + if tensor_parallel_degree > 1 and is_tensor_parallel_output: + logits = paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + preds = logits.argmax(axis=-1) + if len(preds.shape) == 1: + # NOTE(hehuang): Adapt evaluation with use_sparse_head_and_loss_fn. + # logits' shape is [num_predictions, vocab_size] when use_sparse_flash_attn is on, + # and need to add virtual batch dim for Trainer._pad_across_processes. + preds = preds[None] + + preds_shape = [preds.shape] + + if pipeline_parallel_group and pipeline_parallel_group.nranks > 1: + # broadcast logits from pp last rank to others. + paddle.distributed.broadcast_object_list( + preds_shape, + src=pipeline_parallel_group.ranks[-1], + group=pipeline_parallel_group, + ) + if not model.is_pipeline_last_stage(): + preds = paddle.empty(shape=preds_shape[0], dtype=paddle.int64) + task = dist.stream.broadcast( + preds, src=pipeline_parallel_group.ranks[-1], group=pipeline_parallel_group, sync_op=False + ) + task.wait() + + return (loss, preds, labels) + loss = None + model.eval() + with paddle.no_grad(): + generated_tokens = model.generate( + **inputs, + decoding_strategy="sampling", + top_k=1, + max_length=self.data_args.max_seq_len, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.cls_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=True, + )[0] + + all_preds = [] + for pred_tokens in generated_tokens: + pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id] + all_preds.append(pred_tokens) + max_pred_length = max([len(x) for x in all_preds]) + for index, preds in enumerate(all_preds): + all_preds[index] = paddle.to_tensor(preds.tolist() + [-100] * (max_pred_length - len(preds))) + all_preds = paddle.to_tensor(all_preds) + + if "labels" in inputs: + all_labels = inputs["labels"] + all_labels = paddle.to_tensor(all_labels) + else: + all_labels = None + return (loss, all_preds, all_labels) + + def log(self, logs: Dict[str, float], **kwargs) -> None: + """ + Log training metrics and calculate perplexity where applicable. + + Args: + logs (Dict[str, float]): Dictionary containing training/evaluation metrics. + **kwargs: Additional keyword arguments for logging. + """ + if hasattr(self.model, "ranking_loss"): + logs["ranking_loss"] = self.model.ranking_loss + else: + if "loss" in logs: + logs["ppl"] = np.exp(logs["loss"]) + if "eval_loss" in logs: + logs["eval_ppl"] = np.exp(logs["eval_loss"]) + + train_eval = "train" if "loss" in logs else "eval" + if self.state.epoch is not None and train_eval == "train": + self.state.epoch *= self.args.num_train_epochs + super().log(logs, **kwargs) + + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + """ + Load random number generator states from checkpoint. + + Args: + checkpoint: Checkpoint containing saved RNG states for reproducibility. + """ + super()._load_rng_state(checkpoint) + if self.data_seed is not None: + random.setstate(random.getstate()) + np.random.set_state(np.random.get_state()) + + def _save_moe_weights(self, output_dir): + """ + Save model weights and optimizer states for Mixture-of-Experts (MoE) models. + + Args: + output_dir (str): Directory path to save the model and optimizer checkpoints. + """ + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model.state_dict() + optimizer_state_dict = self.optimizer.state_dict() + + filtered_state_dict = OrderedDict() + filter_optimizer_state_dict = OrderedDict() + + param_names_in_master_weights = ( + list(optimizer_state_dict["master_weights"].keys()) if self.args.bf16 or self.args.fp16 else [] + ) + filter_optimizer_state_dict["master_weights"] = OrderedDict() + + for k, v in state_dict.items(): + if getattr(v, 'no_sync', False): + if v.name in param_names_in_master_weights: + filter_optimizer_state_dict["master_weights"][v.name] = optimizer_state_dict["master_weights"][ + v.name + ] + + filtered_state_dict[k] = v + + for op_k, op_v in optimizer_state_dict.items(): + if op_k.startswith(v.name): + filter_optimizer_state_dict[op_k] = op_v + + filter_optimizer_state_dict['LR_Scheduler'] = optimizer_state_dict['LR_Scheduler'] + + self._save_ckpt_func( + filtered_state_dict, + os.path.join( + output_dir, + _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix), + ), + ) + if not self.args.ignore_save_lr_and_optim: + self._save_ckpt_func( + filter_optimizer_state_dict, + os.path.join( + output_dir, + _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix), + ), + ) + + def _get_train_sampler(self): + """ + Create and return appropriate data sampler for distributed training. + + Returns: + DistributedBatchSampler: Configured sampler for distributing training data across devices. + """ + if self._is_iterable_dataset(self.train_dataset): + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.world_size <= 1: + return paddle.io.BatchSampler( + dataset=self.train_dataset, + shuffle=True, + batch_size=self.args.per_device_train_batch_size, + drop_last=self.args.dataloader_drop_last, + ) + + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + else: + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) diff --git a/ernie/ERNIE/examples/post-training/tools/mergekit.py b/ernie/ERNIE/examples/post-training/tools/mergekit.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6a5338101a448c17033521d9f9814350facffb --- /dev/null +++ b/ernie/ERNIE/examples/post-training/tools/mergekit.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model Merge Tools.""" + +import argparse +import json +import os +import shutil +import time + +import paddle +from paddleformers.mergekit import MergeConfig, MergeModel +from paddleformers.trainer.argparser import strtobool +from paddleformers.utils.log import logger + + +def parse_arguments(): + """ + Parse command line arguments for model merging configuration. + + This function sets up and configures all available command line arguments + for the model merging process, including paths, device selection, and optional + tokenizer handling. + + Returns: + argparse.Namespace: An object containing all parsed command line arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--mergekit_task_config", type=str, help="The merge config path." + ) + parser.add_argument( + "--output_path", required=True, type=str, help="The merge config path." + ) + parser.add_argument( + "--lora_model_path", default=None, type=str, help="The lora model path." + ) + parser.add_argument( + "--model_name_or_path", default=None, type=str, help="The base model path." + ) + parser.add_argument("--device", default="gpu", type=str, help="Device") + parser.add_argument( + "--copy_tokenizer", default="True", type=strtobool, help="Copy tokenizer file" + ) + return parser.parse_args() + + +def logger_merge_config(merge_config, lora_merge): + """ + Logs the merge configuration details to debug output, with different formatting + for LoRA merges versus standard model merges. + + Args: + merge_config (object): Configuration object containing merge parameters. + Expected to have attributes accessible via __dict__. + lora_merge (bool): Flag indicating whether this is a LoRA merge operation. + When True, logs only LoRA-specific parameters. + When False, logs standard merge parameters. + + Outputs: + Writes formatted configuration details to the logger at DEBUG level. + For LoRA merges: Displays centered "LoRA Merge Info" header and specific paths. + For standard merges: Displays centered "Mergekit Config Info" header and all + parameters except excluded ones. + """ + if lora_merge: + logger.debug("{:^40}".format("LoRA Merge Info")) + for k, v in merge_config.__dict__.items(): + if k in ["lora_model_path", "base_model_path"]: + logger.debug(f"{k:30}: {v}") + else: + logger.debug("{:^40}".format("Mergekit Config Info")) + for k, v in merge_config.__dict__.items(): + if k in ["model_path_str", "device", "tensor_type", "merge_preifx"]: + continue + logger.debug(f"{k:30}: {v}") + + +def merge(): + """ + Main function for merging models, supporting both LoRA adapter merging and standard model merging. + + Handles the complete merging workflow including: + - Argument parsing + - Device configuration + - Configuration setup for different merge types + - Model merging execution + - Progress logging and timing + + The function has two main execution paths: + 1. LoRA Merge: When lora_model_path is specified + 2. Standard Merge: When mergekit_task_config is specified + + Returns: + None: Outputs are written to specified paths and logged to console + """ + args = parse_arguments() + + paddle.set_device(args.device) + tensor_type = "np" if args.device == "cpu" else "pd" + + lora_merge = args.lora_model_path is not None + if lora_merge: + start = time.time() + logger.info("***** Start merging LoRA model *****") + config = {} + config["output_path"] = args.output_path + config["lora_model_path"] = args.lora_model_path + config["base_model_path"] = args.model_name_or_path + if args.copy_tokenizer: + config["copy_file_list"] = [ + "tokenizer.model", + "tokenizer_config.json", + "special_tokens_map.json", + ] + merge_config = MergeConfig(**config) + mergekit = MergeModel(merge_config) + logger_merge_config(merge_config, lora_merge) + mergekit.merge_model() + src_file = os.path.join(args.model_name_or_path, "config.json") + dst_file = os.path.join(args.output_path, "config.json") + if os.path.isfile(src_file): + shutil.copy2(src_file, dst_file) + else: + logger.debug( + f"Copy failed: 'config.json' not found in {args.model_name_or_path}" + ) + logger.info( + f"***** Successfully finished merging LoRA model. Time cost: {time.time() - start} s *****" + ) + else: + with open(args.mergekit_task_config, "r", encoding="utf-8") as f: + config_list = json.load(f) + if not ( + isinstance(config_list, list) + and all(isinstance(config, dict) for config in config_list) + ): + raise ValueError( + "The mergekit_task_config must be a list of dict. Please check config." + ) + + for i, config in enumerate(config_list): + logger.info("=" * 30) + start = time.time() + logger.info(f"***** Start merging model id: {i} *****") + config["output_path"] = os.path.join( + args.output_path, config.pop("output_folder_name") + ) + config["tensor_type"] = tensor_type + if args.copy_tokenizer: + config["copy_file_list"] = [ + "tokenizer.model", + "tokenizer_config.json", + "special_tokens_map.json", + ] + merge_config = MergeConfig(**config) + mergekit = MergeModel(merge_config) + logger_merge_config(merge_config, lora_merge) + mergekit.merge_model() + logger.info( + f"***** Successfully finished merging model id: {i}. Time cost: {time.time() - start} s *****" + ) + + +if __name__ == "__main__": + merge() diff --git a/ernie/ERNIE/examples/post-training/tools/run_lora_merge.sh b/ernie/ERNIE/examples/post-training/tools/run_lora_merge.sh new file mode 100644 index 0000000000000000000000000000000000000000..0758b710f7b300dc6a330a38e427c033a675e5ca --- /dev/null +++ b/ernie/ERNIE/examples/post-training/tools/run_lora_merge.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export PYTHONPATH=$(dirname "$0")/..:$PYTHONPATH + +python -m paddle.distributed.launch \ + --gpus 0,1,2,3,4,5,6,7 \ + examples/post-training/tools/mergekit.py \ + --lora_model_path "checkpoint-path/ernie-model-lora" \ + --model_name_or_path "base-model-path" \ + --output_path "checkpoint-path/merge_lora_model" \ diff --git a/ernie/ERNIE/examples/pre-training/ernie/src/__init__.py b/ernie/ERNIE/examples/pre-training/ernie/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cc79cc9d7f1977efe8e066facf32c20c8ad3af --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/ernie/src/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ernie/ERNIE/examples/pre-training/ernie/src/trainers/__init__.py b/ernie/ERNIE/examples/pre-training/ernie/src/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..254a42c39d7fed9884d56f720ea30b85a4452b01 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/ernie/src/trainers/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pretraining_trainer import ( + PreTrainingArguments, + PretrainingTrainer, + WeightedDistributedSampler, +) + +__all__ = [ + 'PretrainingTrainer', + 'PreTrainingArguments', + 'WeightedDistributedSampler', +] diff --git a/ernie/ERNIE/examples/pre-training/ernie/src/trainers/data_parallel.py b/ernie/ERNIE/examples/pre-training/ernie/src/trainers/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..4d32137517bc7b8fe01e0da6d9a94968715616ae --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/ernie/src/trainers/data_parallel.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import framework +from paddle.distributed import fleet +from paddle.framework import ( + base as imperative_base, +) +from paddle.framework import ( + core, + in_dynamic_mode, +) + + +class DataParallel(paddle.DataParallel): + def init_reducer(self): + layers_param = [] + params_set = set() + for sublayer in self.sublayers(): + for _, param in sublayer.named_parameters(include_sublayers=False): + if param is None or param in params_set: + continue + params_set.add(param) + if not isinstance(param, self.var_dtype): + raise TypeError("The data type of '%s' must be '%s'" % (param.name, self.var_dtype)) + if param.trainable: + layers_param.append((sublayer, param)) + + trainable_parameters = list( + filter( + lambda x: not getattr(x, "no_sync", False), + [param for _, param in layers_param], + ) + ) + + assert len(trainable_parameters) > 0, ( + "This model does not have any parameters to train, and " "does not need to use DataParallel" + ) + + def check_layer_sparse(sublayer): + if isinstance(sublayer, paddle.nn.layer.common.Embedding): + return sublayer._sparse + return False + + is_sparse_gradient = [ + check_layer_sparse(sublayer) for sublayer, param in layers_param if not getattr(param, "no_sync", False) + ] + + if in_dynamic_mode(): + self.group_indices = core.eager_assign_group_by_size( + trainable_parameters, + is_sparse_gradient, + [self.last_comm_buffer_size, self.comm_buffer_size], + ) + self._reducer = core.EagerReducer( + trainable_parameters, + list(reversed(self.group_indices)), + is_sparse_gradient, + self.group.process_group, + [self.last_comm_buffer_size, self.comm_buffer_size], + self.find_unused_parameters, + ) + + +@imperative_base.no_grad +@framework.dygraph_only +def sync_dp_moe_params_across_sharding(model: paddle.nn.Layer) -> None: + hcg = fleet.fleet._hcg + sharding_parallel_group = hcg.get_sharding_parallel_group() + src_rank = hcg.get_sharding_parallel_group_src_rank() + model_vars = [] + for _, param in model._obtain_parameters_buffers().items(): + if not isinstance(param, core.eager.Tensor): + raise TypeError(f"The data type of '{param.name}' must be core.eager.Tensor") + + if param.type == core.VarDesc.VarType.VOCAB: + continue + + if getattr(param, "no_sync", False): + model_vars.append(param.detach()) + + if len(model_vars) == 0: + return + + for var in model_vars: + var = var.contiguous() + paddle.distributed.broadcast(var, src=src_rank, group=sharding_parallel_group, sync_op=True) diff --git a/ernie/ERNIE/examples/pre-training/ernie/src/trainers/pretraining_trainer.py b/ernie/ERNIE/examples/pre-training/ernie/src/trainers/pretraining_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2fe1b365ec91d251156208e0e37c31c1a3a0f4 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/ernie/src/trainers/pretraining_trainer.py @@ -0,0 +1,1348 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "PretrainingTrainer", +] + + +import contextlib +import json +import logging +import math +import os +import pickle +import random +import re +import time +from collections import OrderedDict, defaultdict +from dataclasses import dataclass, field +from types import MethodType +from typing import Optional + +import numpy as np +import paddle +import paddle.amp.auto_cast as autocast +from paddle import framework, nn +from paddle.base import core +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.fleet.utils import mix_precision_utils +from paddleformers.trainer import ( + Trainer, + TrainingArguments, + speed_metrics, +) +from paddleformers.utils.tools import get_env_device + +try: + from paddleformers.trainer import TRAINING_ARGS_NAME +except ImportError: + TRAINING_ARGS_NAME = "training_args.bin" + +try: + from paddleformers.utils.env import ( + PADDLE_OPTIMIZER_NAME, + ) +except ImportError: + from paddleformers.trainer.trainer import ( + OPTIMIZER_NAME, + ) + + PADDLE_OPTIMIZER_NAME = OPTIMIZER_NAME + +try: + from paddleformers.trainer.trainer import ( + PADDLE_WEIGHT_FILE_NAME as PADDLE_WEIGHTS_NAME, + ) +except ImportError: + from paddleformers.utils.env import PADDLE_WEIGHTS_NAME +import paddle.distributed as dist +from models.sequence_parallel_utils import register_sequence_parallel_allreduce_hooks +from models.utils import global_training_logs_enabled +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer, +) +from paddleformers.datasets import MapDataset +from paddleformers.trainer.trainer_callback import PrinterCallback +from paddleformers.trainer.trainer_utils import ( + ShardingOption, +) +from paddleformers.trainer.utils import add_start_docstrings +from paddleformers.transformers.model_utils import _add_variant, unwrap_model +from paddleformers.utils.batch_sampler import ( + DistributedBatchSampler as PaddleNLPDistributedBatchSampler, +) + +from src.callbacks import ( + GCCallback, + LoggingCallback, + SPGradSyncCallback, + TensorBoardCallback, + FP8QuantWeightCallback, +) +from src.callbacks.moe_logging_callback import MoeLoggingCallback +from src.clip import ClipGradForMOEByGlobalNorm +from src.lr_schedulers import get_wsd_schedule_with_warmup +from src.trainers.data_parallel import sync_dp_moe_params_across_sharding +from src.utils.misc import global_training_logs +from src.utils.training_utils import ( + reset_per_device_batch_size, +) + +logger = logging.getLogger(__name__) + + +def distributed_optimizer_maybe_overwrite( + optimizer, + use_moe, +): + if use_moe: + from src.trainers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer as MoEHybridParallelOptimizer, + ) + + fleet_env = fleet.fleet + fleet_env.user_defined_optimizer = optimizer + hp_optim = MoEHybridParallelOptimizer(optimizer, fleet_env._hcg, fleet_env._user_defined_strategy) + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].dp_comm_overlap: + hp_optim._dp_enable = False + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].sharding_comm_overlap: + hp_optim._sharding_enable = False + return hp_optim + else: + return fleet.distributed_optimizer(optimizer) + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + vocab_path: str = field(default=None, metadata={"help": "eb35 streaming data vocab"}) + model_name_or_path: str = field( + default=None, + metadata={ + "help": "Path to pretrained model or model identifier from " + "https://paddleformers.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + prefetch_factor: int = field( + default=2, + metadata={"help": "global random seed factor."}, + ) + eval_iters: int = field( + default=-1, + metadata={"help": "eval iteration for every evaluation."}, + ) + num_consecutive: int = field( + default=1, + metadata={"help": "H5 file consecutive num."}, + ) + min_lr: float = field( + default=0.0, + metadata={"help": "minus learning rate"}, + ) + dataset: str = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + + input_dir: str = field(default=None, metadata={"help": "data path"}) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split ratio"}) + + max_seq_length: int = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + global_batch_size: int = field( + default=-1, + metadata={ + "help": "if `global_batch_size` and `per_device_train_batch_size` is provied, " + "`gradient_accumulation_steps` will be ignored" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + sequence_parallel: Optional[int] = field( + default=0, + metadata={}, + ) + virtual_pp_degree: Optional[int] = field( + default=1, + metadata={ + "help": "vpp", + }, + ) + from_scratch: Optional[int] = field(default=1, metadata={"help": "train from scratch"}) + same_data: Optional[bool] = field( + default=None, + metadata={"help": "when resume from checkpoint, keey data same with the ckpt"}, + ) + base_seq_length: Optional[int] = field(default=4096, metadata={"help": "reeao min seq_length"}) + shuffle_consecutive: Optional[bool] = field( + default=False, + metadata={"help": "shuffle num_consecutive or not"}, + ) + global_shuffle_num_examples: Optional[int] = field( + default=0, + metadata={"help": "max num of shuffling among different parts"}, + ) + use_async_save: Optional[bool] = field(default=False, metadata={"help": "use async save or not"}) + pre_alloc_memory: float = field( + default=0.0, + metadata={ + "help": "Pre-allocate one specific-capacity empty tensor " + "and release it for avoiding memory fragmentation" + }, + ) + enable_global_training_logs: bool = field(default=False, metadata={"help": "use global_training_logs or not"}) + moe_group: Optional[str] = field(default="dp", metadata={"help": "moe comm group"}) + use_moe: Optional[bool] = field(default=False, metadata={"help": "enable expert parallel"}) + log_global_grad_norm: Optional[bool] = field( + default=False, + metadata={"help": "print global grad-norm"}, + ) + multi_token_pred_depth: Optional[int] = field( + default=0, + metadata={}, + ) + enable_mtp_magic_send: Optional[bool] = field(default=False, metadata={"help": ""}) + enable_optimizer_timer: Optional[bool] = field(default=False, metadata={"help": "enable timer in zero-1"}) + lr_scheduler: str = field( + default="cosine", + metadata={"help": "The scheduler type to use. support linear, cosine, constant, constant_with_warmup"}, + ) + decay_function: str = field( + default="half_life", + metadata={"help": "The decay function for WSD LR scheduler. support half_life(default), 1-sqrt"}, + ) + moe_gate_lr_ratio: float = field( + default=None, + metadata={"help": ("special handle the lr for gate/router")}, + ) + + gc_interval: int = field(default=0, metadata={"help": "gc time"}) + use_sp_callback: int = field( + default=True, + metadata={"help": "use callback for sequence parallel"}, + ) + moe_use_aux_free_update_coef: float = field( + default=1.0e-3, + metadata={"help": "moe aux free update coef,"}, + ) + use_fp8: bool = field( + default=False, + metadata={"help": "whether to use fp8 training"}, + ) + global_logging_interval: int = field( + default=1, + metadata={"help": "the logging interval of global_training_logs"}, + ) + train_moe_only: int = field(default=None, metadata={"help": "train moe params only"}) + + @property + def use_moe(self): # noqa: F811 + return getattr(self, "use_expert_parallel", self._use_moe) + + @use_moe.setter + def use_moe(self, value): + self.use_expert_parallel = value + self._use_moe = value + + @property + def need_data(self): + return self.pipeline_parallel_rank == 0 and self.tensor_parallel_rank == 0 + + @property + def combine_batch(self): + return self.max_seq_length // self.base_seq_length + + @property + def reeao_dataset_rank(self): + return super().dataset_rank + + @property + def reeao_dataset_world_size(self): + return super().dataset_world_size + + def __post_init__(self): + super().__post_init__() + + if self.global_batch_size > 0: + micro_bsz, acc_steps = reset_per_device_batch_size( + self.global_batch_size, + self.per_device_train_batch_size, + self.dataset_world_size, + ) + logger.info(f"global_batch={self.global_batch_size} micro-bsz:{micro_bsz}, accumulate_steps:{acc_steps}") + if ( + acc_steps != 1 + and self.gradient_accumulation_steps != 1 + and acc_steps != self.gradient_accumulation_steps + ): + raise ValueError( + f"global_accumulation_steps={self.gradient_accumulation_steps}" + f"& global_batch={self.global_batch_size} are both set" + ) + self.per_device_train_batch_size, self.gradient_accumulation_steps = ( + micro_bsz, + acc_steps, + ) + + self.max_gradient_accumulation_steps = self.gradient_accumulation_steps + + if self.pipeline_parallel_degree > 1: + self.per_device_eval_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps + logger.warn(f"eval_batch_size set to {self.per_device_eval_batch_size} in Pipeline Parallel!") + user_defined_strategy = fleet.fleet._user_defined_strategy + user_defined_strategy.strategy.pipeline_configs.accumulate_steps = self.gradient_accumulation_steps + self.max_gradient_accumulation_steps = self.gradient_accumulation_steps + logger.info(f"fixing pp configs: {user_defined_strategy.pipeline_configs}") + else: + self.per_device_eval_batch_size = self.per_device_train_batch_size + logger.warn(f"eval_batch_size set to {self.per_device_eval_batch_size}") + + if self.sharding_parallel_degree > 1: + sharding_parallel_config = ( + set(self.sharding_parallel_config.split(" ")) if self.sharding_parallel_config else set() + ) + sharding_comm_overlap_non_pp = ( + True + if "shardingv1_comm_overlap" in sharding_parallel_config + or "sharding_comm_overlap" in sharding_parallel_config + else False + ) + if sharding_comm_overlap_non_pp: + assert hasattr(fleet.fleet, "_user_defined_strategy") + user_defined_strategy = fleet.fleet._user_defined_strategy + user_defined_strategy.hybrid_configs["sharding_configs"].accumulate_steps = ( + self.gradient_accumulation_steps + ) + + if hasattr(fleet.fleet, "_user_defined_strategy"): + user_defined_strategy = fleet.fleet._user_defined_strategy + if ( + hasattr(user_defined_strategy, "hybrid_configs") + and "sharding_configs" in user_defined_strategy.hybrid_configs + ): + sd_configs = user_defined_strategy.hybrid_configs["sharding_configs"] + if sd_configs.comm_overlap: + assert self.global_batch_size % self.dataset_world_size == 0, ( + f"global_batch_size[{self.global_batch_size}] should be divisible by " + f"dataset_world_size[{self.dataset_world_size}]" + ) + lbs = self.global_batch_size // self.dataset_world_size + assert lbs % self.per_device_train_batch_size == 0, ( + f"local_batch_size[{lbs}] should be divisible by " + f"per_device_train_batch_size[{self.per_device_train_batch_size}]" + ) + assert lbs // self.per_device_train_batch_size == sd_configs.accumulate_steps, ( + f"local_batch_size[{lbs}] should be equal to " + f"accumulate_steps[{sd_configs.accumulate_steps}] * " + f"per_device_train_batch_size[{self.per_device_train_batch_size}]" + ) + + if ShardingOption.SHARD_GRAD_OP in self.sharding: + logger.info("disabling `sp_callback` b/c using sharding stage2") + self.use_sp_callback = False + + +class WeightedDistributedSampler(PaddleNLPDistributedBatchSampler): + def __init__( + self, + dataset, + batch_size, + output_dir, + dp_rank, + dp_size, + num_consecutive=1, + seed=0, + gradient_accumulation_steps=None, + max_gradient_accumulation_steps=None, + per_device_train_batch_size=None, + combine_batch: int = 1, + shuffle_consecutive: bool = False, + global_shuffle_num_examples: int = 0, + same_data: bool = False, + **kwargs, + ): + self.num_consecutive = num_consecutive + self.seed = seed + super().__init__(dataset, batch_size, **kwargs) + self.weights = None + self.batch_size = batch_size + self.output_dir = output_dir + self.rng = random.Random(self.seed + self.epoch) + self.dp_rank = dp_rank + self.dp_size = dp_size + self.gradient_accumulation_steps = gradient_accumulation_steps + self.max_gradient_accumulation_steps = max_gradient_accumulation_steps + self.per_device_train_batch_size = per_device_train_batch_size + self.combine_batch = combine_batch + self.shuffle_consecutive = shuffle_consecutive + self.global_shuffle_seed = 0 + self.global_shuffle_num_examples = global_shuffle_num_examples + self.same_data = same_data + self.load_data_seq = False + if isinstance(self.dataset, MapDataset): + self.inner_dataset = self.dataset.data + else: + self.inner_dataset = self.dataset + assert self.inner_dataset._load + + self.max_part_id = self.inner_dataset.global_max_part_id + + self.set_epoch(0) + + def set_epoch(self, epoch=0, consumed_samples=0): + consumed_samples = consumed_samples // self.dp_size + logger.info(f"set consumed samples={consumed_samples}, epoch={epoch}") + super().set_epoch(epoch, consumed_samples) + + def gen_data_seq(self): + total = [] + for ex in self.inner_dataset.exs: + total.extend([(ex.part, 0, i) for i in range(ex.data_status, len(ex))]) + assert len(total) > self.num_consecutive, f"total={total} < num_consecutive={self.num_consecutive}" + indices = np.array_split(np.array(total), len(total) // self.num_consecutive) + if self.shuffle: + self.rng.shuffle(indices) + indices = np.concatenate(indices) + indices = self.roundup_and_shard(indices) + logger.info(indices[:10]) + return indices + + def load_data_seq_from_cache(self): + indices_file = os.path.join( + self.output_dir, + f"data_seq.epoch{self.epoch}.dp_{self.dp_rank}_of_{self.dp_size}" + f"_shard_{self.local_rank}_of_{self.nranks}.pth", + ) + if self.same_data and os.path.exists(indices_file): + logger.info(f"load data seq from file - {indices_file}") + self.load_data_seq = True + with open(indices_file, "rb") as of: + return pickle.load(of) + return None + + def gen_data_seq_weighted(self, num_examples, data_type=None): + assert self.load_data_seq is False, "需要保证所有epoch的data_seq都从文件加载,否则下次删data_seq无法控住随机性" + logger.info( + f"generating data sequence... #non_consecutive_data_chunks={num_examples}," + f" num_consecutive={self.num_consecutive}" + ) + + if num_examples > 1e5: + logger.info("generating data sequence for very large data, consider use large `num_consecutive`") + if data_type is not None: + weights = [ex.weights for ex in self.inner_dataset.exs if ex.data_type == data_type] + exs = [ex for ex in self.inner_dataset.exs if ex.data_type == data_type] + else: + weights = [ex.weights for ex in self.inner_dataset.exs] + exs = self.inner_dataset.exs + assert len(exs) > 0, f"data_type={data_type}, no data found" + total_w = sum(weights) + weights = [w / total_w for w in weights] + + logger.info( + f"using weighted sampler, num_consecutive={self.num_consecutive}:\n" + + "\n".join(["%-100s...%.3e" % (e.path, w) for w, e in zip(weights, exs)]) + ) + + part_indices_gen = {} + indices = [] + for i, ex in enumerate(exs): + sample_size = int(weights[i] * num_examples) + logger.info(f"part_data_pre_sampling--[part-{ex.part}]-[sampler-size-{sample_size}]") + assert ex.combine_batch == self.combine_batch + part_indices_gen[ex.part] = ex.sampler() + indices.extend([ex.part] * sample_size) + + logger.info(f"shuffle part placeholder index, size={len(indices)}, exmaple={indices[0]}") + if self.shuffle: + self.rng.shuffle(indices) + logger.info("shuffle done") + indices_ret = [] + logger.info("build_index from shuffled placeholder") + + for part_id in indices: + epoch, _index = next(part_indices_gen[part_id]) + if len(_index) % self.combine_batch != 0: + _index += [-1] * (self.combine_batch - len(_index) % self.combine_batch) + indices_ret += [(part_id, epoch, i) for i in _index] + + if self.shuffle_consecutive and self.combine_batch >= 1: + part_data_gen = defaultdict(lambda: []) + logger.info("consecutive placeholder 2 shuffle") + for item in indices_ret: + part_data_gen[item[0]].append(item) + logger.info("consecutive placeholder 2 shuffle...") + part_data_gen_iter = {} + for key in part_data_gen.keys(): + part_data_gen_iter[key] = iter(part_data_gen[key]) + logger.info("consecutive placeholder 2 shuffle......") + placeholder_indices = [i[0] for i in indices_ret] + placeholder_indices = [ + placeholder_indices[i : i + self.combine_batch] + for i in range(0, len(placeholder_indices), self.combine_batch) + ] + logger.info("consecutive placeholder 2 shuffle..........") + self.rng.shuffle(placeholder_indices) + logger.info("consecutive placeholder 2 shuffle.............") + placeholder_indices = [item for sublist in placeholder_indices for item in sublist] + logger.info("consecutive placeholder 2 shuffle................") + indices_ret = [next(part_data_gen_iter[i]) for i in placeholder_indices] + logger.info("consecutive placeholder 2 shuffle done") + + logger.info("build index done") + indices = np.array(indices_ret) + del indices_ret + logger.info(f"num_data_seq={len(indices)}, example={indices[:10]}") + indices = self.roundup_and_shard(indices) + return indices + + def roundup_and_shard(self, indices): + if self.nranks == 1: + return indices + + padding_size = self.total_size - len(indices) + logger.info(f"padding-size={padding_size}, total_size={self.total_size} shard={self.local_rank}/{self.nranks}") + if padding_size < 0: + indices = indices[:padding_size] + else: + indices = np.concatenate( + [ + indices, + np.tile(indices, math.ceil(padding_size / len(indices)))[:padding_size], + ] + ) + + assert len(indices) == self.total_size, (len(indices), self.total_size) + + indices = indices[self.local_rank : self.total_size : self.nranks] + assert len(indices) == self.num_samples + return indices + + def __len__(self): + raise TypeError + + def __iter__(self): + self.rng = random.Random(self.seed + self.epoch + self.global_shuffle_seed) + logger.info(f"seed={self.seed + self.epoch + self.global_shuffle_seed}") + weights = [e.weights for e in self.inner_dataset.exs] + if any(w is None for w in weights) or sum(weights) == 0.0: + logger.info(f"using normal sampler, num_consecutive={self.num_consecutive}") + indices = self.gen_data_seq() + self.weights = None + else: + self.weights = weights + num_examples = sum([ex.num_examples for ex in self.inner_dataset.exs]) + + if self.global_shuffle_num_examples > 0: + num_examples = min([self.global_shuffle_num_examples, num_examples]) + logger.info(f"using global shuffle num examples: {self.global_shuffle_num_examples}") + indices = self.load_data_seq_from_cache() + if indices is None: + indices = self.gen_data_seq_weighted(num_examples) + + if self.output_dir: + with open( + os.path.join( + self.output_dir, + f"data_seq.epoch{self.epoch}.dp_{self.dp_rank}_of_{self.dp_size}" + f"_shard_{self.local_rank}_of_{self.nranks}.pth", + ), + "wb", + ) as of: + pickle.dump(indices, of, protocol=4) + + def ret(): + nonlocal indices + buf = [] + logger.info(f"start training sequence, data-sequence: {indices[:10]}") + while 1: + if self.consumed_samples >= len(indices): + self.consumed_samples -= len(indices) + else: + for i in range(self.consumed_samples, len(indices)): + if len(buf) == self.batch_size: + yield buf + buf = [] + buf.append(indices[i].tolist()) + self.consumed_samples = 0 + self.epoch += 1 + logger.info(f"epoch done, #data={self.total_size}, reshuffle-sequence: epoch={self.epoch}") + + self.rng = random.Random(self.seed + self.epoch) + if self.weights: + indices = self.load_data_seq_from_cache() + if indices is None: + indices = self.gen_data_seq_weighted(num_examples) + else: + indices = self.gen_data_seq() + if self.output_dir: + with open( + os.path.join( + self.output_dir, + f"data_seq.epoch{self.epoch}.dp_{self.dp_rank}_of_{self.dp_size}" + f"_shard_{self.local_rank}_of_{self.nranks}.pth", + ), + "wb", + ) as of: + pickle.dump(indices, of, protocol=4) + + return ret() + + +class DummySampler(PaddleNLPDistributedBatchSampler): + def __init__(self, dataset, batch_size=1, **kwargs): + super().__init__(dataset, batch_size=batch_size, **kwargs) + + def __len__(self): + raise TypeError + + def __iter__(self): + while True: + yield [0] * self.batch_size + + +class PretrainingTrainer(Trainer): + def __init__(self, args=None, model=None, callbacks=[], **kwargs): + callbacks = [ + FP8QuantWeightCallback(), + LoggingCallback(), + TensorBoardCallback(args, model=model, log_tokens_per_step=True, log_flops_per_step=False), + GCCallback(), + ] + callbacks + + args.use_async_save = args.use_async_save and args.save_sharded_model and args.load_sharded_model + super().__init__(args=args, model=model, callbacks=callbacks, **kwargs) + self.pop_callback(PrinterCallback) + self.pp_data_buffer = [] + self._tokens_per_sec_per_card_buffer = [] + self._start_save_time = time.time() + self._end_save_time = time.time() + self._first_end_save_time = time.time() + self.resume_global_step = -1 + self.first_skip_step = 5 if self.args.save_steps > 5 else self.args.save_steps / 2 + global_training_logs.enable_skip_zero([r".*aux_loss.*"]) + global_training_logs.set_trainer_interval(self, self.args.global_logging_interval) + + def autocast_smart_context_manager(self): + if self.enable_autocast_context_manager: + black = [ + "reduce_sum", + "c_softmax_with_cross_entropy", + "elementwise_div", + "sin", + "cos", + ] + white = [ + "lookup_table", + "lookup_table_v2", + "flash_attn", + "flash_attn_v1", + "matmul", + "matmul_v2", + "fused_gemm_epilogue", + ] + if self.args.bf16 and self.args.fp16_opt_level == "O2": + black.append("c_embedding") + + ctx_manager = autocast( + True, + custom_black_list=black, + custom_white_list=white, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + ) + else: + ctx_manager = contextlib.nullcontext() + return ctx_manager + + def _load_optimizer_state(self, checkpoint): + def _broadcast_moe_optimizer_state(state_dict): + base_state_dict = {"master_weights": {}} + buf = [ + {i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]}, + {i: j.shape for i, j in state_dict["master_weights"].items()}, + {"LR_Scheduler": state_dict.get("LR_Scheduler", {})}, + ] + + if self.args.use_hybrid_parallel: + hcg = fleet.get_hybrid_communicate_group() + src_rank = hcg.get_data_parallel_group_src_rank() + group = hcg.get_data_parallel_group() + else: + src_rank = 0 + group = None + + dist.broadcast_object_list(buf, src=src_rank, group=group) + for k, s in buf[0].items(): + v = state_dict.get(k, paddle.zeros(s, "float32")).to(get_env_device()) + v.name = k + dist.broadcast(v, src=src_rank, group=group) + logger.info(f"broadcast moe optimizer {k} from {src_rank}") + base_state_dict[k] = v.cpu() + for k, s in buf[1].items(): + v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).to(get_env_device()) + v.name = k + dist.broadcast(v, src=src_rank, group=group) + logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}") + base_state_dict["master_weights"][k] = v.cpu() + base_state_dict.update(buf[2]) + return base_state_dict + + state_dict = super()._load_optimizer_state(checkpoint) + + if self.args.use_moe: + base_state_dict = _broadcast_moe_optimizer_state(state_dict) + if self.args.data_parallel_rank > 0: + master_weight = state_dict.pop("master_weights", {}) + base_state_dict.update(state_dict) + if master_weight: + if "master_weights" in base_state_dict: + base_state_dict["master_weights"].update(master_weight) + else: + base_state_dict["master_weights"] = master_weight + state_dict = base_state_dict + del base_state_dict + return state_dict + + def _save_moe_weights(self, output_dir): + optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) + saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") + + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model.state_dict() + optimzier_state_dict = self.optimizer.state_dict() + + filtered_state_dict = OrderedDict() + filter_optimzier_state_dict = OrderedDict() + + param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else [] + filter_optimzier_state_dict["master_weights"] = OrderedDict() + + for k, v in state_dict.items(): + if getattr(v, "no_sync", False): + + if v.name in param_names_in_master_weights: + filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][ + v.name + ] + if not ( + getattr(self.args, "should_save_sharding_stage1_model", False) + or getattr(self.args, "save_sharding_stage1_model", False) + ): + filtered_state_dict[k] = v + for op_k, op_v in optimzier_state_dict.items(): + if op_k.startswith(v.name): + filter_optimzier_state_dict[op_k] = op_v + + if getattr(self.args, "should_save_sharding_stage1_model", False) or getattr( + self.args, "save_sharding_stage1_model", False + ): + self._save(output_dir=output_dir) + else: + if self.args.sharding_parallel_rank == 0: + paddle.save( + filtered_state_dict, + os.path.join( + output_dir, + _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix), + ), + ) + paddle.save(filter_optimzier_state_dict, os.path.join(output_dir, optimizer_name)) + with open(saved_signal_path, mode="w+") as f: + f.write("1") + + def _wrap_model(self, model, training=True): + if unwrap_model(model) is not model: + return model + if not training: + return model + if self.args.fp16 or self.args.bf16: + model = paddle.amp.decorate(models=model, level=self.args.fp16_opt_level, dtype=self.amp_dtype) + + if self.args.use_moe: + from src.trainers.data_parallel import DataParallel as MoEDDP + + paddle.DataParallel = MoEDDP + + if self.args.world_size > 1 and not self.args.use_hybrid_parallel: + model = paddle.DataParallel(model) + + in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1 + in_sharding_parallel_mode = self.sharding is not None + in_tensor_parallel_model = self.args.tensor_parallel_degree > 1 + + def enable_sequence_parallel(_model): + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: + if self.args.use_sp_callback: + self.add_callback(SPGradSyncCallback(_model._layers)) + else: + register_sequence_parallel_allreduce_hooks(_model) + + is_dp_moe = self.args.use_moe and self.args.moe_group in {"data", "dp"} + + if in_pipeline_parallel_mode: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer( + model, + dtype=self.amp_dtype if hasattr(self, "amp_dtype") else "float16", + ) + prepare_pipeline_inputs_func = ( + model._prepare_pipeline_inputs_func if hasattr(model, "_prepare_pipeline_inputs_func") else None + ) + model = fleet.distributed_model(model) + if is_dp_moe: + logger.info("start broadcast dp moe parameters across sharding group") + sync_dp_moe_params_across_sharding(model._layers) + if prepare_pipeline_inputs_func is not None: + model._prepare_pipeline_inputs_func = prepare_pipeline_inputs_func + else: + + def _prepare_pipeline_inputs_func(inputs): + first_stage_keys = ["input_ids", "attention_mask", "position_ids"] + last_stage_keys = ["labels"] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) for k in keys if k in inputs]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + + logger.warning( + "Using default prepare pipeline inputs func, only support input_ids and labels as inputs." + ) + model._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func + + enable_sequence_parallel(model) + + assert self.optimizer is not None, "Pipeline mode need decorate optimizer, pelease init optimizer." + if self.args.amp_master_grad: + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + self.optimizer = distributed_optimizer_maybe_overwrite(self.optimizer, self.args.use_moe) + + if not in_pipeline_parallel_mode and in_sharding_parallel_mode: + if self.args.tensor_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + assert ( + ShardingOption.SHARD_GRAD_OP in self.args.sharding or ShardingOption.SHARD_OP in self.args.sharding + ), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now." + model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None) + model.accumulate_steps = self.args.gradient_accumulation_steps + enable_sequence_parallel(model) + + if ShardingOption.SHARD_OP in self.args.sharding: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) + model = fleet.distributed_model(model) + if is_dp_moe: + logger.info("start broadcast dp moe parameters across sharding group") + sync_dp_moe_params_across_sharding(model._layers) + if self.args.amp_master_grad: + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + self.optimizer = distributed_optimizer_maybe_overwrite(self.optimizer, self.args.use_moe) + + else: + if (self.args.use_moe) and self.args.data_parallel_degree > 1: + try: + from paddle.fluid.dygraph.parallel import sync_params_buffers + except ImportError: + from paddle.distributed.parallel import sync_params_buffers + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0]) + + if is_dp_moe: + logger.info("start broadcast dp moe parameters across sharding group") + sync_dp_moe_params_across_sharding(model) + + cpu_offload = ShardingOption.OFFLOAD in self.args.sharding + assert self.optimizer is not None, "optimizer is empty!" + level = None + if ShardingOption.SHARD_GRAD_OP in self.args.sharding: + level = "os_g" + if ShardingOption.FULL_SHARD in self.args.sharding: + level = "p_g_os" + + from paddle.distributed.sharding import group_sharded_parallel + + extra_kwargs = {} + if not self.args.use_moe: + extra_kwargs["dp_group"] = self.dp_group + extra_kwargs["exclude_layer"] = ["GroupNorm"] + + model, optimizer, _ = group_sharded_parallel( + model, + self.optimizer, + level=level, + scaler=None, + group=self.sharding_group, + offload=cpu_offload, + **extra_kwargs, + ) + self.optimizer = optimizer + + if not in_pipeline_parallel_mode and not in_sharding_parallel_mode and in_tensor_parallel_model: + if self.args.amp_master_grad: + mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) + + model = fleet.distributed_model(model) + model.accumulate_steps = self.args.gradient_accumulation_steps + enable_sequence_parallel(model) + assert self.optimizer is not None, "Tensor parallel mode need decorate optimizer, pelease init optimizer." + if self.args.amp_master_grad: + self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) + + self.optimizer = distributed_optimizer_maybe_overwrite(self.optimizer, self.args.use_moe) + + if self.args.use_moe: + self.callback_handler.callbacks.insert(0, MoeLoggingCallback(self.optimizer)) + + try: + from paddle.fluid.dygraph.parallel import sync_params_buffers + except ImportError: + from paddle.distributed.parallel import sync_params_buffers + + self._new_gradclip() + return model + + def _new_gradclip(self): + if ( + isinstance(self.optimizer, HybridParallelOptimizer) + and self.args.log_global_grad_norm + and self.args.max_grad_norm > 0 + ): + gradclip = self.optimizer._inner_opt._grad_clip + oldcomm = gradclip._comm_and_clip + oldclip = gradclip._dygraph_clip + hcg = fleet.get_hybrid_communicate_group() + num_pp = hcg.get_pipe_parallel_world_size() + + @paddle.no_grad() + def newcomm( + self, + params_grads, + global_norm_var_dist, + global_norm_var_not_dist, + *args, + ): + if num_pp > 1: + for p, g in params_grads: + if getattr(p, "need_clip", True) == "pp_non_distributed": + g.scale_(np.sqrt(num_pp)) + ret = oldcomm(params_grads, global_norm_var_dist, global_norm_var_not_dist, *args) + global_norm_var_fp32 = paddle.sqrt(global_norm_var_dist + global_norm_var_not_dist) + if global_training_logs_enabled(): + global_training_logs.update(global_grad_norm=global_norm_var_fp32.item()) + return ret + + @paddle.no_grad() + def new_dygraph_clip(self, params_grads): + if num_pp > 1: + for p, g in params_grads: + if getattr(p, "need_clip", True) == "pp_non_distributed": + g.scale_(1 / np.sqrt(num_pp)) + ret = oldclip(params_grads) + return ret + + self.optimizer._inner_opt._grad_clip._comm_and_clip = MethodType( + newcomm, self.optimizer._inner_opt._grad_clip + ) + self.optimizer._inner_opt._grad_clip._dygraph_clip = MethodType( + new_dygraph_clip, self.optimizer._inner_opt._grad_clip + ) + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): + self.model_wrapped.accumulate_steps = self.args.gradient_accumulation_steps + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + start_time = time.time() + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + return output.metrics + + def prediction_pipeline_step(self, model, inputs, prediction_loss_only, ignore_keys): + loss, _, labels = super().prediction_pipeline_step(model, inputs, prediction_loss_only, ignore_keys) + num_tokens = (labels != self.tokenizer.ignored_index).sum().item() + loss_avg = loss * self.model_wrapped.accumulate_steps / num_tokens + return loss_avg, loss, labels + + def restore_dataloader_status(self): + if self.args.same_data is None or self.args.same_data == "": + if self.args.resume_from_checkpoint is not None: + train_bin_file = os.path.join(self.args.resume_from_checkpoint, TRAINING_ARGS_NAME) + assert os.path.exists(train_bin_file), f"{train_bin_file} not found." + train_bin = paddle.load(train_bin_file) + old_data_filelist = train_bin.data_filelist + old_data_weights = train_bin.data_weights + old_sharding_degree = train_bin.sharding_parallel_degree + old_data_parallel_degree = train_bin.data_parallel_degree + old_reeao_data_world_size = getattr(train_bin, "reeao_data_world_size", None) + new_data_filelist = self.args.data_filelist + new_data_weights = self.args.data_weights + new_sharding_degree = self.args.sharding_parallel_degree + new_data_parallel_degree = self.args.data_parallel_degree + self.args.same_data = ( + (old_data_filelist == new_data_filelist) + and (old_data_weights == new_data_weights) + and (old_sharding_degree == new_sharding_degree) + and (old_data_parallel_degree == new_data_parallel_degree) + and (not self.args.multimodal) + and ( + old_reeao_data_world_size is None + or old_reeao_data_world_size == self.args.reeao_data_world_size + ) + ) + logger.info(f"Automatically setting same_data value: {self.args.same_data}") + else: + self.args.same_data = False + logger.info(f"Training from scratch, setting same_data value: {self.args.same_data}") + else: + logger.info(f"User has defined same_data value: {self.args.same_data}") + + if self.args.same_data: + logger.warning( + "same_data has been set to True. \ + Carefully check whether the data, population proportion, " + "and DP count are completely consistent with those before." + ) + else: + logger.warning( + "same_data has been set to False. \ + which will regenerate the global shuffle domain." + ) + + def _get_eval_sampler(self, eval_dataset) -> Optional[paddle.io.Sampler]: + return PaddleNLPDistributedBatchSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return PaddleNLPDistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): + flag_log = self.control.should_log + if self.control.should_log: + logs = {} + tr_loss_single_dp_scalar = tr_loss.item() + dist.all_reduce(tr_loss, dist.ReduceOp.SUM) + tr_loss_scalar = tr_loss.item() / dist.get_world_size() + tr_loss.zero_() + + logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged) + logs["loss_cur_dp"] = tr_loss_single_dp_scalar / (self.state.global_step - self._globalstep_last_logged) + logs["learning_rate"] = float(self._get_learning_rate()) + logs["global_step"] = int(self.state.global_step) + + divisor = 2**30 + + current_device = framework._current_expected_place_() + device_id = current_device.get_device_id() + current_memory_allocated = core.device_memory_stat_current_value("Allocated", device_id) + current_memory_reserved = core.device_memory_stat_current_value("Reserved", device_id) + max_memory_allocated = core.device_memory_stat_peak_value("Allocated", device_id) + max_memory_reserved = core.device_memory_stat_peak_value("Reserved", device_id) + logs["mem_allocated_gb"] = current_memory_allocated / divisor + logs["max_mem_allocated_gb"] = max_memory_allocated / divisor + logs["mem_reserved_gb"] = current_memory_reserved / divisor + logs["max_mem_reserved_gb"] = max_memory_reserved / divisor + + if not self.args.enable_global_training_logs: + global_training_logs.global_meters_keys = [] + + if get_env_device() == "gpu": + info_callback = global_training_logs.dict(use_async=True) + + if hasattr(self, "scaler"): + logs["loss_scale"] = float(f"{self.scaler._scale.item():.3e}") + + total_train_batch_size = ( + self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.reeao_dataset_world_size + ) + num_steps = self.state.global_step - self._globalstep_last_logged + logs.update( + speed_metrics( + "global", + self._globalstep_last_start_time, + num_samples=total_train_batch_size * num_steps, + num_steps=num_steps, + ) + ) + if not hasattr(self, "model_numel"): + model_numel = sum( + p.numel().item() + for n, p in model.named_parameters() + if not p.stop_gradient and "embeddings" not in n and "embed_tokens" not in n + ) + numel_tensor = paddle.to_tensor(model_numel) + dist.all_reduce(numel_tensor) + self.model_numel = numel_tensor.item() // self.args.dataset_world_size + + tokens_per_steps = self.args.max_seq_length * total_train_batch_size + logs["tokens_trained_current_step"] = tokens_per_steps + logs["timestamp"] = int(time.time() * 1000) + logs["TFLOPS_per_sec_per_card"] = round( + 6 + * tokens_per_steps + * self.model_numel + * logs["global_steps_per_second"] + / 1e12 + / self.args.world_size, + 3, + ) + logs["tokens_per_sec_per_card"] = round( + tokens_per_steps * logs["global_steps_per_second"] / self.args.world_size, + 1, + ) + self._tokens_per_sec_per_card_buffer.append(logs["tokens_per_sec_per_card"]) + logs["tokens_per_sec_per_card_average"] = round(np.mean(self._tokens_per_sec_per_card_buffer), 1) + if self.resume_global_step == -1: + self.resume_global_step = self.state.global_step - 1 + if self.state.global_step <= self.resume_global_step + self.first_skip_step: + self._tokens_per_sec_per_card_buffer = [] + self._end_save_time = time.time() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self._globalstep_last_start_time = time.time() + + info, gathered_info = info_callback() + global_training_logs.reset() + logs.update({f"{k}_cur_dp": v for k, v in info.items()}) + logs.update(gathered_info) + if self.args.enable_global_training_logs: + info_list = [] + dist.all_gather_object(info_list, info) + logs.update( + { + k: np.mean([v[k] for v in info_list if k in v]) + for k in {key for item in info_list for key in item.keys()} + } + ) + + self.log(logs, **kwargs) + + metrics = None + if self.control.should_evaluate: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + + if self.control.should_save: + if hasattr(self.args, "flash_device_save_steps") and self.args.flash_device_save_steps > 0: + is_persistent_ckpt = 1 if self.state.global_step % self.args.save_steps == 0 else 0 + else: + is_persistent_ckpt = 1 + + if is_persistent_ckpt: + self._start_save_time = time.time() + else: + zcc_start_save_time = time.time() + self._save_checkpoint(model, metrics=metrics) + paddle.distributed.barrier() + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + if flag_log: + logs = {"is_persistent_ckpt": is_persistent_ckpt} + tbk = self._start_save_time - self._end_save_time + if (self.state.global_step == self.resume_global_step + self.args.save_steps) or ( + hasattr(self.args, "flash_device_save_steps") + and (self.state.global_step == self.resume_global_step + self.args.flash_device_save_steps) + ): + actual_tbk = self._start_save_time - self._first_end_save_time + actual_avg_speed_step = self.args.save_steps * tokens_per_steps / actual_tbk / self.args.world_size + tbk = tbk / (self.args.save_steps - self.first_skip_step) * self.args.save_steps + if is_persistent_ckpt: + ts = time.time() - self._start_save_time + else: + ts = time.time() - zcc_start_save_time + logs["save_ckpt_time_sec"] = ts + logs["global_save_step"] = self.state.global_step + if is_persistent_ckpt: + tokens_per_steps = self.args.max_seq_length * total_train_batch_size + avg_speed_step = self.args.save_steps * tokens_per_steps / tbk / self.args.world_size + logs["train_time_sec_without_save"] = tbk + logs["average_tokens_per_sec_per_card_without_save"] = round(avg_speed_step, 1) + logs["average_tokens_per_sec_per_card_with_save"] = round( + self.args.save_steps * tokens_per_steps / (tbk + ts) / self.args.world_size, + 2, + ) + if self.state.global_step == self.resume_global_step + self.args.save_steps: + logs["actual_average_tokens_per_sec_per_card_without_save"] = round(actual_avg_speed_step, 1) + logs["actual_average_tokens_per_sec_per_card_with_save"] = round( + self.args.save_steps * tokens_per_steps / (actual_tbk + ts) / self.args.world_size, + 2, + ) + logs["one_day_billion_tokens_without_save"] = round( + 0.0000864 * self.args.save_steps * tokens_per_steps / tbk, 2 + ) + logs["one_day_billion_tokens_with_save"] = round( + 0.0000864 * self.args.save_steps * tokens_per_steps / (tbk + ts), + 2, + ) + self.log(logs, **kwargs) + if is_persistent_ckpt: + self._globalstep_last_start_time = time.time() + self._tokens_per_sec_per_card_buffer = [] + if is_persistent_ckpt: + self._end_save_time = time.time() + + def create_scheduler(self, num_training_steps): + if self.args.warmup_steps > 0: + warmup = self.args.warmup_steps + else: + warmup = int(self.args.warmup_ratio * num_training_steps) + + assert self.args.lr_scheduler.startswith("wsd") + scheduler = self.args.lr_scheduler.split(":") + if len(scheduler) == 2: + num_steady_steps = int(scheduler[1]) + else: + num_steady_steps = None + logger.info(f"using wsd lr scheduler, num_steady_steps={num_steady_steps}") + self.lr_scheduler = get_wsd_schedule_with_warmup( + self.args.learning_rate, + warmup, + self.args.max_steps, + decay_function=self.args.decay_function, + min_lr=self.args.min_lr if self.args.min_lr else 0.0, + num_steady_steps=num_steady_steps, + ) + + return self.lr_scheduler + + def create_optimizer(self, lr_scheduler=None): + optimizer_params = [p for n, p in self.model.named_parameters() if p.stop_gradient is False] + if self.args.train_moe_only: + optimizer_params = ( + [p for n, p in self.model.named_parameters() if "mlp.experts" in n or "mlp.gate" in n] + if self.args.train_moe_only + else [p for n, p in self.model.named_parameters() if p.stop_gradient is False] + ) + logger.info(f"using `train_moe-only`, #moe params={len(optimizer_params)}") + elif len(optimizer_params) < len(self.model.parameters()): + logger.info( + f"some params are not optimized, #totally={len(self.model.parameters())}, \ + #optimized={len(optimizer_params)}" + ) + if self.optimizer is None: + decay_parameters = [ + p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.args.use_moe and not self.args.use_hybrid_parallel: + logger.info("using moe Global clip") + + def expert_fn(p): + return getattr(p, "no_sync", False) + + grad_clip = ClipGradForMOEByGlobalNorm( + self.args.max_grad_norm, + is_expert_param_func=expert_fn, + moe_group=_get_global_group(), + local_clip=False, + ) + else: + grad_clip = nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None + + self.static_name_to_dyg_name = {p.name: n for n, p in self.model.state_dict().items()} + gate_pattern = re.compile(r"ernie\.layers\.0\.mlp\.gate\.weight") + + def lr_ratio_fn(param): + name = self.static_name_to_dyg_name[param.name] + if self.args.moe_gate_lr_ratio is not None and gate_pattern.match(name): + logger.info(f"apply moe_gate_lr_ratio to {name}, ratio={self.args.moe_gate_lr_ratio}") + return float(self.args.moe_gate_lr_ratio) + return 1.0 + + self.optimizer = optimizer_cls( + learning_rate=(self.lr_scheduler if lr_scheduler is None else lr_scheduler), + apply_decay_param_fun=apply_decay_param_fun, + parameters=optimizer_params, + weight_decay=self.args.weight_decay, + grad_clip=grad_clip, + multi_precision=True, + lr_ratio=(lr_ratio_fn if (self.args.moe_gate_lr_ratio is not None) else None), + **optimizer_kwargs, + ) + + return self.optimizer + + def save_model(self, output_dir=None): + super().save_model(output_dir) + if self.args.should_save: + with open(os.path.join(output_dir, "static_name_to_dyg_name.json"), "w") as of: + of.write(json.dumps(self.static_name_to_dyg_name)) + + def _load_rng_state(self, checkpoint): + pass diff --git a/ernie/ERNIE/examples/pre-training/models/comm_utils.py b/ernie/ERNIE/examples/pre-training/models/comm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0618d397961555541581eedc3433724702b0c5b2 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/comm_utils.py @@ -0,0 +1,297 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +from contextlib import contextmanager + +import numpy as np +import paddle +from paddle import distributed as dist +from paddle.distributed import fleet +from paddle.distributed.communication.batch_isend_irecv import ( + _coalescing_manager as batch_isend_irecv_coalescing_manager, +) +from paddle.nn import functional as F +from paddleformers.trainer.plugins.timer import get_timers + +logger = logging.getLogger(__name__) + + +def scatter(input, group=None, axis=0): + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + rank = group.rank + seq_len = input.shape[axis] + assert seq_len % parallelism == 0, ( + f"Input sequence length {seq_len} can't be divided exactly" f" by sequence parallelism {parallelism}" + ) + interval = seq_len // parallelism + input = paddle.slice(input, axes=[axis], starts=[interval * rank], ends=[interval * (rank + 1)]) + input = paddle.assign(input) + return input + + +def mp_slice(x, indices=None, group=None, axis=0): + if indices is None: + return scatter(x, group, axis) + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return x + rank = group.rank + assert len(indices) == parallelism, (len(indices), parallelism) + indices = F.pad(paddle.to_tensor(indices).cumsum(0), [1, 0]) + input = paddle.slice(x, axes=[axis], starts=[indices[rank]], ends=[indices[rank + 1]]) + input = paddle.assign(input) + return input + + +def all_gather_varlen(input, indices, group=None, axis=0, sync_op=True): + assert axis == 0, "only support axis=0" + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + input_sizes = [len(input)] * parallelism + output_sizes = indices + out = paddle.empty([sum(indices)] + input.shape[1:], dtype=input.dtype) + task = dist.stream.alltoall_single( + out, + (paddle.concat([input] * parallelism, 0) if len(input) else input), + output_sizes, + input_sizes, + group=group, + sync_op=sync_op, + use_calc_stream=sync_op, + ) + task.wait() + return out + + +def scatter_varlen(x, recv_tensor, indices, src_rank, group, sync_op=True): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + if rank == src_rank: + in_split_size = indices + else: + x = paddle.empty([], dtype=recv_tensor.dtype) + in_split_size = [0] * world_size + out_split_size = [indices[rank] if i == src_rank else 0 for i in range(world_size)] + task = dist.stream.alltoall_single( + recv_tensor, + x, + out_split_size, + in_split_size, + group=group, + sync_op=sync_op, + use_calc_stream=sync_op, + ) + task.wait() + + +def all_gather(input, group=None, axis=0): + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + output_shape = input.shape + if axis == 0: + output_shape[axis] = output_shape[axis] * parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.all_gather(output, input, group=group, use_calc_stream=True) + return output + outputs = [paddle.empty(output_shape, dtype=input.dtype) for _ in range(parallelism)] + dist.stream.all_gather(outputs, input, group=group, use_calc_stream=True) + output = paddle.concat(outputs, axis=axis) + return output + + +def reduce_scatter(input, group=None): + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + output_shape = input.shape + assert ( + input.shape[0] % parallelism == 0 + ), f"Input sequence length {input.shape[0]} can't be divided exactly by sequence parallelism {parallelism}" + output_shape[0] = output_shape[0] // parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.reduce_scatter(output, input, op=dist.ReduceOp.SUM, group=group, use_calc_stream=True) + return output + + +def subbatch(f, arg_idx, axis, bs, out_idx, use_recompute=False, same_arg_idx={}): + @functools.wraps(f) + def wrapper(*args, **kwargs): + + assert len(arg_idx) == len(axis), "Number of batching args and number of batching dims should match." + + inps = [args[i] for i in arg_idx] + axis_width = [inp.shape[d] for inp, d in zip(inps, axis)] + assert len(set(axis_width)) == 1, "Batch sizes should be kept equal." + + inp_axis = {inp: d for inp, d in zip(inps, axis)} + + axis_width = axis_width[0] + if axis_width < bs: + return f(*args, **kwargs) + + outs = [] + for slice_at in np.arange(0, axis_width, bs): + _args = [] + for i, inp in enumerate(args): + if i in same_arg_idx: + assert ( + i > same_arg_idx[i] + ), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}" + _args.append(_args[same_arg_idx[i]]) + elif i in arg_idx: + inp = inp.slice( + [inp_axis[inp]], + [slice_at], + [min(inp.shape[inp_axis[inp]], slice_at + bs)], + ) + _args.append(inp) + else: + _args.append(inp) + if use_recompute: + out = paddle.distributed.fleet.utils.recompute(f, *_args, **kwargs) + else: + out = f(*_args, **kwargs) + outs.append(out) + + return paddle.concat(outs, out_idx) + + return wrapper + + +def gather_varlen(input, dst, group, offload_pp_data_chunk_size=0, all_shape_and_dtype=None): + if dist.get_world_size(group) <= 1: + return input + if group is None: + group = dist.collective._get_global_group() + + shape_and_dtype = (None, None) if input is None else (input.shape, input.dtype) + if all_shape_and_dtype is None: + all_shape_and_dtype = [] + dist.all_gather_object(all_shape_and_dtype, shape_and_dtype, group=group) + assert any(s is not None for s, _ in all_shape_and_dtype), all_shape_and_dtype + + any_shape = None + shape0_all = [] + for s, d in all_shape_and_dtype: + if s is not None and any_shape is None: + any_shape = s + elif s is not None and any_shape is not None: + assert any_shape[1:] == s[1:], f"{any_shape[1:]} != {s[1:]}" + shape0_all.append(s if s is not None else 0) + + output = [] + if offload_pp_data_chunk_size > 0: + assert (group.nranks >= offload_pp_data_chunk_size) and (group.nranks % offload_pp_data_chunk_size == 0), ( + f"group.nranks {group.nranks} must be greater than offload_pp_data_chunk_size {offload_pp_data_chunk_size} " + f"and group.nranks % offload_pp_data_chunk_size == 0" + ) + if group.ranks[group.rank] == dst: + num_sub_group = group.nranks // offload_pp_data_chunk_size + for sub_group_idx in range(num_sub_group): + start = sub_group_idx * offload_pp_data_chunk_size + end = start + offload_pp_data_chunk_size + tasks = [] + output_ptr = len(output) + with batch_isend_irecv_coalescing_manager(group, tasks): + for src in range(start, end): + if all_shape_and_dtype[src][0] is None or all_shape_and_dtype[src][0][0] == 0: + pass + elif src != group.rank: + recv_tensor = paddle.empty( + all_shape_and_dtype[src][0], + dtype=all_shape_and_dtype[src][1], + ) + output.append(recv_tensor) + task = dist.irecv(recv_tensor, group.ranks[src], group=group) + tasks.append(task) + else: + output.append(input) + for task in tasks: + task.wait() + for i in range(output_ptr, len(output)): + output[i] = output[i].pin_memory() + else: + num_sub_group = group.nranks // offload_pp_data_chunk_size + for sub_group_idx in range(num_sub_group): + start = sub_group_idx * offload_pp_data_chunk_size + end = start + offload_pp_data_chunk_size + tasks = [] + with batch_isend_irecv_coalescing_manager(group, tasks): + for _ in range(1): + if group.rank in list(range(start, end)) and input is not None and input.shape[0] != 0: + task = dist.isend(input, dst, group=group) + tasks.append(task) + for task in tasks: + task.wait() + else: + if group.ranks[group.rank] == dst: + tasks = [] + with batch_isend_irecv_coalescing_manager(group, tasks): + for src in range(group.nranks): + if all_shape_and_dtype[src][0] is None: + pass + elif src != group.rank: + recv_tensor = paddle.empty( + all_shape_and_dtype[src][0], + dtype=all_shape_and_dtype[src][1], + ) + output.append(recv_tensor) + task = dist.irecv(recv_tensor, group.ranks[src], group=group) + tasks.append(task) + else: + output.append(input) + for task in tasks: + task.wait() + else: + tasks = [] + with batch_isend_irecv_coalescing_manager(group, tasks): + for _ in range(1): + if input is not None: + task = dist.isend(input, dst, group=group) + tasks.append(task) + for task in tasks: + task.wait() + + if len(output) != 0: + output = paddle.concat(output, 0) + return output + + +@contextmanager +def profile(name, use_event=True): + if get_timers() is not None: + get_timers()(name, use_event=use_event).start() + yield + if get_timers() is not None: + get_timers()(name, use_event=use_event).stop() diff --git a/ernie/ERNIE/examples/pre-training/models/ernie/__init__.py b/ernie/ERNIE/examples/pre-training/models/ernie/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b00b0579073950b389d410bf2d2a93101881d245 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/ernie/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import ErnieMoEConfig +from .modeling_pp import ErnieMoEForCausalLMPipe + +__all__ = ['ErnieMoEConfig', 'ErnieMoEForCausalLMPipe'] diff --git a/ernie/ERNIE/examples/pre-training/models/ernie/configuration.py b/ernie/ERNIE/examples/pre-training/models/ernie/configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..53255e184925c8c9d2b0a76d3a7e06945446262a --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/ernie/configuration.py @@ -0,0 +1,434 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from typing import Optional, Union + +import paddle.distributed.communication.group +from paddleformers.transformers.configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +__all__ = [ + "ERNIE_PRETRAINED_INIT_CONFIGURATION", + "ErnieMoEConfig", + "ERNIE_PRETRAINED_RESOURCE_FILES_MAP", +] + +ERNIE_PRETRAINED_INIT_CONFIGURATION = { + "ernie/tiny-random-ernie": { + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "ernie", + "num_attention_heads": 2, + "num_hidden_layers": 2, + "rms_norm_eps": 1e-06, + "vocab_size": 32000, + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + "use_cache": False, + "use_recompute": False, + "use_flash_attn": True, + "use_mem_eff_attn": False, + }, +} + +ERNIE_PRETRAINED_RESOURCE_FILES_MAP = { + "model_state": { + "facebookresearch/tiny-random-ernie": "https://bj.bcebos.com/paddleformers/models/community/facebookresearch/tiny-random-ernie/model_state.pdparams", + }, +} + + +class ErnieMoEConfig(PretrainedConfig): + model_type = "ernie" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + max_position_embeddings=32768, + num_hidden_layers=2, + num_attention_heads=2, + head_dim=None, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + use_flash_attn=True, + use_mem_eff_attn=False, + use_flash_attn_with_mask=False, + use_recompute=False, + use_recompute_attn=False, + recompute_use_reentrant=False, + use_rmsnorm=True, + fuse_rms_norm=False, + fuse_ln=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + fuse_attn_ffn=False, + fuse_swiglu=False, + use_bias=False, + expert_mlp_use_bias=None, + rope_reorder=True, + rope_theta=10000, + fuse_rope=False, + use_fast_ln=False, + weight_share_add_bias=True, + fuse_linear=False, + seqlen=False, + ignored_index=-100, + remove_tail_layer=False, + use_recompute_lm_head=False, + use_recompute_loss_fn=False, + use_recompute_mtp=False, + use_recompute_dnd=False, + selective_no_recompute_num=0, + use_mp_gathered_weight=False, + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + compression_ratio: float = 1.0, + num_key_value_heads=None, + use_sparse_head_and_loss_fn=False, + using_dynamic_sequence_length=False, + micro_batch_size=-1, + use_qk_norm=False, + use_tpsp_comm_overlap=False, + use_ep_comm_overlap=False, + offload_pp_data_chunk_size=0, + use_fused_head_loss_fn=False, + use_recompute_resampler=False, + resampler_fuse_rms_norm=False, + token_balance_loss=False, + token_balance_seqlen=False, + use_fp8=False, + fp8_configs=dict(), + use_fp8_mlp=False, + use_fp8_fuse_node=False, + fp8_mem_configs=dict(), + fp8_fused_ops_configs=dict(), + rope_3d=False, + freq_allocation=0, + moe_layer_feed_fake_token=False, + decoderlayer_act_offload_settings={"type": "", "value": ""}, + loss_subbatch_seqlen=32768, + moe_num_experts: Union[int, list] = 0, + use_recompute_moe=False, + moe_capacity=(), + moe_layer_interval=2, + moe_layer_start_index: Union[int, list] = 0, + moe_layer_end_index: Union[int, list] = -1, + moe_aux_loss_lambda=1e-2, + global_aux_loss=False, + moe_dropout_prob=0.0, + moe_group="world", + num_experts_per_tok: int = 8, + moe_intermediate_size: Union[int, list] = 0, + moe_num_shared_experts: int = 0, + moe_num_dense_experts: int = 0, + moe_dense_experts_token_type_id: int = 3, + moe_reverse_token_drop: bool = False, + moe_gate_act: str = "softmax", + moe_norm_gate_logits=True, + moe_fuse_experts: bool = False, + moe_all_to_all_dropout: float = 0.0, + moe_k=2, + moe_use_aux_free: bool = False, + moe_group_experts: bool = False, + enable_delay_scale_loss: bool = True, + num_acc_steps: Optional[int] = None, + insert_empty_layer: Optional[list] = None, + pp_no_recompute_layer: Optional[list] = None, + multi_token_pred_depth: int = 0, + multi_token_pred_lambda: float = 0.3, + fuse_gate_detach_matmul: bool = False, + enable_mtp_magic_send: bool = False, + n_group: int = 0, + topk_group: int = 0, + scaling_factor: Optional[float] = None, + aux_loss_type: str = "", + use_linear_residual_norm_recompute: bool = False, + use_rms_qkv_recompute: bool = False, + use_combine_before_a2a=False, + use_quant_before_a2a=False, + **kwargs, + ): + if "tie_word_embeddings" not in kwargs: + kwargs["tie_word_embeddings"] = False + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_recompute_attn = use_recompute_attn + if use_recompute_attn: + logger.warning("set `use_recompute_attn`=True, disabling `use_recompute`") + use_recompute = False + self.use_recompute = use_recompute + self.use_flash_attn = use_flash_attn + self.recompute_use_reentrant = recompute_use_reentrant + self.use_mem_eff_attn = use_mem_eff_attn + self.use_flash_attn_with_mask = use_flash_attn_with_mask + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.fuse_attn_ffn = fuse_attn_ffn + self.fuse_swiglu = fuse_swiglu + self.fuse_rms_norm = fuse_rms_norm + self.fuse_ln = fuse_ln + self.use_rmsnorm = use_rmsnorm + self.using_dynamic_sequence_length = using_dynamic_sequence_length + if using_dynamic_sequence_length: + assert micro_batch_size > 0, "micro_batch_size should be set when using_dynamic_sequence_length" + self.micro_batch_size = micro_batch_size + self.use_qk_norm = use_qk_norm + + self.seqlen = seqlen + self.use_bias = use_bias + self.weight_share_add_bias = weight_share_add_bias + self.rope_reorder = rope_reorder + self.rope_theta = rope_theta + self.fuse_rope = fuse_rope + self.use_fast_ln = use_fast_ln + + self.fuse_linear = fuse_linear + self.ignored_index = ignored_index + self.remove_tail_layer = remove_tail_layer + self.use_recompute_lm_head = use_recompute_lm_head + self.use_recompute_loss_fn = use_recompute_loss_fn + self.use_recompute_mtp = use_recompute_mtp + self.use_recompute_dnd = use_recompute_dnd + + self.use_mp_gathered_weight = use_mp_gathered_weight + self.selective_no_recompute_num = selective_no_recompute_num + + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.compression_ratio = compression_ratio + self.skip_recompute_ops = dict() + self.num_key_value_heads = num_key_value_heads + self.use_sparse_head_and_loss_fn = use_sparse_head_and_loss_fn + self.use_tpsp_comm_overlap = use_tpsp_comm_overlap + self.use_ep_comm_overlap = use_ep_comm_overlap + self.offload_pp_data_chunk_size = offload_pp_data_chunk_size + self.use_fused_head_loss_fn = use_fused_head_loss_fn + self.use_recompute_resampler = use_recompute_resampler + self.resampler_fuse_rms_norm = resampler_fuse_rms_norm + self.token_balance_loss = token_balance_loss + self.token_balance_seqlen = token_balance_seqlen + self.rope_3d = rope_3d + self.freq_allocation = freq_allocation + self.decoderlayer_act_offload_settings = decoderlayer_act_offload_settings + self.loss_subbatch_seqlen = loss_subbatch_seqlen + self.use_combine_before_a2a = use_combine_before_a2a + + # Fuse activation quantization into the dispatch kernel, using FP8 for All-to-All (A2A) communication. + # Additionally, overlap the A2A operation with weight gradient computation during backward propagation. + self.use_quant_before_a2a = use_quant_before_a2a + + default_fp8_configs = { + "quant_scheme": "DelayedScaling", + "recipe": { + "format": "hybrid", + "calibrating": True, + "amax_history_len": 1024, + "amax_compute_algo": "max", + "fuse_wgrad_accumulation": False, + "quant_weight_at_first_microbatch": False, + }, + "layers": { + "attn_fc1_linear": True, + "attn_fc2_linear": True, + "mlp_fc1_linear": True, + "mlp_fc2_linear": True, + "attn_tp_fc1_linear": True, + "attn_tp_fc2_linear": True, + "mlp_tp_fc1_linear": True, + "mlp_tp_fc2_linear": True, + }, + "smooth_swiglu": False, + } + + def update_nested_dict(default_dict, update_dict): + for key, value in update_dict.items(): + if isinstance(value, dict) and key in default_dict and isinstance(default_dict[key], dict): + update_nested_dict(default_dict[key], value) + else: + default_dict[key] = value + + update_nested_dict(default_fp8_configs, fp8_configs) + self.fp8_configs = default_fp8_configs + self.use_fp8 = use_fp8 + self.expert_mlp_use_bias = expert_mlp_use_bias + self.use_fp8_mlp = use_fp8_mlp + self.use_fp8_fuse_node = use_fp8_fuse_node + default_fp8_mem_configs = { + "shared_expert": False, + "recompute_fwd_gate_up": False, + "dequant_input": False, + "offline_quant_expert_weight": False, + "clear_origin_weight_when_offline_quant": False, + } + update_nested_dict(default_fp8_mem_configs, fp8_mem_configs) + self.fp8_mem_configs = default_fp8_mem_configs + default_fp8_fused_ops_configs = { + "stack_quant": False, + "swiglu_probs_bwd": False, + "split_group_gemm": True, + } + update_nested_dict(default_fp8_fused_ops_configs, fp8_fused_ops_configs) + self.fp8_fused_ops_configs = default_fp8_fused_ops_configs + self.moe_layer_feed_fake_token = moe_layer_feed_fake_token + + if self.sequence_parallel: + assert ( + self.using_dynamic_sequence_length or self.seqlen + ), "seqlen not provided in sequence-parallel when not using dygramic sequence length" + + assert ( + self.tensor_parallel_degree > 1 + ), f"sequence-parallel only works in mp, got mp={self.tensor_parallel_degree}" + + if use_recompute_moe: + logger.warning("set `use_recompute_moe`=True, disabling `use_recompute`") + kwargs["use_recompute"] = False + + self.use_recompute_moe = use_recompute_moe + self.moe_num_experts = moe_num_experts + self.moe_capacity = moe_capacity + self.moe_aux_loss_lambda = moe_aux_loss_lambda + self.global_aux_loss = global_aux_loss + self.moe_layer_interval = moe_layer_interval + self.moe_dropout_prob = moe_dropout_prob + self.moe_group = moe_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_num_shared_experts = moe_num_shared_experts + self.moe_num_dense_experts = moe_num_dense_experts + self.moe_dense_experts_token_type_id = moe_dense_experts_token_type_id + self.moe_intermediate_size = moe_intermediate_size + self.moe_reverse_token_drop = moe_reverse_token_drop + self.moe_fuse_experts = moe_fuse_experts + self.moe_k = moe_k + self.moe_all_to_all_dropout = moe_all_to_all_dropout + self.moe_group_experts = moe_group_experts + self.enable_delay_scale_loss = enable_delay_scale_loss + self.num_acc_steps = num_acc_steps + self.moe_layer_start_index = moe_layer_start_index + self.moe_layer_end_index = self.num_hidden_layers - 1 if moe_layer_end_index == -1 else moe_layer_end_index + self.moe_gate_act = moe_gate_act + self.moe_norm_gate_logits = moe_norm_gate_logits + self.moe_use_aux_free = moe_use_aux_free + self.fuse_gate_detach_matmul = fuse_gate_detach_matmul + if insert_empty_layer is not None: + assert isinstance(insert_empty_layer, list), "insert_empty_layer should be a list" + else: + insert_empty_layer = [] + + self.multi_token_pred_depth = multi_token_pred_depth + self.multi_token_pred_lambda = multi_token_pred_lambda + self.enable_mtp_magic_send = enable_mtp_magic_send + self.insert_empty_layer = insert_empty_layer + self.n_group = n_group + self.topk_group = topk_group + self.scaling_factor = scaling_factor + + self.use_linear_residual_norm_recompute = use_linear_residual_norm_recompute + self.use_rms_qkv_recompute = use_rms_qkv_recompute + + assert aux_loss_type in ["", "default", "seq_aux_loss", "switch_aux_loss"] + self.aux_loss_type = aux_loss_type + + if pp_no_recompute_layer is not None: + assert isinstance(insert_empty_layer, list), "pp_no_recompute_layer should be a list" + + self.pp_no_recompute_layer = pp_no_recompute_layer + self.register_nonsaveable_keys("moe_group") + self.register_nonsaveable_keys("pp_no_recompute_layer") + self.register_nonsaveable_keys("use_recompute") + self.register_nonsaveable_keys("recompute_use_reentrant") + self.register_nonsaveable_keys("use_recompute_attn") + self.register_nonsaveable_keys("use_recompute_lm_head") + self.register_nonsaveable_keys("use_recompute_loss_fn") + self.register_nonsaveable_keys("skip_recompute_ops") + + def __setattr__(self, name: str, value): + super().__setattr__(name, value) + if getattr(self, "use_recompute", False): + assert not getattr( + self, "use_recompute_attn", False + ), "cannot set `use_recompute_attn=True` when `use_recompute=True`" + if getattr(self, "use_recompute", False): + assert not getattr( + self, "use_recompute_moe", False + ), "cannot set `use_recompute_moe=True` when `use_recompute=True`" + + def register_nonsaveable_keys(self, keys): + if hasattr(super(), "register_nonsaveable_keys"): + return super().register_nonsaveable_keys(keys) + elif hasattr(super(), "register_unsavable_keys"): + return super().register_unsavable_keys(keys) + else: + raise AttributeError("register_nonsaveable_keys not found in PretrainedConfig") + + @property + def use_moe(self) -> bool: + return self.moe_num_experts > 0 + + def to_json_string(self, use_diff: bool = True) -> str: + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + + def _serializer(obj): + if isinstance(obj, paddle.distributed.communication.group.Group): + return repr(obj) + raise TypeError(f"Type {type(obj)} is not serializable") + + return ( + json.dumps( + config_dict, + indent=2, + sort_keys=True, + ensure_ascii=False, + default=_serializer, + ) + + "\n" + ) diff --git a/ernie/ERNIE/examples/pre-training/models/ernie/modeling.py b/ernie/ERNIE/examples/pre-training/models/ernie/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..79993d1a21bc5ef7041175c93f4af1ece67ef85b --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/ernie/modeling.py @@ -0,0 +1,2381 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import copy +import logging +import math +from functools import partial +from typing import Optional, Tuple + +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from models.comm_utils import subbatch +from models.fp8_linear import MemEfficientFp8FusedMlpFunc +from models.sequence_parallel_utils import ( + AllGatherVarlenOp, + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + mark_as_sequence_parallel_parameter, + sequence_parallel_sparse_mask_labels, +) +from paddle import nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu import mp_ops +from paddle.distributed.fleet.layers.mpu.mp_layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import fused_rms_norm_ext +from paddle.incubate.nn.memory_efficient_attention import ( + BlockDiagonalCausalMask, + memory_efficient_attention, +) +from paddleformers.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddleformers.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from paddleformers.transformers.model_utils import PretrainedModel, register_base_model +from paddleformers.utils.tools import get_env_device + +from .configuration import ErnieMoEConfig + +logger = logging.getLogger(__name__) + +NativeLinear = nn.Linear + + +try: + from paddle.nn.functional.flash_attention import flash_attention + + logger.warning("Use flash attention in scaled-dot-product. Attention mask is deprecated") +except (ImportError, ModuleNotFoundError): + flash_attention = None + +try: + from paddle.nn.functional.flash_attention import flash_attention_with_mask +except (ImportError, ModuleNotFoundError): + try: + from paddle.nn.functional.flash_attention import ( + scaled_dot_product_attention as flash_attention_with_mask, + ) + except (ImportError, ModuleNotFoundError): + flash_attention_with_mask = None + + +try: + from paddle.nn.functional.flash_attention import flashmask_attention +except (ImportError, ModuleNotFoundError): + flashmask_attention = None + +try: + from paddle.incubate.nn.functional import ( + fused_rotary_position_embedding as fused_rope, + ) +except (ImportError, ModuleNotFoundError): + logger.warning("fused_rotary_position_embedding not found") + fused_rope = None + +try: + from paddle.incubate.nn.functional import swiglu as fused_swiglu +except (ImportError, ModuleNotFoundError): + fused_swiglu = None + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} + + +ERNIE_PRETRAINED_MODEL_ARCHIVE_LIST = [] + +__all__ = [ + "ErnieModel", + "ErniePretrainedModel", + "ErnieForCausalLM", +] + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + shape = x.shape + shape[1] = 1 + mask = paddle.full(shape, -np.inf, dtype=x.dtype) + mask.stop_gradient = True + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def gqa_qkv_split_func( + weight, + tensor_parallel_degree, + tensor_parallel_rank, + num_attention_heads, + num_key_value_heads, + head_dim, +): + q, k, v = paddle.split( + weight, + [ + num_attention_heads * head_dim, + num_key_value_heads * head_dim, + num_key_value_heads * head_dim, + ], + axis=-1, + ) + if tensor_parallel_rank is None: + q_list = paddle.split(q, tensor_parallel_degree, axis=-1) + k_list = paddle.split(k, tensor_parallel_degree, axis=-1) + v_list = paddle.split(v, tensor_parallel_degree, axis=-1) + ret = [paddle.concat([q, k, v], axis=-1) for q, k, v in zip(q_list, k_list, v_list)] + return ret + else: + q = paddle.split(q, tensor_parallel_degree, axis=-1)[tensor_parallel_rank] + k = paddle.split(k, tensor_parallel_degree, axis=-1)[tensor_parallel_rank] + v = paddle.split(v, tensor_parallel_degree, axis=-1)[tensor_parallel_rank] + return paddle.concat([q, k, v], axis=-1) + + +def gqa_qkv_merge_func(weight_list, num_attention_heads, num_key_value_heads, head_dim): + tensor_parallel_degree = len(weight_list) + num_attention_heads = num_attention_heads // tensor_parallel_degree + num_key_value_heads = num_key_value_heads // tensor_parallel_degree + q_list, k_list, v_list = [], [], [] + for weight in weight_list: + q, k, v = paddle.split( + weight, + [ + num_attention_heads * head_dim, + num_key_value_heads * head_dim, + num_key_value_heads * head_dim, + ], + axis=-1, + ) + q_list.append(q) + k_list.append(k) + v_list.append(v) + return paddle.concat(q_list + k_list + v_list, axis=-1) + + +def parallel_matmul( + x, + y, + bias=None, + transpose_y=False, + tensor_parallel_degree=1, + tensor_parallel_output=True, + fuse_linear=False, + training=True, +): + if tensor_parallel_degree > 1: + if isinstance(y, paddle.base.framework.EagerParamBase): + assert y.is_distributed + + pg = fleet.get_hybrid_communicate_group().get_model_parallel_group() + input_parallel = paddle.distributed.collective._c_identity(x, group=pg) + if transpose_y: + logits = paddle.matmul(input_parallel, y, transpose_y=True) + if bias is not None: + logits += bias + else: + if fuse_linear: + logits = paddle.incubate.nn.functional.fused_linear(input_parallel, y, bias) + else: + logits = F.linear(input_parallel, y, bias) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=pg) + + else: + if fuse_linear: + logits = paddle.incubate.nn.functional.fused_linear(x, y, bias, transpose_weight=transpose_y) + else: + logits = paddle.matmul(x, y, transpose_y=transpose_y) + if bias is not None: + logits += bias + return logits + + +def calc_lm_head_logits(config, hidden_states, weight, bias, tensor_parallel_output=None, training=True): + if config.sequence_parallel: + if config.use_sparse_head_and_loss_fn: + pass + else: + lm_head_use_gather = getattr(config, "lm_head_use_gather", True) + if lm_head_use_gather: + hidden_states = GatherOp.apply(hidden_states) + if not config.using_dynamic_sequence_length: + hidden_states = hidden_states.reshape([-1, config.seqlen, hidden_states.shape[-1]]) + else: + assert config.micro_batch_size, "micro_batch_size should be set when using dygramic sequence length." + hidden_states = hidden_states.reshape([config.micro_batch_size, -1, hidden_states.shape[-1]]) + + if tensor_parallel_output is None: + tensor_parallel_output = config.tensor_parallel_output + logits = parallel_matmul( + hidden_states, + weight, + bias=bias, + transpose_y=config.tie_word_embeddings, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_output=tensor_parallel_output, + fuse_linear=config.fuse_linear, + training=training, + ) + + return logits + + +def finfo(dtype: paddle.dtype = None): + if dtype is None: + dtype = paddle.get_default_dtype() + + if dtype == paddle.bfloat16: + + class BFloatFInfo: + min = -3.3895313892515355e38 + + return BFloatFInfo + if dtype == paddle.float32: + return np.finfo(np.float32) + if dtype == paddle.float16: + return np.finfo(np.float16) + if dtype == paddle.float64: + return np.finfo(np.float64) + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def mem_eff_attn(query, key, value, pack_offset, drop_prob=0.0, dtype=paddle.bfloat16, training=True): + pack_offset = pack_offset.numpy() + shape = pack_offset.shape + assert len(shape) == 2, len(shape) + assert shape[0] == 1, shape[0] + n = pack_offset.size + pack_offset = pack_offset.flatten() + seqlens = [] + assert pack_offset[0] == 0, pack_offset[0] + for i in range(1, n): + if pack_offset[i] < 0: + break + cur = pack_offset[i] - pack_offset[i - 1] + assert cur > 0 + seqlens.append(cur) + + assert drop_prob == 0.0, drop_prob + assert dtype == paddle.bfloat16, dtype + + def cast(x): + return x.astype(dtype) if x.dtype != dtype else x + + if len(seqlens) == 1: + out, _ = flash_attention(query, key, value, drop_prob, causal=True, training=training) + else: + mask = BlockDiagonalCausalMask.from_seqlens(seqlens) + out = memory_efficient_attention( + cast(query), + cast(key), + cast(value), + attn_bias=mask, + p=drop_prob, + training=training, + ) + return out + + +def inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset): + inbatch_pack_offset = inbatch_pack_offset.numpy() + attn_mask_row_start_indices = [] + min_start_row = np.inf + for bidx in range(inbatch_pack_offset.shape[0]): + item = inbatch_pack_offset[bidx] + cumsum_item = item[item != -1] + record_lens = cumsum_item[1:] - cumsum_item[0:-1] + min_start_row = min(cumsum_item[1], min_start_row) + row_start_indices = np.repeat(cumsum_item[1:], record_lens) + attn_mask_row_start_indices.append(row_start_indices[None, None, ...]) + attn_mask_row_start_indices = np.concatenate(attn_mask_row_start_indices, axis=0) + return paddle.to_tensor(attn_mask_row_start_indices, dtype=paddle.int32), int(min_start_row) + + +def scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + output_attentions, + config, + is_causal=True, + inbatch_pack_offset=None, + training=True, + startend_row_indices=None, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = value_states.shape + + if startend_row_indices is not None: + flashmask_attention_func = flashmask_attention + + attn_output = flashmask_attention_func( + query_states.astype(value_states.dtype), + key_states.astype(value_states.dtype), + value_states.astype(value_states.dtype), + startend_row_indices=startend_row_indices, + dropout=config.attention_probs_dropout_prob, + causal=False, + ) + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return attn_output, None + + can_use_fa = config.use_flash_attn and flash_attention is not None + can_use_fa_sparse_mask = ( + config.use_mem_eff_attn and inbatch_pack_offset is not None and flashmask_attention is not None + ) + + if not can_use_fa and not can_use_fa_sparse_mask: + if query_states.shape[-2] != key_states.shape[-2]: + key_states = key_states.repeat_interleave(num_heads // num_key_value_heads, axis=-2) + if query_states.shape[-2] != value_states.shape[-2]: + value_states = value_states.repeat_interleave(num_heads // num_key_value_heads, axis=-2) + + if can_use_fa: + assert not (config.use_mem_eff_attn and inbatch_pack_offset is not None) + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + dropout=config.attention_probs_dropout_prob, + causal=is_causal and query_states.shape[1] != 1, + return_softmax=output_attentions, + ) + + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return attn_output, attn_weights + else: + + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) / math.sqrt(head_dim) + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + if training: + attn_weights = attention_mask + attn_weights + attn_weights = paddle.maximum( + attn_weights, + paddle.to_tensor(float(finfo(query_states.dtype).min), dtype=query_states.dtype), + ) + + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + attn_weights = attn_weights.cast(paddle.float32) + attention_mask = attention_mask.cast(paddle.float32) + attn_weights = attn_weights.add_(attention_mask) + attn_weights = F.softmax_(attn_weights, axis=-1).astype(query_states.dtype) + + if config.attention_probs_dropout_prob > 0.0: + if config.tensor_parallel_degree > 1: + with get_rng_state_tracker().rng_state("local_seed"): + attn_weights = F.dropout( + attn_weights, + config.attention_probs_dropout_prob, + training=training, + mode="upscale_in_train", + ) + else: + attn_weights = F.dropout( + attn_weights, + config.attention_probs_dropout_prob, + training=training, + mode="upscale_in_train", + ) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + if output_attentions: + return attn_output, attn_weights + return attn_output, None + + +def _make_causal_mask(input_ids_shape, past_key_values_length, dtype): + batch_size, target_length = input_ids_shape + + mask = paddle.full((target_length, target_length), float(finfo(dtype).min)) + + mask_cond = paddle.arange(mask.shape[-1]) + mask = masked_fill(mask, mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0) + + if past_key_values_length > 0: + mask = paddle.concat([paddle.zeros([target_length, past_key_values_length]), mask], axis=-1) + + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_mask(mask, dtype, tgt_length): + if mask.ndim == 4: + expanded_mask = mask + elif mask.ndim == 3: + expanded_mask = mask[:, None, :, :] + else: + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = mask[:, None, None, :].expand([batch_size, 1, tgt_length, src_length]) + + inverted_mask = 1.0 - expanded_mask + return masked_fill(inverted_mask, inverted_mask.cast("bool"), float(finfo(dtype).min)) + + +class FusedDropoutImpl(nn.Layer): + + def __init__(self, prob, mode): + super().__init__() + self.prob = prob + self.mode = mode + + self.dropout = nn.Dropout(p=prob, mode=mode) + + def forward(self, x, y): + if self.prob > 0: + x = self.dropout(x) + output = x + y + + return output + + +class RMSNorm(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, hidden_states): + if self.config.fuse_rms_norm: + return fused_rms_norm_ext(hidden_states, self.weight, self.variance_epsilon)[0] + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +class RotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=4096, base=10000): + super().__init__() + self.base = base + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / dim)) + + t = paddle.arange(max_position_embeddings, dtype="float32") + freqs = paddle.einsum("i,j->ij", t, inv_freq.cast("float32")) + emb = paddle.concat([freqs, freqs], axis=-1) + + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + + self._cast_to_low_precision = False + self._cast_to_low_precision = False + + def forward(self, x, seq_len=None): + + return ( + self.cos_cached[:seq_len, :], + self.sin_cached[:seq_len, :], + ) + + @classmethod + def rotate_half(cls, x): + + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) + + @classmethod + def apply_rotary_pos_emb(cls, q, k, cos, sin, offset: int = 0, position_ids=None): + if position_ids is not None: + assert offset == 0, offset + cos = F.embedding(position_ids, cos) + sin = F.embedding(position_ids, sin) + else: + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + cos = cos[:, offset : q.shape[1] + offset, None, :] + sin = sin[:, offset : q.shape[1] + offset, None, :] + + q_embed = paddle.add(paddle.multiply(q, cos), paddle.multiply(cls.rotate_half(q), sin)) + k_embed = paddle.add(paddle.multiply(k, cos), paddle.multiply(cls.rotate_half(k), sin)) + q_embed = q_embed.astype(q.dtype) + k_embed = k_embed.astype(k.dtype) + return q_embed, k_embed + + +class RopeEmbeddingLegacy(nn.Layer): + def __init__(self, head_dim, compression_ratio=1.0, base=10000, freq_allocation=0): + super().__init__() + self.head_dim = head_dim + self.compression_ratio = compression_ratio + self.base = base + self.freq_allocation = freq_allocation + + def forward(self, seq_length, position_ids=None): + indices = paddle.arange(0, self.head_dim, 2, dtype="float32") + indices = 1 / self.base ** (indices / self.head_dim) + if position_ids is None: + position_ids = paddle.arange(0, seq_length, 1, dtype="float32").unsqueeze(1) + position_ids = position_ids / self.compression_ratio + sinusoid_inp = position_ids * indices.unsqueeze(0) + else: + position_ids = position_ids / self.compression_ratio + seq_length = position_ids.shape[-1] + sinusoid_inp = position_ids.unsqueeze(-1).astype("float32") * indices.unsqueeze(0) + pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) + pos_emb = paddle.reshape(pos_emb, (-1, 1, seq_length, self.head_dim)) + pos_emb.stop_gradient = True + return pos_emb + + def apply_rotary(self, rp, q, k): + sin, cos = paddle.chunk(rp, 2, axis=-1) + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), rp.shape) + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), rp.shape) + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + query = paddle.add( + paddle.multiply(q.astype("float32"), cos_pos), + paddle.multiply(rotate_half_q.astype("float32"), sin_pos), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + key = paddle.add( + paddle.multiply(k.astype("float32"), cos_pos), + paddle.multiply(rotate_half_k.astype("float32"), sin_pos), + ) + return query, key + + def apply_rotary_3d(self, rp, q, k, position_ids): + sin, cos = paddle.chunk(rp, 2, axis=-1) + assert position_ids.shape[:1] == q.shape[:1] + batch_indices = paddle.arange(end=position_ids.shape[0]) + batch_indices = batch_indices[..., None] + sin = sin.tile([position_ids.shape[0], 1, 1, 1]) + cos = cos.tile([position_ids.shape[0], 1, 1, 1]) + + assert self.freq_allocation != 0 + sin_t = sin[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] + sin_h = sin[ + batch_indices, + position_ids[..., 1], + :, + : self.head_dim // 2 - self.freq_allocation : 2, + ] + sin_w = sin[ + batch_indices, + position_ids[..., 2], + :, + 1 : self.head_dim // 2 - self.freq_allocation : 2, + ] + sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2]) + sin_thw = paddle.concat([sin_hw, sin_t], axis=-1) + + cos_t = cos[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] + cos_h = cos[ + batch_indices, + position_ids[..., 1], + :, + : self.head_dim // 2 - self.freq_allocation : 2, + ] + cos_w = cos[ + batch_indices, + position_ids[..., 2], + :, + 1 : self.head_dim // 2 - self.freq_allocation : 2, + ] + cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2]) + cos_thw = paddle.concat([cos_hw, cos_t], axis=-1) + + sin_pos = paddle.reshape( + paddle.stack([sin_thw, sin_thw], axis=-1), + sin_thw.shape[:3] + [sin_thw.shape[-1] * 2], + ) + cos_pos = paddle.reshape( + paddle.stack([cos_thw, cos_thw], axis=-1), + cos_thw.shape[:3] + [cos_thw.shape[-1] * 2], + ) + + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + query = paddle.add( + paddle.multiply(q.astype("float32"), cos_pos), + paddle.multiply(rotate_half_q.astype("float32"), sin_pos), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + key = paddle.add( + paddle.multiply(k.astype("float32"), cos_pos), + paddle.multiply(rotate_half_k.astype("float32"), sin_pos), + ) + return query, key + + def forward_single(self, position_ids): + batch_size, seq_length = position_ids.shape[:2] + rope_emb = paddle.zeros((2, batch_size, seq_length, 1, self.head_dim), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim) + position_ids = position_ids.cast("float32") + position_ids = position_ids / self.compression_ratio + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + emb = paddle.stack([freqs, freqs], axis=-1).reshape((batch_size, seq_length, self.head_dim)) + emb = paddle.unsqueeze(emb, 2) + + rope_emb[0] = paddle.cos(emb) + rope_emb[1] = paddle.sin(emb) + return rope_emb + + @staticmethod + def apply_rotary_single(x, rope_emb): + rotate_half_x = paddle.reshape( + paddle.stack([-x[:, :, :, 1::2], x[:, :, :, 0::2]], axis=-1), + paddle.shape(x), + ) + return x * rope_emb[0] + rotate_half_x * rope_emb[1] + + +class ErnieMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.fuse_ffn = config.fuse_attn_ffn + + if config.tensor_parallel_degree > 1: + ColumnLN = ColumnSequenceParallelLinear if config.sequence_parallel else ColumnParallelLinear + RowLN = RowSequenceParallelLinear if config.sequence_parallel else RowParallelLinear + + column_ln_configs = ( + {"use_rr": config.use_recompute and config.skip_recompute_ops.get("mlp_column_ln", False)} + if config.sequence_parallel and get_env_device() == "gpu" + else {} + ) + if config.sequence_parallel and get_env_device() == "gpu": + column_ln_configs["use_tpsp_comm_overlap"] = config.use_tpsp_comm_overlap + if config.fuse_attn_ffn: + self.up_gate_proj = ColumnLN( + self.hidden_size, + self.intermediate_size * 2, + gather_output=False, + has_bias=config.use_bias, + fuse_matmul_bias=config.fuse_linear, + **column_ln_configs, + ) + else: + self.gate_proj = ColumnLN( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=config.use_bias, + fuse_matmul_bias=config.fuse_linear, + **column_ln_configs, + ) + self.up_proj = ColumnLN( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=config.use_bias, + fuse_matmul_bias=config.fuse_linear, + **column_ln_configs, + ) + else: + LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + if config.fuse_attn_ffn: + self.up_gate_proj = LinearFN( + self.hidden_size, + self.intermediate_size * 2, + bias_attr=config.use_bias, + ) + else: + self.gate_proj = LinearFN(self.hidden_size, self.intermediate_size, bias_attr=config.use_bias) + self.up_proj = LinearFN(self.hidden_size, self.intermediate_size, bias_attr=config.use_bias) + + if config.tensor_parallel_degree > 1: + row_ln_configs = ( + {"use_rr": config.use_recompute and config.skip_recompute_ops.get("mlp_row_ln", False)} + if config.sequence_parallel and get_env_device() == "gpu" + else {} + ) + if config.sequence_parallel and get_env_device() == "gpu": + row_ln_configs["use_tpsp_comm_overlap"] = config.use_tpsp_comm_overlap + self.down_proj = RowLN( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=config.use_bias, + fuse_matmul_bias=config.fuse_linear, + **row_ln_configs, + ) + else: + LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + self.down_proj = LinearFN(self.intermediate_size, self.hidden_size, bias_attr=config.use_bias) + + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def forward(self, x): + if ( + self.config.tensor_parallel_degree <= 1 + and self.fuse_ffn + and self.config.use_fp8_mlp + and not self.config.use_bias + ): + return MemEfficientFp8FusedMlpFunc.apply(x, self.up_gate_proj.weight, self.down_proj.weight) + + if self.fuse_swiglu: + if self.fuse_ffn: + if self.config.use_fp8 and self.config.fp8_configs["smooth_swiglu"]: + x, gate = self.up_gate_proj(x).chunk(2, axis=-1) + with paddle.no_grad(): + scale = paddle.clip(gate.abs().max(axis=-1, keepdim=True), 1e-8) + + gate = gate / scale + if self.config.sequence_parallel: + scale = ScatterOp.apply(scale) + + x = paddle.concat([x, gate], axis=-1) + else: + x = self.up_gate_proj(x) + x = fused_swiglu(x) + else: + x = fused_swiglu(self.gate_proj(x), self.up_proj(x)) + else: + if self.fuse_ffn: + x, gate = self.up_gate_proj(x).chunk(2, axis=-1) + x = F.silu(x) * gate + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + if self.config.use_fp8 and self.config.fp8_configs["smooth_swiglu"]: + return self.down_proj(x) * scale + return self.down_proj(x) + + +class ErnieAttention(nn.Layer): + + def __init__(self, config, layer_idx=0): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + if config.head_dim is None: + self.head_dim = self.hidden_size // self.num_heads + else: + self.head_dim = config.head_dim + self.fuse_attn = config.fuse_attn_ffn + self.use_recompute_attn = config.use_recompute_attn + logger.info(f"using recompute attn={self.use_recompute_attn}") + self.is_gqa = config.num_key_value_heads is not None and config.num_key_value_heads != self.num_heads + if config.fuse_rope: + assert fused_rope is not None, "fused_rope is not supported" + self.fuse_rope = config.fuse_rope + self.rope_3d = config.rope_3d + if self.rope_3d: + assert not self.fuse_rope, "does not support fuse rope when rope_3d is on for now." + assert not config.rope_reorder, "does not support rope_reorder when rope_3d is on for now." + assert config.freq_allocation is not None, "freq_allocation must be provided if rope_3d is on." + + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + if self.is_gqa: + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + if self.is_gqa: + logger.info(f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}") + assert ( + self.num_heads % self.num_key_value_heads == 0 + ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" + q_hidden_size = self.head_dim * config.num_attention_heads + kv_hidden_size = self.head_dim * config.num_key_value_heads + else: + q_hidden_size = kv_hidden_size = self.head_dim * config.num_attention_heads + + if config.tensor_parallel_degree > 1: + ColumnLN = ColumnSequenceParallelLinear if config.sequence_parallel else ColumnParallelLinear + RowLN = RowSequenceParallelLinear if config.sequence_parallel else RowParallelLinear + column_ln_configs = ( + {"use_rr": config.use_recompute and config.skip_recompute_ops.get("attention_column_ln", False)} + if config.sequence_parallel and get_env_device() == "gpu" + else {} + ) + if config.sequence_parallel and get_env_device() == "gpu": + column_ln_configs["use_tpsp_comm_overlap"] = config.use_tpsp_comm_overlap + + if config.fuse_attn_ffn: + self.qkv_proj = ColumnLN( + self.hidden_size, + q_hidden_size + 2 * kv_hidden_size, + has_bias=config.use_bias, + gather_output=False, + fuse_matmul_bias=config.fuse_linear, + **column_ln_configs, + ) + else: + self.q_proj = ColumnLN( + self.hidden_size, + q_hidden_size, + has_bias=config.use_bias, + gather_output=False, + fuse_matmul_bias=config.fuse_linear, + **column_ln_configs, + ) + self.k_proj = ColumnLN( + self.hidden_size, + kv_hidden_size, + has_bias=config.use_bias, + gather_output=False, + fuse_matmul_bias=config.fuse_linear, + **column_ln_configs, + ) + self.v_proj = ColumnLN( + self.hidden_size, + kv_hidden_size, + has_bias=config.use_bias, + gather_output=False, + fuse_matmul_bias=config.fuse_linear, + **column_ln_configs, + ) + else: + LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + if config.fuse_attn_ffn: + self.qkv_proj = LinearFN( + self.hidden_size, + q_hidden_size + 2 * kv_hidden_size, + bias_attr=config.use_bias, + ) + else: + self.q_proj = LinearFN( + self.hidden_size, + q_hidden_size, + bias_attr=config.use_bias, + ) + self.k_proj = LinearFN( + self.hidden_size, + kv_hidden_size, + bias_attr=config.use_bias, + ) + self.v_proj = LinearFN( + self.hidden_size, + kv_hidden_size, + bias_attr=config.use_bias, + ) + + if config.tensor_parallel_degree > 1: + row_ln_configs = ( + {"use_rr": config.use_recompute and config.skip_recompute_ops.get("attention_row_ln", False)} + if config.sequence_parallel and get_env_device() == "gpu" + else {} + ) + if config.sequence_parallel and get_env_device() == "gpu": + row_ln_configs["use_tpsp_comm_overlap"] = config.use_tpsp_comm_overlap + + self.o_proj = RowLN( + q_hidden_size, + self.hidden_size, + has_bias=config.use_bias, + input_is_parallel=True, + fuse_matmul_bias=config.fuse_linear, + **row_ln_configs, + ) + else: + LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + self.o_proj = LinearFN( + q_hidden_size, + self.hidden_size, + bias_attr=config.use_bias, + ) + if config.rope_reorder: + self.rotary_emb = RotaryEmbedding( + self.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + else: + self.rotary_emb = RopeEmbeddingLegacy( + self.head_dim, + compression_ratio=config.compression_ratio, + base=config.rope_theta, + freq_allocation=config.freq_allocation, + ) + self.config = config + + if self.config.use_qk_norm: + logger.info(f"use_qk_norm, the head_dim is {self.head_dim}") + Norm = RMSNorm + + qk_norm_config = copy.deepcopy(config) + qk_norm_config.hidden_size = self.head_dim + self.q_norm = Norm(qk_norm_config) + self.k_norm = Norm(qk_norm_config) + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + inbatch_pack_offset: Optional[Tuple[paddle.Tensor]] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + if self.config.sequence_parallel: + if not self.config.using_dynamic_sequence_length: + bsz = hidden_states.shape[0] * self.config.tensor_parallel_degree // self.config.seqlen + q_len = self.config.seqlen + else: + assert ( + self.config.micro_batch_size + ), "micro_batch_size should be set when using dygramic sequence length." + + bsz = self.config.micro_batch_size + q_len = hidden_states.shape[0] * self.config.tensor_parallel_degree // bsz + else: + bsz, q_len, _ = hidden_states.shape + query_states = key_states = value_states = mix_layer = None + if self.fuse_attn: + mix_layer = self.qkv_proj(hidden_states) + if self.is_gqa: + query_states, key_states, value_states = paddle.split( + mix_layer.reshape([bsz, q_len, -1, self.head_dim]), + [ + self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads, + ], + axis=2, + ) + mix_layer = None + else: + mix_layer = mix_layer.reshape([bsz, q_len, self.num_heads, 3 * self.head_dim]) + else: + query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + key_states = self.k_proj(hidden_states).reshape( + shape=[ + bsz, + q_len, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + value_states = self.v_proj(hidden_states).reshape( + shape=[ + bsz, + q_len, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + + if self.use_recompute_attn: + assert past_key_value is None, "do not use kv cache in recompute" + assert not use_cache + attn_output, attn_weights, past_key_value = recompute( + self.rope_attn, + mix_layer, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + attn_output, attn_weights, past_key_value = self.rope_attn( + mix_layer=mix_layer, + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + ) + + if self.config.sequence_parallel: + attn_output = attn_output.reshape([-1, attn_output.shape[-1]]) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def rope_attn( + self, + mix_layer, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions=False, + past_key_value=None, + use_cache=False, + inbatch_pack_offset=None, + ): + if mix_layer is not None: + query_states, key_states, value_states = paddle.split(mix_layer, 3, axis=-1) + query_states_dtype = query_states.dtype + + if self.rope_3d: + assert position_ids is not None, "rope3d requires pos-id" + kv_seq_len = key_states.shape[-3] if not self.rope_3d else position_ids.max() + 1 + offset = 0 + if past_key_value is not None: + if not self.rope_3d: + offset = past_key_value[0].shape[-3] + kv_seq_len += offset + else: + offset = position_ids.max() + kv_seq_len = position_ids.max() + 1 + position_ids = position_ids[:, -1:, :] + + if self.config.rope_reorder: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = self.rotary_emb.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids=position_ids, + offset=offset if position_ids is None else 0, + ) + else: + if offset > 0 or position_ids is not None or not self.fuse_rope: + if not self.rope_3d: + cos_sin = self.rotary_emb(kv_seq_len, position_ids).transpose([0, 2, 1, 3]) + if offset > 0 and position_ids is None: + cos_sin = cos_sin[:, offset:] + query_states, key_states = self.rotary_emb.apply_rotary(cos_sin, query_states, key_states) + else: + cos_sin = self.rotary_emb(kv_seq_len).transpose([0, 2, 1, 3]) + + if offset > 0 and position_ids is None: + cos_sin = cos_sin[:, offset:] + + query_states, key_states = self.rotary_emb.apply_rotary_3d( + cos_sin, query_states, key_states, position_ids + ) + else: + assert not self.rope_3d + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape + if num_heads != num_key_value_heads: + query_states, _, _ = fused_rope(query_states, None, None, rotary_emb_base=self.config.rope_theta) + key_states, _, _ = fused_rope(key_states, None, None, rotary_emb_base=self.config.rope_theta) + else: + query_states, key_states, _ = fused_rope( + query_states, + key_states, + None, + rotary_emb_base=self.config.rope_theta, + ) + + if use_cache: + query_states = query_states.astype(query_states_dtype) + key_states = key_states.astype(query_states_dtype) + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = [key_states, value_states] if use_cache else None + + if self.config.use_qk_norm: + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + attn_output, attn_weights = scaled_dot_product_attention( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + config=self.config, + inbatch_pack_offset=inbatch_pack_offset, + training=self.training, + ) + return attn_output, attn_weights, past_key_value + + +class ErnieDecoderLayer(nn.Layer): + def __init__(self, config, layer_idx=0): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ErnieAttention(config, layer_idx) + self.mlp = ErnieMLP(config) + Norm = RMSNorm + + self.input_layernorm = Norm(config) + self.post_attention_layernorm = Norm(config) + self.residual_add1 = FusedDropoutImpl(config.hidden_dropout_prob, mode="upscale_in_train") + self.residual_add2 = FusedDropoutImpl(config.hidden_dropout_prob, mode="upscale_in_train") + self.config = config + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + inbatch_pack_offset: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + ) + + if self.config.tensor_parallel_degree > 1 and self.config.hidden_dropout_prob > 0.0: + current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + with get_rng_state_tracker().rng_state(current_seed): + hidden_states = self.residual_add1(hidden_states, residual) + else: + hidden_states = self.residual_add1(hidden_states, residual) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + if self.config.tensor_parallel_degree > 1 and self.config.hidden_dropout_prob > 0.0: + current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + with get_rng_state_tracker().rng_state(current_seed): + hidden_states = self.residual_add2(hidden_states, residual) + else: + hidden_states = self.residual_add2(hidden_states, residual) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class ErniePretrainedModel(PretrainedModel): + config_class = ErnieMoEConfig + base_model_prefix = "ernie" + + @classmethod + def _get_name_mappings(cls, config: ErnieMoEConfig) -> StateDictNameMapping: + mappings: StateDictNameMapping = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range( + config.num_hidden_layers if not config.remove_tail_layer else config.num_hidden_layers - 1 + ): + if config.fuse_attn_ffn: + layer_mappings = [ + [ + f"layers.{layer_index}.self_attn.qkv_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.o_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [ + f"layers.{layer_index}.mlp.up_gate_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + else: + layer_mappings = [ + [ + f"layers.{layer_index}.self_attn.q_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.k_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.v_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.o_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + if "ErnieModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "ernie." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + if config.num_key_value_heads is not None and config.num_key_value_heads != config.num_attention_heads: + if is_split: + qkv_fn = partial( + gqa_qkv_split_func, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.hidden_size // config.num_attention_heads, + ) + else: + qkv_fn = partial( + gqa_qkv_merge_func, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.hidden_size // config.num_attention_heads, + ) + else: + qkv_fn = partial(fn, is_column=True) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + if config.fuse_attn_ffn: + base_actions = { + "layers.0.self_attn.qkv_proj.weight": qkv_fn, + "layers.0.mlp.up_gate_proj.weight": partial(fn, is_column=True, is_naive_2fuse=True), + "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + if config.use_bias: + base_actions.update( + { + "layers.0.self_attn.qkv_proj.bias": qkv_fn, + "layers.0.mlp.up_gate_proj.bias": partial(fn, is_column=True, is_naive_2fuse=True), + "lm_head.bias": partial(fn, is_column=True), + } + ) + else: + base_actions = { + "layers.0.self_attn.q_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), + "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + if config.use_bias: + base_actions.update( + { + "layers.0.self_attn.q_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.up_proj.bias": partial(fn, is_column=True), + "lm_head.bias": partial(fn, is_column=True), + } + ) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings( + config.num_hidden_layers if not config.remove_tail_layer else config.num_hidden_layers - 1 + ) + + return mappings + + def _init_weights(self, layer): + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + else: + rng_tracker = contextlib.nullcontext + + if isinstance( + layer, + ( + ColumnParallelLinear, + RowParallelLinear, + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + VocabParallelEmbedding, + ErnieLMHead, + nn.Embedding, + NativeLinear, + paddle.incubate.nn.FusedLinear, + ), + ): + + with rng_tracker(): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + layer.weight.set_value( + paddle.randn(layer.weight.shape, dtype=dtype).scale(self.config.initializer_range) + ) + paddle.set_default_dtype(dtype) + logger.info( + f"dist-init-fc: shape={layer.weight.shape}, " + f" range={self.config.initializer_range}, dtype={layer.weight.dtype} " + f' type={type(layer)},norm={layer.weight.astype("float32").norm().item()}' + ) + + elif isinstance(layer, RotaryEmbedding): + head_dim = self.config.hidden_size // self.config.num_attention_heads + inv_freq = 1.0 / (layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) + + t = np.arange(layer.max_position_embeddings, dtype="float32") + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate([freqs, freqs], axis=-1) + cos_cached = np.cos(emb)[:, :] + sin_cached = np.sin(emb)[:, :] + layer.cos_cached.set_value(cos_cached) + layer.sin_cached.set_value(sin_cached) + + +@register_base_model +class ErnieModel(ErniePretrainedModel): + + def __init__(self, config: ErnieMoEConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.config = config + + if config.tensor_parallel_degree > 1: + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + layers_list = [ + ErnieDecoderLayer(config, layer_idx) + for layer_idx in range( + config.num_hidden_layers - 1 if config.remove_tail_layer else config.num_hidden_layers + ) + ] + + self.layers = nn.LayerList(layers_list) + Norm = RMSNorm + + self.norm = Norm(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @classmethod + def _prepare_decoder_attention_mask(cls, attention_mask, input_shape, past_key_values_length, dtype): + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length, dtype=dtype + ) + + if attention_mask is not None: + expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + combined_attention_mask = paddle.maximum( + combined_attention_mask.astype(dtype), + paddle.to_tensor(float(finfo(dtype).min), dtype=dtype), + ) + return combined_attention_mask + + @paddle.jit.not_to_static + def recompute_training( + self, + layer_module, + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + ): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + use_reentrant=self.config.recompute_use_reentrant, + ) + return hidden_states + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + inbatch_pack_offset=None, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = paddle.shape(past_key_values[0][0])[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if self.embed_tokens is not None: + inputs_embeds = inputs_embeds.astype(self.embed_tokens.weight.dtype) + + if self.config.sequence_parallel: + inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + can_use_fa = self.config.use_flash_attn and flash_attention is not None + can_mem_eff_attn = self.config.use_mem_eff_attn and inbatch_pack_offset is not None + if can_use_fa or can_mem_eff_attn: + if attention_mask is not None: + attention_mask = None + logger.warning( + f"set attention_mask = None when (can_use_fa or can_mem_eff_attn) and attention_mask is not None, " + f"can_use_fa = {can_use_fa}, can_mem_eff_attn = {can_mem_eff_attn}, " + f"attention_mask is not None = {attention_mask is not None}" + ) + elif attention_mask is None: + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if attention_mask is not None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + cache_length, + inputs_embeds.dtype, + ) + hidden_states = inputs_embeds + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if self.config.use_recompute and has_gradient: + layer_outputs = self.recompute_training( + decoder_layer, + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + hidden_states = paddle.unsqueeze(hidden_states[:, -1, :], 1) + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + ) + + +class FusedHeadParallelCrossEntropy(PyLayer): + + @staticmethod + def forward( + ctx, + hidden_states, + weight, + bias, + labels, + tensor_parallel_degree, + mp_group=None, + ignore_index=-100, + seq_chunk_size=8192, + transpose_y=False, + fuse_linear=False, + training=True, + ): + + ctx.tensor_parallel_degree = tensor_parallel_degree + ctx.ignore_index = ignore_index + ctx.seq_chunk_size = seq_chunk_size + ctx.transpose_y = transpose_y + ctx.fuse_linear = fuse_linear + ctx.training = training + + ctx.hidden_states_shape = hidden_states.shape + + if ctx.tensor_parallel_degree > 1: + ctx.mp_group = ( + fleet.get_hybrid_communicate_group().get_model_parallel_group() if mp_group is None else mp_group + ) + ctx.rank = ctx.mp_group.rank + ctx.world_size = ctx.mp_group.nranks + else: + ctx.mp_group = None + ctx.rank = 0 + ctx.world_size = 1 + + loss_all = [] + labels_all = [] + with paddle.no_grad(): + labels = labels.reshape_([-1]) + hidden_states = hidden_states.reshape_([-1, hidden_states.shape[-1]]) + + num_tokens_per_rank = [] + if ctx.tensor_parallel_degree > 1: + dist.stream.all_gather( + num_tokens_per_rank, + paddle.to_tensor(hidden_states.shape[0], dtype=paddle.int32), + group=ctx.mp_group, + ) + ctx.num_tokens_per_rank = num_tokens_per_rank + + for idx in range(ctx.world_size): + if idx == ctx.rank: + hidden_states_recv = hidden_states + labels_recv = labels + else: + hidden_states_recv = paddle.empty( + [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]], + dtype=hidden_states.dtype, + ) + labels_recv = paddle.empty([ctx.num_tokens_per_rank[idx]], dtype=labels.dtype) + + if ctx.tensor_parallel_degree > 1: + dist.stream.broadcast( + hidden_states_recv, + src=ctx.mp_group.ranks[idx], + group=ctx.mp_group, + ) + dist.stream.broadcast(labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group) + + seq_len = hidden_states_recv.shape[0] + num_chunk = (seq_len + ctx.seq_chunk_size - 1) // ctx.seq_chunk_size + + loss_chunk = [] + for chunk_idx in range(num_chunk): + start = chunk_idx * ctx.seq_chunk_size + end = min(start + ctx.seq_chunk_size, seq_len) + hidden_states_chunk = hidden_states_recv._slice(start, end) + labels_chunk = labels_recv._slice(start, end) + + logits = parallel_matmul( + hidden_states_chunk, + weight, + bias=bias, + transpose_y=ctx.transpose_y, + tensor_parallel_degree=ctx.tensor_parallel_degree, + tensor_parallel_output=True, + fuse_linear=ctx.fuse_linear, + training=ctx.training, + ) + + with paddle.amp.auto_cast(False): + if ctx.tensor_parallel_degree > 1: + loss = mp_ops._c_softmax_with_cross_entropy( + logits.cast("float32"), + labels_chunk.unsqueeze(-1), + group=ctx.mp_group, + ignore_index=ctx.ignore_index, + ) + else: + loss = paddle.nn.functional.softmax_with_cross_entropy( + logits.cast("float32"), + labels_chunk.unsqueeze(-1), + ignore_index=ctx.ignore_index, + ) + loss_chunk.append(loss) + loss_all.append(paddle.concat(loss_chunk, axis=0)) + labels_all.append(labels_recv) + + ctx.loss_concat_sections = [loss.shape[0] for loss in loss_all] + loss_all = paddle.concat(loss_all, axis=0) + labels_all = paddle.concat(labels_all, axis=0) + + tensor_inputs = [hidden_states, weight, bias, labels] + ctx.save_for_backward(*tensor_inputs) + + return loss_all, labels_all + + @staticmethod + def backward(ctx, loss_all_grad, labels_all_grad): + + hidden_states, weight, bias, labels = ctx.saved_tensor() + + loss_all_grad_list = paddle.split(loss_all_grad, ctx.loss_concat_sections, axis=0) + + def detach_variable(inp): + if inp is None: + return None + x = inp.detach() + x.stop_gradient = inp.stop_gradient + return x + + if weight.stop_gradient is False: + weight_main_grad = paddle.zeros(weight.shape, dtype=paddle.float32) + else: + weight_main_grad = None + if bias is not None and bias.stop_gradient is False: + bias_main_grad = paddle.zeros(bias.shape, dtype=paddle.float32) + else: + bias_main_grad = None + + hidden_states = detach_variable(hidden_states) + weight = detach_variable(weight) + bias = detach_variable(bias) + labels = detach_variable(labels) + + with paddle.base.dygraph.guard(): + tracer = paddle.base.framework._dygraph_tracer() + tracer._has_grad = True + + for idx in range(ctx.world_size): + if idx == ctx.rank: + hidden_states_recv = hidden_states + labels_recv = labels + else: + hidden_states_recv = paddle.empty( + [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]], + dtype=hidden_states.dtype, + ) + labels_recv = paddle.empty([ctx.num_tokens_per_rank[idx]], dtype=labels.dtype) + if ctx.tensor_parallel_degree > 1: + dist.stream.broadcast( + hidden_states_recv, + src=ctx.mp_group.ranks[idx], + group=ctx.mp_group, + ) + dist.stream.broadcast(labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group) + hidden_states_recv.stop_gradient = False + + seq_len = hidden_states_recv.shape[0] + num_chunk = (seq_len + ctx.seq_chunk_size - 1) // ctx.seq_chunk_size + + for chunk_idx in range(num_chunk): + start = chunk_idx * ctx.seq_chunk_size + end = min(start + ctx.seq_chunk_size, seq_len) + hidden_states_chunk = hidden_states_recv.slice(axes=[0], starts=[start], ends=[end]) + labels_chunk = labels_recv._slice(start, end) + loss_grad_chunk = loss_all_grad_list[idx]._slice(start, end) + + logits = parallel_matmul( + hidden_states_chunk, + weight, + bias=bias, + transpose_y=ctx.transpose_y, + tensor_parallel_degree=ctx.tensor_parallel_degree, + tensor_parallel_output=True, + fuse_linear=ctx.fuse_linear, + training=ctx.training, + ) + + with paddle.amp.auto_cast(False): + if ctx.tensor_parallel_degree > 1: + loss_chunk = mp_ops._c_softmax_with_cross_entropy( + logits.cast("float32"), + labels_chunk.unsqueeze(-1), + group=ctx.mp_group, + ignore_index=ctx.ignore_index, + ) + else: + loss_chunk = paddle.nn.functional.softmax_with_cross_entropy( + logits.cast("float32"), + labels_chunk.unsqueeze(-1), + ignore_index=ctx.ignore_index, + ) + + with paddle.amp.auto_cast(enable=False): + paddle.autograd.backward(loss_chunk, loss_grad_chunk) + + if weight_main_grad is not None: + weight_main_grad.add_(weight.grad.cast(paddle.float32)) + weight.clear_gradient(True) + if bias_main_grad is not None: + bias_main_grad.add_(bias.grad.cast(paddle.float32)) + bias.clear_gradient(True) + + if idx == ctx.rank: + hidden_states_grad = hidden_states_recv.grad + hidden_states_grad = hidden_states_grad.reshape(ctx.hidden_states_shape) + + if weight_main_grad is not None: + weight_main_grad = weight_main_grad.astype(weight.dtype) + if bias_main_grad is not None: + bias_main_grad = bias_main_grad.astype(bias.dtype) + + if bias_main_grad is not None: + return ( + hidden_states_grad, + weight_main_grad, + bias_main_grad, + ) + else: + return ( + hidden_states_grad, + weight_main_grad, + ) + + +class ErniePretrainingCriterion(paddle.nn.Layer): + + def __init__(self, config, return_tuple=True): + super(ErniePretrainingCriterion, self).__init__() + self.ignored_index = getattr(config, "ignored_index", -100) + self.config = config + self.return_tuple = return_tuple + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: + self.loss_func = fleet.meta_parallel.ParallelCrossEntropy() + else: + self.loss_func = paddle.nn.CrossEntropyLoss( + reduction="none", + ) + self.token_balance_loss = config.token_balance_loss + + def forward(self, prediction_scores, masked_lm_labels): + + if self.config.use_sparse_head_and_loss_fn: + hidden_states, outlinear_weight, outlinear_bias = prediction_scores + + if self.config.sequence_parallel: + masked_lm_labels, sparse_label_idx = sequence_parallel_sparse_mask_labels( + masked_lm_labels, self.ignored_index + ) + sparse_label_idx = sparse_label_idx.reshape([-1, 1]) + hidden_states = paddle.gather(hidden_states, sparse_label_idx, axis=0) + hidden_states = AllGatherVarlenOp.apply(hidden_states) + else: + masked_lm_labels = masked_lm_labels.flatten() + sparse_label_idx = paddle.nonzero(masked_lm_labels != self.ignored_index).flatten() + masked_lm_labels = paddle.take_along_axis(masked_lm_labels, sparse_label_idx, axis=0) + + hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) + hidden_states = paddle.take_along_axis(hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0) + + if self.config.use_recompute_loss_fn: + offload_kwargs = {} + if self.config.get("offload_lm_head", False): + offload_kwargs["offload_indices"] = [1] + res = recompute( + self.forward_impl_with_calc_logits, + masked_lm_labels, + hidden_states, + outlinear_weight, + outlinear_bias, + **offload_kwargs, + ) + else: + logits = calc_lm_head_logits( + self.config, + hidden_states, + outlinear_weight, + outlinear_bias, + training=self.training, + ) + res = self.forward_impl(logits, masked_lm_labels) + elif self.config.use_recompute_loss_fn: + if self.config.use_fused_head_loss_fn: + res = self.forward_impl_with_fused_head_loss_fn(masked_lm_labels, *prediction_scores) + else: + assert isinstance(prediction_scores, tuple) and len(prediction_scores) in [3, 4], prediction_scores + res = recompute( + self.forward_impl_with_calc_logits, + masked_lm_labels, + *prediction_scores, + ) + else: + res = self.forward_impl(prediction_scores, masked_lm_labels) + + return res + + def forward_impl_with_fused_head_loss_fn(self, masked_lm_labels, hidden_states, outlinear_weight, outlinear_bias): + masked_lm_labels.stop_gradient = True + masked_lm_loss, masked_lm_labels_all = FusedHeadParallelCrossEntropy.apply( + hidden_states, + outlinear_weight, + outlinear_bias, + masked_lm_labels, + self.config.tensor_parallel_degree, + ignore_index=self.ignored_index, + seq_chunk_size=self.config.get("loss_subbatch_seqlen", 32768), + transpose_y=self.config.tie_word_embeddings, + fuse_linear=self.config.fuse_linear, + training=self.training, + ) + lossmask = masked_lm_labels_all != self.ignored_index + if (~lossmask).all(): + logger.warning(f"encounter empty span when calculate loss, ignored_index={self.ignored_index}") + loss = paddle.mean(masked_lm_loss) * 0.0 + loss_sum = masked_lm_loss.sum().detach() + else: + lossmask = lossmask.reshape([-1]).cast(paddle.float32) + + masked_lm_loss = paddle.sum(masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask) + loss = masked_lm_loss / lossmask.sum() + if self.token_balance_loss: + _loss = masked_lm_loss / self.config.token_balance_seqlen + global_training_logs.update(token_balance_loss=_loss.detach()) + loss = _loss - _loss.detach() + loss.detach() + loss_sum = masked_lm_loss.sum().detach() + if not self.return_tuple: + if self.training: + return loss + return loss_sum + return loss, loss_sum + + def forward_impl_with_calc_logits(self, masked_lm_labels, hidden_states, outlinear_weight, outlinear_bias): + + logits = calc_lm_head_logits( + self.config, + hidden_states, + outlinear_weight, + outlinear_bias, + training=self.training, + ) + + return self.forward_impl(logits, masked_lm_labels) + + def loss_impl(self, prediction_scores, masked_lm_labels): + prediction_scores = prediction_scores.cast("float32") + masked_lm_loss = self.loss_func(prediction_scores, masked_lm_labels.unsqueeze(-1)) + + return masked_lm_loss + + def forward_impl(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + assert prediction_scores.shape[-1] != self.config.vocab_size, ( + f"enable_parallel_cross_entropy, the vocab_size should be splited:" + f" {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + + with paddle.amp.auto_cast(False): + prediction_scores_dims = len(prediction_scores.shape) + if prediction_scores_dims == 2 and prediction_scores.shape[0] > self.config.get( + "loss_subbatch_seqlen", 32768 + ): + sb_loss_func = subbatch( + self.loss_impl, + [0, 1], + [0, 0], + self.config.get("loss_subbatch_seqlen", 32768), + 0, + ) + masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels) + elif prediction_scores_dims == 3 and prediction_scores.shape[1] > self.config.get( + "loss_subbatch_seqlen", 32768 + ): + sb_loss_func = subbatch( + self.loss_impl, + [0, 1], + [1, 1], + self.config.get("loss_subbatch_seqlen", 32768), + 1, + ) + masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels) + else: + masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels) + + lossmask = masked_lm_labels != self.ignored_index + if (~lossmask).all(): + logger.warning(f"encounter empty span when calculate loss, ignored_index={self.ignored_index}") + loss = paddle.mean(masked_lm_loss) * 0.0 + loss_sum = masked_lm_loss.sum().detach() + else: + lossmask = lossmask.reshape([-1]).cast(paddle.float32) + masked_lm_loss = paddle.sum(masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask) + loss = masked_lm_loss / lossmask.sum() + if self.token_balance_loss: + _loss = masked_lm_loss / self.config.token_balance_seqlen + global_training_logs.update(token_balance_loss=_loss.detach()) + loss = _loss - _loss.detach() + loss.detach() + loss_sum = masked_lm_loss.sum().detach() + if not self.return_tuple: + if self.training: + return loss + return loss_sum + return loss, loss_sum + + +class ErnieLMHead(nn.Layer): + def __init__(self, config): + super(ErnieLMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + self.weight = self.create_parameter( + shape=( + [vocab_size, config.hidden_size] if config.tie_word_embeddings else [config.hidden_size, vocab_size] + ), + dtype=paddle.get_default_dtype(), + ) + logger.info(f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}") + if config.weight_share_add_bias and config.use_bias: + self.bias = self.create_parameter( + shape=[vocab_size], + dtype=paddle.get_default_dtype(), + attr=paddle.ParamAttr(initializer=paddle.nn.initializer.constant.Constant(0.0)), + ) + else: + self.bias = None + + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if config.weight_share_add_bias and config.use_bias: + self.bias.is_distributed = True if (vocab_size != config.vocab_size) else False + + if self.weight.is_distributed: + self.weight.split_axis = 1 + if config.weight_share_add_bias and config.use_bias and self.bias.is_distributed: + self.bias.split_axis = 0 + + if self.config.use_recompute_loss_fn: + logger.info( + "Using recompute_loss_fn, the calculation of logits will be moved into " + "loss_fn for memory optimization" + ) + + def forward(self, hidden_states, tensor_parallel_output=None): + if self.config.use_recompute_loss_fn or self.config.use_sparse_head_and_loss_fn: + out_tensors = ( + (hidden_states, self.weight, self.bias) + if tensor_parallel_output is None + else (hidden_states, self.weight, self.bias, tensor_parallel_output) + ) + + return out_tensors + + return calc_lm_head_logits( + self.config, + hidden_states, + self.weight, + self.bias, + tensor_parallel_output, + training=self.training, + ) + + +class ErnieForCausalLM(ErniePretrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.sequence_parallel: + logger.info(f"using sequence_parallel, input seqlen={config.seqlen}") + if config.using_dynamic_sequence_length: + assert ( + not config.micro_batch_size + ), "sequence-parallel needs micro_batch_size setting when using dynamic_sequence_length" + else: + assert config.seqlen is not None + + assert ( + config.tensor_parallel_degree > 1 + ), f"sequence-parallel needs mp>1, got mp={config.tensor_parallel_degree}" + + new_initializer_range = math.sqrt(0.3333 / config.hidden_size) + logger.info(f"change initializer-range from {config.initializer_range} to {new_initializer_range}") + config.initializer_range = new_initializer_range + self.config = config + + self.ernie = ErnieModel(config) + self.lm_head = ErnieLMHead(config) + self.criterion = ErniePretrainingCriterion(config) + + self.tie_weights() + + if self.config.fuse_rms_norm: + logger.info("Use fusedRMSNorm") + else: + logger.info("Use normal RMSNorm") + + def _post_init(self, original_init, *args, **kwargs): + super()._post_init(self, original_init, *args, **kwargs) + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + logger.info(f"using post init div: factor:{factor}") + with paddle.no_grad(): + for layer in self.ernie.layers: + layer.self_attn.o_proj.weight.scale_(factor) + layer.mlp.down_proj.weight.scale_(factor) + + def get_input_embeddings(self): + return self.ernie.embed_tokens + + def set_input_embeddings(self, value): + self.ernie.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.ernie = decoder + + def get_decoder(self): + return self.ernie + + @staticmethod + def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id): + is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any( + input_ids == pad_token_id + ).numpy().item() + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + attention_mask = (input_ids != pad_token_id).astype("int64") + else: + attention_mask = paddle.ones_like(input_ids, dtype="int64") + return attention_mask + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + attention_mask = kwargs.get("attention_mask", None) + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": True, + "attention_mask": attention_mask, + "return_dict": True, + } + ) + + if self.config.rope_3d: + model_inputs.update({"position_ids": kwargs["position_ids"]}) + + return model_inputs + + def update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False): + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = paddle.concat([token_type_ids, token_type_ids[:, -1:]], axis=-1) + + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [ + attention_mask, + paddle.ones([attention_mask.shape[0], 1], dtype="int64"), + ], + axis=-1, + ) + if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None: + role_ids = model_kwargs["role_ids"] + model_kwargs["role_ids"] = paddle.concat([role_ids, role_ids[:, -1:]], axis=-1) + + if self.config.rope_3d: + assert "position_ids" in model_kwargs, "position_ids must be provided if rope_3d is on" + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat( + [ + position_ids, + position_ids.max(axis=(1, 2), keepdim=True).tile([1, 1, 3]) + 1, + ], + axis=1, + ) + + return model_kwargs + + def forward( + self, + input_ids, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=False, + ignored_index=0, + data_id=None, + src_id=None, + inbatch_pack_offset=None, + loss_mask=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + inbatch_pack_offset=inbatch_pack_offset, + ) + + hidden_states = outputs[0] + + logits = self.lm_head( + hidden_states, + ) + + if return_dict: + if labels is not None: + loss, _ = self.criterion(logits, labels) + else: + loss = None + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + assert labels is not None + loss, loss_sum = self.criterion(logits, labels) + return loss, loss_sum diff --git a/ernie/ERNIE/examples/pre-training/models/ernie/modeling_moe.py b/ernie/ERNIE/examples/pre-training/models/ernie/modeling_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3c913dc3c5c003cffd9b02ff8090724c3ca115 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/ernie/modeling_moe.py @@ -0,0 +1,2186 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import math +import random +import re +from copy import deepcopy +from dataclasses import dataclass +from functools import partial +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.distributed.communication.group +import paddle.nn.functional as F +from models.comm_utils import profile +from models.ernie import ErnieMoEConfig +from models.ernie.modeling import ( + ErnieAttention, + ErnieLMHead, + ErnieMLP, + FusedDropoutImpl, + RMSNorm, + RotaryEmbedding, + _expand_mask, + _make_causal_mask, + finfo, +) +from models.ernie.modeling import ( + ErniePretrainingCriterion as ErniePretrainingCriterionBase, +) +from models.fp8_linear import Fp8FusedMlpFunc, MemEfficientFp8FusedMlpFunc +from models.moe.moe_layer import ( + MOELayer, + MoEStatics, +) +from models.moe.top2_gate import Top2Gate, TopKGateFused +from models.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + get_async_loader, + hack_offload_wait, + mark_as_sequence_parallel_parameter, +) +from models.utils import get_global_training_logs +from paddle import nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.fleet.layers.mpu import mp_ops +from paddle.distributed.fleet.layers.mpu.mp_layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import fused_rms_norm_ext +from paddle.incubate.tensor.manipulation import async_offload +from paddleformers.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddleformers.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions as _BaseModelOutput, +) +from paddleformers.transformers.model_outputs import ( + CausalLMOutputWithCrossAttentions as _CausalLMOutput, +) +from paddleformers.transformers.model_utils import PretrainedModel, register_base_model +from paddleformers.utils.tools import get_env_device + +try: + from paddle.incubate.nn.functional import swiglu as fused_swiglu +except (ImportError, ModuleNotFoundError): + fused_swiglu = None + +logger = logging.getLogger(__name__) +paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self +paddle.distributed.communication.group.Group.to_json = lambda self: repr(self) + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(_BaseModelOutput): + router_loss: Optional[paddle.Tensor] = None + gate_logits: Optional[Tuple[paddle.Tensor]] = None + mtp_outputs: Optional[paddle.Tensor] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(_CausalLMOutput): + router_loss: Optional[paddle.Tensor] = None + + +global_training_logs = get_global_training_logs() + +ERNIE_PRETRAINED_MODEL_ARCHIVE_LIST = [] + +__all__ = [ + "ErnieMoEForCausalLM", + "ErniePretrainingCriterion", + "CausalLMOutputWithCrossAttentions", +] + +gate_class = dict( + top2=Top2Gate, + top2_fused=TopKGateFused, +) + + +def get_gate( + config: ErnieMoEConfig, + expert: Tuple[Tuple[int, nn.Layer]], + layer_idx: int, +) -> Tuple[nn.Layer, nn.LayerList]: + moe_num_experts = config.moe_num_experts + assert ( + moe_num_experts >= config.moe_world_size + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={config.moe_world_size}" + assert ( + moe_num_experts % config.moe_world_size == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={config.moe_world_size} == 0" + moe_num_experts_per_device = moe_num_experts // config.moe_world_size + if not config.moe_fuse_experts: + experts = nn.LayerList([]) + for expert_id, (experts_num, fc) in enumerate(expert): + assert experts_num % config.moe_world_size == 0 + num_experts_per_device = experts_num // config.moe_world_size + experts_to_append = [] + if not hasattr(fc, "__len__"): + experts_to_append.append(fc) + if expert_id == 1: + with paddle.utils.unique_name.guard("_mm_deepcopy"): + for _ in range(num_experts_per_device - 1): + experts_to_append.append(deepcopy(fc)) + else: + for _ in range(num_experts_per_device - 1): + experts_to_append.append(deepcopy(fc)) + else: + experts_to_append = fc + for ex in experts_to_append: + for p in ex.parameters(): + p.expert_type = f"expert_type_{expert_id}" + experts.extend(experts_to_append) + assert ( + len(experts) == moe_num_experts_per_device + ), f"experts.len={len(experts)} != moe_num_experts_per_device={moe_num_experts_per_device}" + else: + assert expert[0][0] == 1, "experts are fused and must be one" + experts = deepcopy(expert[0][1]) + + logger.info(f"using moe-world-size: {config.moe_world_size} " f"expert-per-device: {moe_num_experts_per_device} ") + if moe_num_experts <= 2: + gate = None + logger.info("MOE-GATE:-hard-gate") + else: + logger.info(f"MOE-GATE:-{config.moe_gate}") + gate = gate_class[config.moe_gate.lower()](config, layer_idx=layer_idx, group=config.moe_group) + + lm_gate, lm_experts = gate, experts + logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") + return gate, experts, lm_gate, lm_experts + + +def build_mpdp_group(): + hcg = fleet.get_hybrid_communicate_group() + mp_world_size = hcg.get_model_parallel_world_size() + dp_world_size = hcg.get_data_parallel_world_size() + sharding_world_size = hcg.get_sharding_parallel_world_size() + pp_world_size = hcg.get_pipe_parallel_world_size() + + world_size = dist.get_world_size() + rank = dist.get_rank() + topo = np.arange(world_size).reshape([pp_world_size, sharding_world_size, dp_world_size, mp_world_size]) + this_group = None + for i in range(pp_world_size): + for j in range(sharding_world_size): + ranks = topo[i, j, :, :].reshape([-1]).tolist() + group = dist.new_group(ranks) + if rank in ranks: + logger.info(f"building mpdp group, this group has rank: {ranks}") + this_group = group + return this_group + + +def _parse_moe_group( + moe_group: str, +) -> Union[str, paddle.distributed.communication.group.Group]: + moe_group = moe_group.lower() + assert moe_group in { + "sharding", + "data", + "dp", + "mp", + "tp", + "model", + "dummy", + "none", + "world", + "all", + "mpdp", + "ep", + }, f"moe-group not supported, got: {moe_group}" + logger.info(f"using moe-group: {moe_group}") + if not hasattr(fleet.fleet, "_hcg"): + assert moe_group in { + "dummy", + "none", + "world", + "data", + }, "only support dummy gate in `single-model`" + if moe_group == "sharding": + moe_group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() + elif moe_group == "ep": + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() + elif moe_group in {"data", "dp"}: + if hasattr(fleet.fleet, "_hcg"): + moe_group = fleet.get_hybrid_communicate_group().get_data_parallel_group() + else: + moe_group = _get_global_group() + elif moe_group in {"mp", "model", "tp"}: + moe_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + elif moe_group in {"dummy"}: + dummy_group = paddle.distributed.communication.group.Group(0, None, [0]) + moe_group = dummy_group + elif moe_group in {"mpdp"}: + moe_group = build_mpdp_group() + else: + moe_group = _get_global_group() + return moe_group + + +def moe_ep2mp(state_dict: Dict[str, paddle.Tensor], config: ErnieMoEConfig, split_actions): + if config.tensor_parallel_degree <= 1 or dist.get_world_size(config.moe_group) > 1: + return state_dict + if isinstance(config.moe_num_experts, (list, tuple)): + num_lm_experts, num_mm_experts = config.moe_num_experts + num_experts = sum(config.moe_num_experts) + else: + num_lm_experts, num_mm_experts = config.moe_num_experts, 0 + num_experts = config.moe_num_experts + expert_ids = [int(re.search(r"mlp\.experts\.(\d+)", k).group(1)) for k in state_dict.keys() if "mlp.experts" in k] + if expert_ids and max(expert_ids) == num_experts - 1: + return state_dict + + logger.info("auto ep2mp") + hcg = fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + world_size = dist.get_world_size(mp_group) + num_lm_local_experts = num_lm_experts // world_size + num_mm_local_experts = num_mm_experts // world_size + + new_sd = {} + + actual_keys = [] + for k in state_dict.keys(): + actual_keys.append(k) + actual_keys_sorted = sorted(actual_keys) + + for k in actual_keys_sorted: + if "mlp.experts" in k: + expert_id = int(re.search(r"mlp\.experts\.(\d+)", k).group(1)) + gathered_experts = [] + tensor = paddle.to_tensor(state_dict[k]) + dist.all_gather(gathered_experts, tensor, group=mp_group) + for rank in range(len(gathered_experts)): + if expert_id < num_lm_local_experts: + real_id = expert_id + rank * num_lm_local_experts + else: + if num_mm_experts > 0: + real_id = num_lm_experts + (expert_id - num_lm_local_experts) + rank * num_mm_local_experts + else: + continue + new_k = k.replace(f"mlp.experts.{expert_id}", f"mlp.experts.{real_id}") + logger.info(f"auto ep2mp: {k}->{new_k}, expert_id: {expert_id}, real_id: {real_id}") + new_sd[new_k] = split_actions[new_k.replace("ernie.", "")](gathered_experts[rank]) + else: + new_sd[k] = state_dict[k] + return new_sd + + +def moe_statedict_cherry_pick(state_dict: Dict[str, paddle.Tensor], config: ErnieMoEConfig): + moe_num_experts = ( + sum(config.moe_num_experts) if isinstance(config.moe_num_experts, (list, tuple)) else config.moe_num_experts + ) + if moe_num_experts <= 1: + return state_dict + moe_world_size = config.moe_world_size + if moe_world_size <= 1: + moe_world_size = 1 + moe_world_size_per_device = moe_num_experts // moe_world_size + for key in list(state_dict.keys()): + if "mlp.experts" in key: + imoe = int(re.search(r"mlp\.experts\.(\d+)", key).group(1)) + if imoe >= moe_world_size_per_device: + continue + maybe_moe_name = key.replace( + f"mlp.experts.{imoe}", + f"mlp.experts.{config.moe_rank * moe_world_size_per_device + imoe}", + ) + if maybe_moe_name != key and maybe_moe_name in state_dict: + logger.info(f"moe auto changed state-dict using {maybe_moe_name} as {key}") + state_dict[key] = state_dict.pop(maybe_moe_name) + return state_dict + + +def moe_statedict_upcycle( + state_dict: Dict[str, paddle.Tensor], + config: ErnieMoEConfig, + dtype, + merge_actions, + split_actions, + layer_idxs=None, +): + if not isinstance(config.moe_intermediate_size, int): + logger.warning("moe upcycle only supports single modality expand !") + return state_dict + + moe_layer_start_index = ( + min(config.moe_layer_start_index) + if isinstance(config.moe_layer_start_index, (tuple, list)) + else config.moe_layer_start_index + ) + moe_layer_end_index = ( + max(config.moe_layer_end_index) + if isinstance(config.moe_layer_end_index, (tuple, list)) + else config.moe_layer_end_index + ) + + if config.moe_num_experts > 0: + moe_world_size = config.moe_world_size + if moe_world_size <= 1: + moe_world_size = 1 + moe_world_size_per_device = config.moe_num_experts // moe_world_size + + granularity = ( + 1 if config.moe_intermediate_size == 0 else config.intermediate_size // config.moe_intermediate_size + ) + + def slice_granularity(w, global_expert_id, column=True, shuffle=False, group_experts=False): + if group_experts: + part_id = global_expert_id // (config.moe_num_experts // config.moe_k) + else: + part_id = global_expert_id % config.moe_k + part_id = part_id % granularity + if shuffle: + rng = random.Random(global_expert_id // config.moe_k) + if column: + idx = np.arange(w.shape[-1]) + rng.shuffle(idx) + w = w.index_select(paddle.to_tensor(idx), axis=-1) + else: + idx = np.arange(w.shape[0]) + rng.shuffle(idx) + w = w.index_select(paddle.to_tensor(idx), axis=0) + if granularity == 1: + return w + if column: + per_expert = w.shape[-1] // granularity + return w[..., part_id * per_expert : (part_id + 1) * per_expert] + per_expert = w.shape[0] // granularity + w *= config.moe_k + return w[part_id * per_expert : (part_id + 1) * per_expert, ...] + + def slice_granularity_shared(w, column=True): + if column: + per_expert = w.shape[-1] // granularity + return w[..., -(per_expert * config.moe_num_shared_experts) :] + per_expert = w.shape[0] // granularity + return w[-(per_expert * config.moe_num_shared_experts) :, ...] + + def _chunk(t): + return t.chunk(2, axis=-1) if isinstance(w, paddle.Tensor) else np.split(w, 2, axis=-1) + + def _cat(t): + return paddle.concat(t, -1) if isinstance(t[0], paddle.Tensor) else np.concatenate(t, -1) + + granularity = ( + 1 if config.moe_intermediate_size == 0 else config.intermediate_size // config.moe_intermediate_size + ) + is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") + and config.moe_group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + logger.info(f"UPCYCLE-IS_MP_MOE: {is_mp_moe}") + if is_mp_moe and fleet.get_hybrid_communicate_group().get_model_parallel_world_size() > 1: + mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + else: + mp_group = None + + for ilayer in range(config.num_hidden_layers): + if layer_idxs and ilayer not in layer_idxs: + continue + if ilayer < moe_layer_start_index or ilayer > moe_layer_end_index: + continue + if (ilayer + 1) % config.moe_layer_interval == 0: + for k in ["up_proj", "gate_proj", "down_proj", "up_gate_proj"]: + for tail in ["weight", "bias"]: + non_moe_key = f"ernie.layers.{ilayer}.mlp.{k}.{tail}" + if non_moe_key in state_dict: + w = state_dict[non_moe_key] + if mp_group is not None and not (k == "down_proj" and tail == "bias"): + w = paddle.to_tensor(w).to(get_env_device()) + gathered_w = [] + logger.info(f"all_gather {non_moe_key} for moe upcycling") + dist.all_gather(gathered_w, w, group=mp_group) + w = w.cpu() + gathered_w = [v.cpu() for v in gathered_w] + gathered_w = merge_actions[non_moe_key.replace("ernie.", "")](gathered_w) + logger.info(f"gathered w is {gathered_w.shape}, type {gathered_w.dtype}") + w = gathered_w + for imoe in range(moe_world_size_per_device): + moe_name = f"ernie.layers.{ilayer}.mlp.experts.{imoe}.{k}.{tail}" + if moe_name not in state_dict and non_moe_key in state_dict: + if k == "up_gate_proj": + w_ = _cat( + [ + slice_granularity( + ww, + config.moe_rank * moe_world_size_per_device + imoe, + column=True, + group_experts=config.moe_group_experts, + ) + for ww in _chunk(w) + ] + ) + elif k == "down_proj" and tail == "bias": + w_ = deepcopy(w) + else: + w_ = slice_granularity( + w, + config.moe_rank * moe_world_size_per_device + imoe, + column=k in {"up_proj", "gate_proj", "up_gate_proj"}, + group_experts=config.moe_group_experts, + ) + logger.info(f"before slice: {w.shape} -> {w_.shape}") + logger.info( + f"moe auto expand state-dict, ffn name G={granularity}: " + f"{moe_name} {w_.shape} {w_.dtype} {dtype}" + ) + if isinstance(w_, np.ndarray): + w_ = paddle.to_tensor(w_) + if w_.dtype == dtype: + state_dict[moe_name] = w_ + else: + state_dict[moe_name] = w_.cast(dtype) + + if config.moe_num_shared_experts > 0: + moe_name = f"ernie.layers.{ilayer}.mlp.shared_experts.{k}.{tail}" + if moe_name not in state_dict and non_moe_key in state_dict: + if k == "up_gate_proj": + w_ = _cat([slice_granularity_shared(ww, column=True) for ww in _chunk(w)]) + if mp_group is not None: + w_ = split_actions[non_moe_key.replace("ernie.", "")](w_) + elif k == "down_proj" and tail == "bias": + w_ = deepcopy(w) + else: + w_ = slice_granularity_shared( + w, + column=k in {"up_proj", "gate_proj", "up_gate_proj"}, + ) + logger.info(f"W_ {k}-{w.shape}--shape-{w_.shape}") + if mp_group is not None: + w_ = split_actions[non_moe_key.replace("ernie.", "")](w_) + logger.info( + f"moe auto expand state-dict, shared experts, ffn name G={granularity}: " + f"{moe_name} {w_.shape} {w_.dtype}" + ) + if isinstance(w_, np.ndarray): + w_ = paddle.to_tensor(w_) + if w_.dtype == dtype: + state_dict[moe_name] = w_ + else: + state_dict[moe_name] = w_.cast(dtype) + + return state_dict + + +class ErnieMoeMLP(ErnieMLP): + def __init__(self, config, is_shared_expert=False): + if getattr(config, "disable_ffn_model_parallel", False): + config = deepcopy(config) + config.tensor_parallel_degree = 1 + super().__init__(config) + self.moe_dropout_prob = config.moe_dropout_prob + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + self.is_shared_expert = is_shared_expert + self.shared_expert_mem_efficient = self.config.fp8_mem_configs["shared_expert"] + + def forward(self, x, use_comm=True): + if ( + self.config.tensor_parallel_degree <= 1 + and self.fuse_ffn + and self.config.use_fp8_mlp + and not self.config.use_bias + ): + if self.is_shared_expert and self.shared_expert_mem_efficient: + return MemEfficientFp8FusedMlpFunc.apply(x, self.up_gate_proj.weight, self.down_proj.weight) + return Fp8FusedMlpFunc.apply(x, self.up_gate_proj.weight, self.down_proj.weight) + + if self.fuse_ffn: + up_gate_proj = ( + partial(self.up_gate_proj, use_comm=use_comm) + if (isinstance(self.up_gate_proj, ColumnSequenceParallelLinear)) + else self.up_gate_proj + ) + else: + gate_proj = ( + partial(self.gate_proj, use_comm=use_comm) + if (isinstance(self.gate_proj, ColumnSequenceParallelLinear)) + else self.gate_proj + ) + up_proj = ( + partial(self.up_proj, use_comm=use_comm) + if (isinstance(self.up_proj, ColumnSequenceParallelLinear)) + else self.up_proj + ) + + if self.fuse_swiglu: + if self.fuse_ffn: + if self.config.use_fp8 and self.config.fp8_configs["smooth_swiglu"]: + x, gate = up_gate_proj(x).chunk(2, axis=-1) + + with paddle.no_grad(): + scale = paddle.clip(gate.abs().max(axis=-1, keepdim=True), 1e-8) + + gate = gate / scale + if self.config.sequence_parallel: + scale = ScatterOp.apply(scale) + + x = paddle.concat([x, gate], axis=-1) + else: + x = up_gate_proj(x) + x = fused_swiglu(x) + else: + x = fused_swiglu(gate_proj(x), up_proj(x)) + else: + if self.fuse_ffn: + x, gate = up_gate_proj(x).chunk(2, axis=-1) + x = F.silu(x) * gate + else: + x = F.silu(gate_proj(x)) * up_proj(x) + if self.moe_dropout_prob > 0: + with get_rng_state_tracker().rng_state("local_seed"): + x = F.dropout(x=x, p=self.moe_dropout_prob) + if self.config.use_fp8 and self.config.fp8_configs["smooth_swiglu"]: + return self.down_proj(x) * scale + ret = self.down_proj(x) + return ret + + +class ErnieMoeDenseExpert(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else nn.Linear + mp_degree = max(1, config.tensor_parallel_degree) + self.is_mp = mp_degree > 1 + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.fuse_ffn = config.fuse_attn_ffn + + if config.fuse_attn_ffn: + self.up_gate_proj = LinearFN( + self.hidden_size, + self.intermediate_size * 2 // mp_degree, + bias_attr=config.use_bias, + ) + self.up_gate_proj.weight.is_distributed = self.is_mp + if config.use_bias: + self.up_gate_proj.bias.is_distributed = self.is_mp + else: + self.gate_proj = LinearFN( + self.hidden_size, + self.intermediate_size // mp_degree, + bias_attr=config.use_bias, + ) + self.up_proj = LinearFN( + self.hidden_size, + self.intermediate_size // mp_degree, + bias_attr=config.use_bias, + ) + self.gate_proj.weight.is_distributed = self.is_mp + self.up_proj.weight.is_distributed = self.is_mp + if config.use_bias: + self.gate_proj.bias.is_distributed = self.is_mp + self.up_proj.bias.is_distributed = self.is_mp + self.down_proj = LinearFN( + self.intermediate_size // mp_degree, + self.hidden_size, + bias_attr=config.use_bias, + ) + self.down_proj.weight.is_distributed = self.is_mp + + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + if self.is_mp: + self.mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + + def forward(self, x): + if self.fuse_swiglu: + if self.fuse_ffn: + x = fused_swiglu(self.up_gate_proj(x)) + else: + x = fused_swiglu(self.gate_proj(x), self.up_proj(x)) + else: + if self.fuse_ffn: + x, gate = self.up_gate_proj(x).chunk(2, axis=-1) + x = F.silu(x) * gate + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + if self.is_mp: + x = F.linear(x, self.down_proj.weight) + output_ = mp_ops._mp_allreduce( + x, + group=self.mp_group, + use_calc_stream=True, + use_model_parallel=True, + ) + output = output_ + self.down_proj.bias if self.config.use_bias else output_ + else: + output = self.down_proj(x) + + return output + + +class BMMLinear(nn.Layer): + def __init__(self, experts, d_in, d_out, use_bias=False): + super().__init__() + self.weight = self.create_parameter([experts, d_in, d_out], dtype=paddle.get_default_dtype()) + if use_bias: + self.bias = self.create_parameter([experts, d_out], dtype=paddle.get_default_dtype(), is_bias=True) + else: + self.bias = None + + def forward(self, x): + if self.bias is not None: + return paddle.bmm(x, self.weight) + self.bias + return paddle.bmm(x, self.weight) + + +class ErnieMoeMLPFused(nn.Layer): + def __init__(self, config): + assert ( + hasattr(config, "disable_ffn_model_parallel") or config.tensor_parallel_degree == 1 + ), f"fused mlp only support mp-moe, mp={config.tensor_parallel_degree}" + assert config.fuse_attn_ffn, "fused mlp only support fuse_attn_ffn" + super().__init__() + self.moe_dropout_prob = config.moe_dropout_prob + self.num_local_experts = config.moe_num_experts // config.moe_world_size + logger.info( + f"fused-expert-weight-shape: {[self.num_local_experts, config.hidden_size, config.intermediate_size]}" + ) + + self.up_gate_proj = BMMLinear(self.num_local_experts, config.hidden_size, config.intermediate_size * 2) + self.down_proj = BMMLinear(self.num_local_experts, config.intermediate_size, config.hidden_size) + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def __len__(self): + return self.num_local_experts + + def __iter__(self): + return (self for _ in range(1)) + + def forward(self, x): + if self.fuse_swiglu: + x = fused_swiglu(self.up_gate_proj(x)) + else: + gate, x = self.up_gate_proj(x).chunk(2, axis=-1) + x = F.silu(gate) * x + x = self.down_proj(x) + return x + + +class FusedLinearAddNormFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, residual, linear_weight, rms_norm_weight, eps): + linear_out = paddle.matmul(x, linear_weight) + add_out = linear_out + residual + norm_out, invar = fused_rms_norm_ext(add_out, rms_norm_weight, eps) + + ctx.save_for_backward(x, residual, linear_weight, rms_norm_weight, eps) + + return norm_out, add_out + + @staticmethod + def backward(ctx, d_rms_norm_out, d_residual_out): + x, residual, linear_weight, rms_norm_weight, eps = ctx.saved_tensor() + + linear_out = paddle.matmul(x, linear_weight) + add_out = linear_out + residual + + rms_out, invar = fused_rms_norm_ext(add_out, rms_norm_weight, eps) + + d_add_out, d_rms_norm_weight = paddle._C_ops.fused_rms_norm_ext_grad( + add_out, rms_norm_weight, invar, d_rms_norm_out, eps + ) + + d_residual = d_add_out + d_residual_out + d_linear_out = d_residual + dx, d_linear_weight = paddle._C_ops.matmul_grad(x, linear_weight, d_linear_out, False, False) + + return dx, d_residual, d_linear_weight, d_rms_norm_weight + + +class FusedLinearAddNorm(paddle.nn.Layer): + def __init__(self, hidden_size, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.linear_weight = self.create_parameter( + shape=[hidden_size, hidden_size], + dtype=self._dtype, + is_bias=False, + ) + + self.rms_norm_weight = self.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.eps = eps + + def forward(self, x, residual): + return FusedLinearAddNormFunc.apply(x, residual, self.linear_weight, self.rms_norm_weight, self.eps) + + +class FusedRMSLinearFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, linear_weight, eps): + hidden_states, invar = fused_rms_norm_ext(x, rms_norm_weight, eps) + q = paddle.matmul(hidden_states, linear_weight) + + ctx.save_for_backward(x, rms_norm_weight, linear_weight, eps) + return q + + @staticmethod + def backward(ctx, d_qkv): + x, rms_norm_weight, linear_weight, eps = ctx.saved_tensor() + hidden_states, invar = fused_rms_norm_ext(x, rms_norm_weight, eps) + h_grad, d_linear_weight = paddle._C_ops.matmul_grad(hidden_states, linear_weight, d_qkv, False, False) + + dx, d_rms_norm_weight = paddle._C_ops.fused_rms_norm_ext_grad(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_linear_weight + + +class FusedRMSLinear(paddle.nn.Layer): + def __init__(self, hidden_size, eps=1e-6, num_heads=1, num_key_value_heads=1) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = self.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + kv_hidden_size = hidden_size // num_heads * num_key_value_heads + qkv_out = hidden_size + kv_hidden_size * 2 + + self.linear_weight = self.create_parameter( + shape=[hidden_size, qkv_out], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + + def forward(self, x): + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.linear_weight, self.eps) + + +class ErnieMoEAttention(ErnieAttention): + def __init__(self, config, layer_idx): + super().__init__(config) + + self.use_linear_residual_norm_recompute = config.use_linear_residual_norm_recompute + self.use_rms_qkv_recompute = config.use_rms_qkv_recompute + if config.use_rms_qkv_recompute is True: + + assert config.use_rmsnorm is True and config.fuse_rms_norm is True + assert config.fuse_linear is True and config.use_bias is False + + assert self.fuse_attn is True + + if self.is_gqa: + self.fused_rms_norm_linear = FusedRMSLinear( + self.hidden_size, + config.rms_norm_eps, + self.num_heads, + self.num_key_value_heads, + ) + else: + self.fused_rms_norm_linear = FusedRMSLinear(self.hidden_size, config.rms_norm_eps) + del self.qkv_proj + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + inbatch_pack_offset: Optional[Tuple[paddle.Tensor]] = None, + token_type_ids: Optional[Tuple[paddle.Tensor]] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + if token_type_ids is not None: + token_type_ids = token_type_ids[:, :-1] + if self.config.sequence_parallel: + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + token_type_ids = ScatterOp.apply(token_type_ids) + token_type_ids.stop_gradient = True + q_len = self.config.seqlen + else: + q_len = hidden_states.shape[-2] + + query_states = key_states = value_states = mix_layer = None + if self.use_rms_qkv_recompute: + mix_layer = self.fused_rms_norm_linear(hidden_states) + else: + if self.fuse_attn: + mix_layer = self.qkv_proj( + hidden_states, + ) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.fuse_attn: + if self.is_gqa: + query_states, key_states, value_states = paddle.split( + mix_layer.reshape( + [ + -1, + q_len, + self.num_heads + 2 * self.num_key_value_heads, + self.head_dim, + ] + ), + [ + self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads, + ], + axis=2, + ) + mix_layer = None + else: + mix_layer = mix_layer.reshape([-1, q_len, self.num_heads, 3 * self.head_dim]) + + else: + query_states = query_states.reshape(shape=[-1, q_len, self.num_heads, self.head_dim]) + key_states = key_states.reshape( + shape=[ + -1, + q_len, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + value_states = value_states.reshape( + shape=[ + -1, + q_len, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + if self.use_recompute_attn: + assert past_key_value is None, "do not use kv cache in recompute" + assert not use_cache + attn_output, attn_weights, past_key_value = recompute( + self.rope_attn, + mix_layer, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + use_reentrant=False, + ) + else: + attn_output, attn_weights, past_key_value = self.rope_attn( + mix_layer=mix_layer, + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + ) + if self.config.sequence_parallel: + attn_output = attn_output.reshape([-1, attn_output.shape[-1]]) + + if self.use_linear_residual_norm_recompute is False: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FakeMoERouterLoss(PyLayer): + @staticmethod + def forward(ctx, x, router_loss, num_acc_steps, enable_delay_scale_loss): + ctx.num_acc_steps = num_acc_steps + ctx.loss_shape = router_loss.shape + ctx.loss_dtype = router_loss.dtype + ctx.enable_delay_scale_loss = enable_delay_scale_loss + return x + + @staticmethod + def backward(ctx, out_grad): + if ctx.enable_delay_scale_loss: + router_loss_grad_value = 1.0 + else: + router_loss_grad_value = 1.0 / ctx.num_acc_steps + + return out_grad, paddle.full(ctx.loss_shape, router_loss_grad_value, dtype=ctx.loss_dtype) + + +class ErnieDecoderLayer(nn.Layer): + def __init__(self, config, layer_idx): + super().__init__() + self._training = True + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.is_moe_infer = config.get("is_moe_infer", False) + self.config = config + self.use_moe = config.use_moe + self.self_attn = ErnieMoEAttention(config, layer_idx) + self.use_linear_residual_norm_recompute = config.use_linear_residual_norm_recompute + self.use_rms_qkv_recompute = config.use_rms_qkv_recompute + + moe_layer_start_index = ( + min(config.moe_layer_start_index) + if isinstance(config.moe_layer_start_index, (tuple, list)) + else config.moe_layer_start_index + ) + moe_layer_end_index = ( + max(config.moe_layer_end_index) + if isinstance(config.moe_layer_end_index, (tuple, list)) + else config.moe_layer_end_index + ) + + if ( + self.use_moe + and ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= moe_layer_start_index + and layer_idx <= moe_layer_end_index + ): + gate, experts, lm_gate, lm_experts, moe_statics = self._init_gate_and_experts(layer_idx) + shared_experts = self._init_shared_experts() + dense_experts = self._init_dense_experts(layer_idx) + moe_cls = MOELayer + logger.info(f"moe_cls={moe_cls}") + assert dense_experts is None + self.mlp = moe_cls( + gate, + experts, + layer_idx=layer_idx, + shared_experts=shared_experts, + group=config.moe_group, + recompute=config.use_recompute_moe, + k=config.moe_k, + all_to_all_dropout=config.moe_all_to_all_dropout, + group_experts=config.moe_group_experts, + moe_statics=moe_statics, + ) + if config.sequence_parallel: + for p in gate.parameters(): + mark_as_sequence_parallel_parameter(p) + else: + self.mlp = ErnieMLP(config) + + Norm = RMSNorm + + if self.use_rms_qkv_recompute is False: + self.input_layernorm = Norm(config) + + if self.use_linear_residual_norm_recompute is True: + assert config.hidden_dropout_prob == 0.0 + assert config.fuse_linear is True and config.use_bias is False + assert config.use_rmsnorm is True and config.fuse_rms_norm is True + self.fused_linear_add_norm = FusedLinearAddNorm(self.hidden_size, config.rms_norm_eps) + del self.self_attn.o_proj + else: + self.residual_add1 = FusedDropoutImpl(config.hidden_dropout_prob, mode="upscale_in_train") + self.post_attention_layernorm = Norm(config) + + self.residual_add2 = FusedDropoutImpl(config.hidden_dropout_prob, mode="upscale_in_train") + + if config.sequence_parallel: + if self.use_linear_residual_norm_recompute is True: + mark_as_sequence_parallel_parameter(self.fused_linear_add_norm.rms_norm_weight) + else: + mark_as_sequence_parallel_parameter(self.post_attention_layernorm.weight) + if not hasattr(config, "disable_ffn_model_parallel"): + if self.use_rms_qkv_recompute is True: + mark_as_sequence_parallel_parameter(self.self_attn.fused_rms_norm_linear.rms_norm_weight) + else: + mark_as_sequence_parallel_parameter(self.input_layernorm.weight) + + if not config.use_rmsnorm: + mark_as_sequence_parallel_parameter(self.post_attention_layernorm.bias) + mark_as_sequence_parallel_parameter(self.input_layernorm.bias) + + @property + def training(self): + return self._training + + @training.setter + def training(self, new): + if hasattr(self, "mlp_text"): + for c in self.mlp_text().sublayers(): + c.training = new + self._training = new + + + def fp8_quant_weight(self): + if isinstance(self.mlp, MOELayer): + logger.info(f"fp8 quant weight for mlp {type(self.mlp)}") + self.mlp.fp8_quant_weight() + + def _init_gate_and_experts(self, layer_idx): + cfg = deepcopy(self.config) + fc_cls = ErnieMoeMLPFused if cfg.moe_fuse_experts and not cfg.use_fp8_mlp else ErnieMoeMLP + if self.config.expert_mlp_use_bias is not None: + cfg.use_bias = self.config.expert_mlp_use_bias + + if cfg.moe_intermediate_size: + if isinstance(cfg.moe_intermediate_size, (tuple, list)): + assert isinstance(cfg.moe_num_experts, (tuple, list)) and len(cfg.moe_num_experts) == len( + cfg.moe_intermediate_size + ) + fc = [] + for _i, (num_experts, intermediate_size) in enumerate( + zip(cfg.moe_num_experts, cfg.moe_intermediate_size) + ): + ex_cfg = deepcopy(cfg) + ex_cfg.intermediate_size = intermediate_size + cur_modality_start_layer_idx = ( + cfg.moe_layer_start_index[_i] + if isinstance(cfg.moe_layer_start_index, (tuple, list)) + else cfg.moe_layer_start_index + ) + cur_modality_end_layer_idx = ( + cfg.moe_layer_end_index[_i] + if isinstance(cfg.moe_layer_end_index, (tuple, list)) + else cfg.moe_layer_end_index + ) + if layer_idx >= cur_modality_start_layer_idx and layer_idx <= cur_modality_end_layer_idx: + if _i == 1: + with paddle.utils.unique_name.guard(f"mm_expert_{layer_idx}_"): + fc.append((num_experts, fc_cls(ex_cfg))) + else: + fc.append((num_experts, fc_cls(ex_cfg))) + else: + logger.info(f"moe experts use Identity layer_idx: {layer_idx}") + fc.append((num_experts, nn.Identity())) + else: + cfg.intermediate_size = cfg.moe_intermediate_size + if cfg.moe_fuse_experts: + fc = [(1, fc_cls(cfg))] + else: + fc = [(cfg.moe_num_experts, fc_cls(cfg))] + else: + fc = [(cfg.moe_num_experts, fc_cls(cfg))] + gate, experts, lm_gate, lm_experts = get_gate(self.config, fc, layer_idx) + if cfg.moe_use_aux_free: + moe_statics = MoEStatics(cfg, layer_idx) + else: + moe_statics = None + return gate, experts, lm_gate, lm_experts, moe_statics + + def _init_shared_experts(self): + cfg = deepcopy(self.config) + if cfg.moe_num_shared_experts > 0: + if cfg.moe_intermediate_size: + inter_size = ( + cfg.moe_intermediate_size[0] + if isinstance(cfg.moe_intermediate_size, (tuple, list)) + else cfg.moe_intermediate_size + ) + cfg.intermediate_size = inter_size * cfg.moe_num_shared_experts + else: + cfg.intermediate_size = cfg.intermediate_size * cfg.moe_num_shared_experts + cfg.disable_ffn_model_parallel = False + shared_experts = ErnieMoeMLP(cfg, True) + else: + shared_experts = None + return shared_experts + + def _init_dense_experts(self, layer_idx): + cfg = deepcopy(self.config) + cfg.sequence_parallel = False + if cfg.moe_num_dense_experts > 0: + logger.info("using dense experts") + if cfg.moe_intermediate_size: + inter_size = ( + cfg.moe_intermediate_size[0] + if isinstance(cfg.moe_intermediate_size, (tuple, list)) + else cfg.moe_intermediate_size + ) + cfg.intermediate_size = inter_size * cfg.moe_num_dense_experts + else: + cfg.intermediate_size = cfg.intermediate_size * cfg.moe_num_shared_experts + cfg.disable_ffn_model_parallel = False + with paddle.utils.unique_name.guard(f"audio_expert_{layer_idx}_"): + dense_experts = ErnieMoeDenseExpert(cfg) + for p in dense_experts.parameters(): + p.expert_type = "expert_type_3" + else: + dense_experts = None + return dense_experts + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + inbatch_pack_offset: Optional[paddle.Tensor] = None, + output_gate_logits=True, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + residual = hidden_states + if token_type_ids is not None: + is_multimodel_token = token_type_ids.any() + has_dense_experts_token = (token_type_ids == self.config.moe_dense_experts_token_type_id).any() + async_loader = get_async_loader() + is_multimodel_token_cpu, is_multimodel_token_task = async_offload(is_multimodel_token, async_loader) + has_dense_experts_token_cpu, has_dense_experts_token_task = async_offload( + has_dense_experts_token, async_loader + ) + else: + is_multimodel_token_task = None + has_dense_experts_token_task = None + + if self.use_rms_qkv_recompute is False: + hidden_states = self.input_layernorm(hidden_states) + + (hidden_states, self_attn_weights, present_key_value) = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + token_type_ids=token_type_ids, + ) + + if self.use_linear_residual_norm_recompute is True: + hidden_states, residual = self.fused_linear_add_norm(hidden_states, residual) + else: + with self.model_parallel_dropout(): + hidden_states = self.residual_add1(hidden_states, residual) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if isinstance( + self.mlp, + (MOELayer,), + ): + if is_multimodel_token_task is not None: + hack_offload_wait(is_multimodel_token_task) + if has_dense_experts_token_task is not None: + hack_offload_wait(has_dense_experts_token_task) + + with profile("moe-mlp"): + hidden_states, _, router_loss, gate_logits = self.mlp(hidden_states, token_type_ids) + else: + hidden_states = self.mlp(hidden_states) + gate_logits = None + + with self.model_parallel_dropout(): + hidden_states = self.residual_add2(hidden_states, residual) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if self.use_moe: + if output_gate_logits: + outputs += (gate_logits,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + def model_parallel_dropout(self): + if self.config.tensor_parallel_degree > 1 and self.config.hidden_dropout_prob > 0.0: + current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + return get_rng_state_tracker().rng_state(current_seed) + return contextlib.nullcontext() + + +class ErniePretrainedModel(PretrainedModel): + config_class = ErnieMoEConfig + base_model_prefix = "ernie" + + @classmethod + def _get_name_mappings(cls, config: ErnieMoEConfig) -> StateDictNameMapping: + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + if config.fuse_attn_ffn: + layer_mappings = [ + [ + f"layers.{layer_index}.self_attn.qkv_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.o_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [ + f"layers.{layer_index}.mlp.up_gate_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + else: + layer_mappings = [ + [ + f"layers.{layer_index}.self_attn.q_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.k_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.v_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.o_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + if "ErnieModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "ernie." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + from models.ernie.modeling import gqa_qkv_merge_func, gqa_qkv_split_func + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + if config.num_key_value_heads is not None and config.num_key_value_heads != config.num_attention_heads: + if is_split: + qkv_fn = partial( + gqa_qkv_split_func, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.hidden_size // config.num_attention_heads, + ) + else: + qkv_fn = partial( + gqa_qkv_merge_func, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.hidden_size // config.num_attention_heads, + ) + else: + qkv_fn = partial(fn, is_column=True) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + if config.fuse_attn_ffn: + base_actions = { + "layers.0.self_attn.qkv_proj.weight": qkv_fn, + "layers.0.mlp.up_gate_proj.weight": partial(fn, is_column=True, is_naive_2fuse=True), + "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + if config.use_bias: + base_actions.update( + { + "layers.0.self_attn.qkv_proj.bias": qkv_fn, + "layers.0.mlp.up_gate_proj.bias": partial(fn, is_column=True, is_naive_2fuse=True), + "layers.0.mlp.down_proj.bias": lambda x: x, + "lm_head.bias": partial(fn, is_column=True), + } + ) + else: + base_actions = { + "layers.0.self_attn.q_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + if config.use_bias: + base_actions.update( + { + "layers.0.self_attn.q_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.up_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.down_proj.bias": lambda x: x, + "lm_head.bias": partial(fn, is_column=True), + } + ) + moe_in_mp = config.moe_group in {"mp", "model", "tp", "mpdp"} + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + newkey = key.replace("layers.0.", f"layers.{i}.") + if config.moe_group in {"mpdp"}: + final_actions[newkey] = lambda x: x + else: + final_actions[newkey] = action + if "mlp" in key and (i + 1) % config.moe_layer_interval == 0: + moe_num_experts = config.moe_num_experts + if moe_num_experts > 0: + for expert_id in range(moe_num_experts): + _key = key.replace( + "layers.0.mlp", + f"layers.{i}.mlp.experts.{expert_id}", + ) + if moe_in_mp: + final_actions[_key] = lambda x: x + else: + final_actions[_key] = action + for _ in range(config.moe_num_shared_experts): + _key = key.replace("layers.0.mlp", f"layers.{i}.mlp.shared_experts") + final_actions[_key] = action + for _ in range(config.moe_num_dense_experts): + _key = key.replace("layers.0.mlp", f"layers.{i}.mlp.dense_experts") + final_actions[_key] = action + else: + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + + elif "self_attn" in key and ( + "qkv_proj" in key or "q_proj" in key or "k_proj" in key or "v_proj" in key + ): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + else: + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + else: + final_actions[key] = action + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + return mappings + + def _init_weights(self, layer): + if get_rng_state_tracker().states_: + rng_tracker = get_rng_state_tracker().rng_state + else: + rng_tracker = contextlib.nullcontext + + if isinstance( + layer, + ( + ColumnParallelLinear, + RowParallelLinear, + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + VocabParallelEmbedding, + ErnieLMHead, + nn.Embedding, + BMMLinear, + nn.Linear, + paddle.incubate.nn.FusedLinear, + ), + ): + if not hasattr(layer, "weight"): + return + + is_moe = getattr(layer.weight, "no_sync", False) + with rng_tracker("local_seed" if is_moe else "model_parallel_rng"): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + layer.weight.set_value( + paddle.randn(layer.weight.shape, dtype=dtype).scale(self.config.initializer_range) + ) + paddle.set_default_dtype(dtype) + logger.info( + f"dist-init-fc: shape={layer.weight.shape}, dtype={layer.weight.dtype} " + f"range={self.config.initializer_range},type={type(layer)}, " + f'norm={layer.weight.astype("float32").norm().item()},is_moe={is_moe}' + ) + elif isinstance(layer, (Top2Gate, TopKGateFused)): + if not hasattr(layer, "weight"): + return + with rng_tracker("model_parallel_rng"): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + layer.weight.set_value( + paddle.randn(layer.weight.shape, dtype=layer.weight.dtype).scale(self.config.initializer_range) + ) + logger.info( + f"dist-init-moe_gate: shape={layer.weight.shape}, dtype={layer.weight.dtype} " + f"range={self.config.initializer_range},type={type(layer)}, " + f'norm={layer.weight.astype("float32").norm().item()}' + ) + if isinstance(self.config.moe_num_experts, (tuple, list)): + for i in range(1, len(self.config.moe_num_experts)): + layer_weight = getattr(layer, f"weight_{i}") + layer_weight.set_value( + paddle.randn(layer_weight.shape, dtype=layer_weight.dtype).scale( + self.config.initializer_range + ) + ) + logger.info( + f"dist-init-moe_gate: shape={layer_weight.shape}, dtype={layer_weight.dtype} " + f"range={self.config.initializer_range},type={type(layer)}, " + f'norm={layer_weight.astype("float32").norm().item()}' + ) + paddle.set_default_dtype(dtype) + + elif isinstance(layer, RotaryEmbedding): + head_dim = self.config.hidden_size // self.config.num_attention_heads + inv_freq = 1.0 / (layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) + t = np.arange(layer.max_position_embeddings, dtype="float32") + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate([freqs, freqs], axis=-1) + cos_cached = np.cos(emb)[:, :] + sin_cached = np.sin(emb)[:, :] + layer.cos_cached.set_value(cos_cached) + layer.sin_cached.set_value(sin_cached) + + +@register_base_model +class ErnieModel(ErniePretrainedModel): + def __init__(self, config: ErnieMoEConfig): + if config.moe_group in {"mp", "model", "tp", "mpdp"}: + logger.info(f"disable FFN tensor model parallel, moe-group={config.moe_group}") + config.disable_ffn_model_parallel = True + + config.moe_group = _parse_moe_group(config.moe_group) + + config.moe_world_size = dist.get_world_size(config.moe_group) + if config.moe_world_size < 0: + config.moe_world_size = 1 + config.moe_rank = dist.get_rank(config.moe_group) + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.config = config + + if config.tensor_parallel_degree > 1: + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + self.layers = nn.LayerList([ErnieDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + Norm = RMSNorm + + self.norm = Norm(config) + + self.gradient_checkpointing = False + + if self.config.multi_token_pred_depth > 0: + self.mtp_block = paddle.nn.LayerList( + [ErnieDecoderLayer(config, layer_idx) for layer_idx in range(self.config.multi_token_pred_depth)] + ) + Norm = RMSNorm + + self.mtp_hidden_norm = paddle.nn.LayerList( + [Norm(config) for _ in range(self.config.multi_token_pred_depth)] + ) + self.mtp_emb_norm = paddle.nn.LayerList([Norm(config) for _ in range(self.config.multi_token_pred_depth)]) + + LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else paddle.nn.Linear + self.mtp_linear_proj = paddle.nn.LayerList( + [ + LinearFN( + self.config.hidden_size * 2, + self.config.hidden_size, + bias_attr=config.use_bias, + ) + for _ in range(self.config.multi_token_pred_depth) + ] + ) + if config.sequence_parallel: + for mtp_linear in self.mtp_linear_proj: + mark_as_sequence_parallel_parameter(mtp_linear.weight) + if config.use_bias: + mark_as_sequence_parallel_parameter(mtp_linear.bias) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @classmethod + def _prepare_decoder_attention_mask(cls, attention_mask, input_shape, past_key_values_length, dtype): + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length, dtype=dtype + ) + + if attention_mask is not None: + expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + combined_attention_mask = paddle.maximum( + combined_attention_mask.astype(dtype), + paddle.to_tensor(float(finfo(dtype).min), dtype=dtype), + ) + return combined_attention_mask + + @paddle.jit.not_to_static + def recompute_training( + self, + layer_module, + hidden_states, + attention_mask, + position_ids, + token_type_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_gate_logits=False) + + return custom_forward + + decoderlayer_act_offload_settings = self.config.get( + "decoderlayer_act_offload_settings", {"type": "", "value": ""} + ) + + setting_type = decoderlayer_act_offload_settings["type"] + offload_value = decoderlayer_act_offload_settings["value"] + + def get_offload_kwargs(layer_idx, setting_type, offload_value): + offload_kwargs = {} + if "mod" == setting_type: + assert isinstance(offload_value, (list, tuple)) + v1, v2 = offload_value + offload_kwargs["offload_indices"] = [0] if layer_idx % v1 == v2 else [] + elif "layer_idxs" == setting_type: + offload_kwargs["offload_indices"] = [0] if layer_idx in offload_value else [] + return offload_kwargs + + layer_idx = layer_module.layer_idx + if layer_idx == 0: + offload_kwargs = {} + else: + offload_kwargs = get_offload_kwargs(layer_idx, setting_type, offload_value) + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + position_ids, + token_type_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + **offload_kwargs, + ) + return hidden_states + + def forward( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + inbatch_pack_offset=None, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + seq_length -= self.config.multi_token_pred_depth + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = paddle.shape(past_key_values[0][0])[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.astype(self.embed_tokens.weight.dtype) + + if self.config.multi_token_pred_depth > 0: + inputs_embeds_extra = inputs_embeds[:, -self.config.multi_token_pred_depth :, :] + inputs_embeds = inputs_embeds[:, : -self.config.multi_token_pred_depth, :] + inputs_embeds_ori = inputs_embeds + + if self.config.sequence_parallel: + inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + can_use_fa = self.config.use_flash_attn + can_mem_eff_attn = self.config.use_mem_eff_attn and inbatch_pack_offset is not None + if can_use_fa or can_mem_eff_attn: + if attention_mask is not None: + attention_mask = None + + elif attention_mask is None: + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if attention_mask is not None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + cache_length, + inputs_embeds.dtype, + ) + hidden_states = inputs_embeds + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + all_router_loss = 0.0 if self.config.use_moe else None + all_gate_logits = () + mtp_outputs = [] + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + has_gradient = not hidden_states.stop_gradient + if self.config.use_recompute and has_gradient: + layer_outputs = self.recompute_training( + decoder_layer, + hidden_states, + attention_mask, + position_ids, + token_type_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + token_type_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + if self.config.use_moe: + if not (self.config.use_recompute and has_gradient): + layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1] + all_gate_logits = all_gate_logits + (gate_logits,) + + if self.config.multi_token_pred_depth > 0: + mtp_outputs.append(hidden_states) + + for depth in range(self.config.multi_token_pred_depth): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) + + inputs_embeds_cur_depth = paddle.concat( + [ + inputs_embeds_ori[:, (depth + 1) :, :], + inputs_embeds_extra[:, : (depth + 1), :], + ], + axis=1, + ) + + inputs_embeds_cur_depth_norm = self.mtp_emb_norm[depth](inputs_embeds_cur_depth) + hidden_states_norm = self.mtp_hidden_norm[depth](hidden_states) + + inputs_embeds_cur_depth = self.mtp_linear_proj[depth]( + paddle.concat([inputs_embeds_cur_depth_norm, hidden_states_norm], axis=-1) + ) + + if self.config.sequence_parallel: + inputs_embeds_cur_depth = inputs_embeds_cur_depth.reshape([-1, inputs_embeds_cur_depth.shape[-1]]) + inputs_embeds_cur_depth = ScatterOp.apply(inputs_embeds_cur_depth) + + decoder_layer = self.mtp_block[depth] + past_key_value = None + layer_outputs = decoder_layer( + inputs_embeds_cur_depth, + attention_mask, + position_ids, + token_type_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if self.config.use_moe: + if not (self.config.use_recompute and has_gradient): + layer_outputs, gate_logits = ( + layer_outputs[:-1], + layer_outputs[-1], + ) + all_gate_logits = all_gate_logits + (gate_logits,) + + mtp_outputs.append(hidden_states) + mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs] + hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:] + else: + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_router_loss, + all_gate_logits, + mtp_outputs, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + router_loss=all_router_loss, + gate_logits=all_gate_logits, + mtp_outputs=mtp_outputs, + ) + + +ErnieMoELMHead = ErnieLMHead + + +class ErniePretrainingCriterion(ErniePretrainingCriterionBase): + def __init__(self, config, return_tuple=True): + super(ErniePretrainingCriterion, self).__init__(config, return_tuple=return_tuple) + self.ignored_index = getattr(config, "ignored_index", -100) + self.config = config + self.return_tuple = return_tuple + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: + logger.info("using parallel cross entropy, take care") + self.loss_func = fleet.meta_parallel.ParallelCrossEntropy() + else: + self.loss_func = paddle.nn.CrossEntropyLoss( + reduction="none", + ) + + def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None): + if self.config.multi_token_pred_depth > 0: + masked_lm_labels_ori = masked_lm_labels + masked_lm_labels = masked_lm_labels[:, : -self.config.multi_token_pred_depth] + seq_length = masked_lm_labels.shape[1] + res = super().forward( + prediction_scores, + masked_lm_labels, + ) + global_training_logs = get_global_training_logs() + + if self.config.multi_token_pred_depth > 0: + global_training_logs.update(mtp_depth_0_loss=res[0].clone().detach()) + mtp_loss_res = [] + for depth in range(self.config.multi_token_pred_depth): + prediction_scores_cur_depth = mtp_logits[depth] + masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)] + res_cur_depth = super().forward( + prediction_scores_cur_depth, + masked_lm_labels_cur_depth, + ) + mtp_loss_res.append(res_cur_depth) + global_training_logs.update(**{f"mtp_depth_{depth + 1}_loss": res_cur_depth[0].clone().detach()}) + + def add_loss(main_loss, loss): + return main_loss + loss - loss.detach() + + if self.return_tuple: + loss, loss_sum = res + if self.config.multi_token_pred_depth > 0: + loss = add_loss( + loss, + self.config.multi_token_pred_lambda * sum([x[0] for x in mtp_loss_res]) / len(mtp_loss_res), + ) + loss_sum = loss_sum + self.config.multi_token_pred_lambda * sum( + [x[1].detach() for x in mtp_loss_res] + ) / len(mtp_loss_res) + else: + loss, loss_sum = res, None + if self.config.multi_token_pred_depth > 0: + loss = add_loss( + loss, + self.config.multi_token_pred_lambda * sum([x[0] for x in mtp_loss_res]) / len(mtp_loss_res), + ) + + global_training_logs.update(lm_loss=loss.clone().detach()) + if router_loss is not None and isinstance(router_loss, paddle.Tensor): + loss = loss + router_loss - router_loss.detach() + if isinstance(router_loss, paddle.Tensor): + global_training_logs.update(router_loss=router_loss.detach()) + return loss, loss_sum + + +class ErnieMoEForCausalLM(ErniePretrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.sequence_parallel: + logger.info(f"using sequence_parallel, input seqlen={config.seqlen}") + assert config.seqlen is not None + assert ( + config.tensor_parallel_degree > 1 + ), f"sequence-parallel needs mp>1, got mp={config.tensor_parallel_degree}" + + new_initializer_range = math.sqrt(0.3333 / config.hidden_size) + logger.info(f"change initializer-range from {config.initializer_range} to {new_initializer_range}") + config.initializer_range = new_initializer_range + self.config = config + self.ernie = ErnieModel(config) + self.lm_head = ErnieMoELMHead(config) + self.criterion = ErniePretrainingCriterion(config) + + self.tie_weights() + + if self.config.fuse_rms_norm: + logger.info("Use fusedRMSNorm") + else: + logger.info("Use normal RMSNorm") + + def _post_init(self, original_init, *args, **kwargs): + super()._post_init(self, original_init, *args, **kwargs) + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + logger.info(f"using post init div: factor:{factor}") + with paddle.no_grad(): + for layer in self.ernie.layers: + if self.config.use_linear_residual_norm_recompute is True: + layer.fused_linear_add_norm.linear_weight.scale_(factor) + else: + if isinstance( + layer.self_attn.o_proj, + (MOELayer,), + ): + for e in layer.self_attn.o_proj.experts: + e.weight.scale_(factor) + if hasattr(layer.self_attn.o_proj, "dense_experts"): + layer.self_attn.o_proj.dense_experts.down_proj.weight.scale_(factor) + else: + layer.self_attn.o_proj.weight.scale_(factor) + + if isinstance( + layer.mlp, + (MOELayer,), + ): + for e in layer.mlp.experts: + if isinstance(e, ErnieMLP): + e.down_proj.weight.scale_(factor) + if getattr(layer.mlp, "dense_experts", None) and isinstance(layer.mlp.dense_experts, ErnieMLP): + layer.mlp.dense_experts.down_proj.weight.scale_(factor) + else: + layer.mlp.down_proj.weight.scale_(factor) + + def set_state_dict(self, state_dict, *args, **kwargs): + state_dict = moe_statedict_upcycle( + state_dict, + self.config, + self.lm_head.weight.dtype, + self._get_tensor_parallel_mappings(self.config, is_split=False), + self._get_tensor_parallel_mappings(self.config, is_split=True), + ) + state_dict = moe_statedict_cherry_pick(state_dict, self.config) + state_dict = moe_ep2mp( + state_dict, + self.config, + self._get_tensor_parallel_mappings(self.config, is_split=True), + ) + ret = super().set_state_dict(state_dict, *args, **kwargs) + logger.info(f"set_state_dict: {ret}") + return ret + + def get_input_embeddings(self): + return self.ernie.embed_tokens + + def set_input_embeddings(self, value): + self.ernie.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.ernie = decoder + + def get_decoder(self): + return self.ernie + + @staticmethod + def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id): + is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any( + input_ids == pad_token_id + ).numpy().item() + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + attention_mask = (input_ids != pad_token_id).astype("int64") + else: + attention_mask = paddle.ones_like(input_ids, dtype="int64") + return attention_mask + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + attention_mask = kwargs.get("attention_mask", None) + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": True, + "attention_mask": attention_mask, + "return_dict": True, + } + ) + + if self.config.rope_3d: + model_inputs.update({"position_ids": kwargs["position_ids"]}) + + return model_inputs + + def update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False): + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = paddle.concat([token_type_ids, token_type_ids[:, -1:]], axis=-1) + + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [ + attention_mask, + paddle.ones([attention_mask.shape[0], 1], dtype="int64"), + ], + axis=-1, + ) + if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None: + role_ids = model_kwargs["role_ids"] + model_kwargs["role_ids"] = paddle.concat([role_ids, role_ids[:, -1:]], axis=-1) + + if self.config.rope_3d: + assert "position_ids" in model_kwargs, "position_ids must be provided if rope_3d is on" + position_ids = model_kwargs["position_ids"] + + model_kwargs["position_ids"] = paddle.concat( + [ + position_ids, + position_ids.max(axis=(1, 2), keepdim=True).tile([1, 1, 3]) + 1, + ], + axis=1, + ) + + return model_kwargs + + def forward( + self, + input_ids, + position_ids=None, + attention_mask=None, + token_type_ids=None, + inputs_embeds=None, + labels=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=False, + ignored_index=0, + data_id=None, + src_id=None, + inbatch_pack_offset=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + inbatch_pack_offset=inbatch_pack_offset, + ) + + hidden_states = outputs.last_hidden_state + mtp_outputs = outputs.mtp_outputs + + logits = self.lm_head(hidden_states) + mtp_logits = [] + if len(mtp_outputs) > 0: + mtp_logits = [self.lm_head(_hidden_states) for _hidden_states in mtp_outputs] + + if return_dict: + if labels is not None: + loss, _ = self.criterion(logits, labels) + else: + loss = None + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_loss=outputs.router_loss if self.config.use_moe else None, + ) + if self.config.use_moe: + router_loss = outputs.router_loss + else: + router_loss = None + assert labels is not None + return self.criterion(logits, labels, router_loss, mtp_logits) diff --git a/ernie/ERNIE/examples/pre-training/models/ernie/modeling_pp.py b/ernie/ERNIE/examples/pre-training/models/ernie/modeling_pp.py new file mode 100644 index 0000000000000000000000000000000000000000..6a21cceea910e3405dc53238d6029ebdd94486ea --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/ernie/modeling_pp.py @@ -0,0 +1,995 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import copy +import logging +import math +from collections import deque + +import numpy as np +import paddle +import paddle.distributed as dist +from models.ernie import ErnieMoEConfig +from models.ernie.modeling_moe import ( + ErnieDecoderLayer, + ErnieMLP, + ErnieModel, + ErnieMoELMHead, + ErniePretrainedModel, + ErniePretrainingCriterion, + RMSNorm, + RotaryEmbedding, + _parse_moe_group, + moe_ep2mp, + moe_statedict_upcycle, +) +from models.moe.moe_layer import MOELayer +from models.moe.top2_gate import Top2Gate, TopKGateFused +from models.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ScatterOp, + mark_as_sequence_parallel_parameter, +) +from models.utils import inplace_offload +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu.mp_layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + PipelineLayer, + SharedLayerDesc, +) +from paddle.distributed.fleet.utils import recompute +from paddleformers.transformers import PretrainedModel + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} + + +try: + from paddle.distributed.fleet.meta_parallel.pipeline_parallel import ( + pipeline_bubble_hooks_, + ) +except ImportError: + pipeline_bubble_hooks_ = None + +try: + from paddle.framework.recall_error import AADIFF_ERROR +except ImportError: + AADIFF_ERROR = "CUDA error(1001)" + + +input_ids_for_mtp = deque() +NativeLinear = nn.Linear + +logger = logging.getLogger(__name__) + + +class ErnieEmbeddingPipe(nn.Layer): + def __init__(self, config): + self.sequence_parallel = config.sequence_parallel + self.use_mem_eff_attn = config.use_mem_eff_attn + self.config = config + + super(ErnieEmbeddingPipe, self).__init__() + self.use_moe = config.use_moe + if config.tensor_parallel_degree > 1: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + @property + def embedding_weight(self): + return self.embed_tokens.weight + + def forward(self, args): + if isinstance(args, tuple): + if len(args) == 3: + input_ids, attention_mask, position_ids = args + inbatch_pack_offset = None + elif len(args) == 2: + if self.use_mem_eff_attn: + input_ids, inbatch_pack_offset = args + position_ids, attention_mask = None, None + inbatch_pack_offset.stop_gradient = True + else: + input_ids, attention_mask = args + position_ids = None + inbatch_pack_offset = None + + else: + input_ids = args + attention_mask, position_ids, inbatch_pack_offset = None, None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + emb = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype) + + if self.config.multi_token_pred_depth > 0: + if self.config.enable_mtp_magic_send: + emb = emb[:, : -self.config.multi_token_pred_depth, :] + if self.sequence_parallel: + emb = emb.reshape([-1, emb.shape[-1]]) + emb = ScatterOp.apply(emb) + else: + inputs_embeds_extra = emb[:, -self.config.multi_token_pred_depth :, :] + inputs_embeds = emb[:, : -self.config.multi_token_pred_depth, :] + inputs_embeds_ori = inputs_embeds + batch_size, seq_length, _ = inputs_embeds.shape + + if self.sequence_parallel: + inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) + inputs_embeds = ScatterOp.apply(inputs_embeds) + mtp_emb_res = [inputs_embeds] + for depth in range(self.config.multi_token_pred_depth): + inputs_embeds_mtp = paddle.concat( + [ + inputs_embeds_ori[:, (depth + 1) :, :], + inputs_embeds_extra[:, : (depth + 1), :], + ], + axis=1, + ) + if self.sequence_parallel: + inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) + inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) + mtp_emb_res.append(inputs_embeds_mtp) + res = paddle.concat(mtp_emb_res) + return [res] + else: + if self.sequence_parallel: + emb = emb.reshape([-1, emb.shape[-1]]) + emb = ScatterOp.apply(emb) + + if attention_mask is not None: + batch_size, seq_length = input_ids.shape + attention_mask = ErnieModel._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), 0, emb.dtype + ) + attention_mask.stop_gradient = True + + ret = (emb,) + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if inbatch_pack_offset is not None: + ret += (inbatch_pack_offset.clone(),) + if self.config.multi_token_pred_depth > 0 and not self.config.enable_mtp_magic_send: + ret += (input_ids,) + assert len(ret) == 2, "mtp only support one input which is input_ids" + if len(ret) == 1: + ret = ret[0] + return ret + + +class MTPEmbeddingPipe(ErnieEmbeddingPipe): + def __init__(self, config): + super(MTPEmbeddingPipe, self).__init__(config) + + @property + def embedding_weight(self): + return self.embed_tokens.weight + + def forward(self, args): + assert ( + self.config.enable_mtp_magic_send + ), "MTPEmbedding can only be added into model only support enable_mtp_magic_send=True" + + global input_ids_for_mtp + assert len(input_ids_for_mtp) > 0, "input_ids for mtp is empty" + hidden_states = args[0] + input_ids = input_ids_for_mtp.popleft() + input_embeds = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype) + + return (hidden_states, input_embeds) + + +class EmptyLayer(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +class ErnieDecoderLayerPipe(ErnieDecoderLayer): + def __init__(self, config, layer_idx, use_full_recompute=False): + super().__init__(config, layer_idx) + self.layer_idx = layer_idx + self.use_full_recompute = use_full_recompute + logger.info(f"using pp full recompute={use_full_recompute}") + self.use_mem_eff_attn = config.use_mem_eff_attn + + def forward(self, args): + if self.config.multi_token_pred_depth > 0 and not self.config.enable_mtp_magic_send: + res = args[0] + tensor_list = paddle.split(res, self.config.multi_token_pred_depth + 1) + inputs_embeds = tensor_list[-self.config.multi_token_pred_depth :] + args = tuple(tensor_list[: -self.config.multi_token_pred_depth]) + else: + res = None + + if isinstance(args, tuple): + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + elif len(args) == 2: + if self.use_mem_eff_attn: + hidden_states, inbatch_pack_offset = args + position_ids, attention_mask = None, None + inbatch_pack_offset.stop_gradient = True + else: + hidden_states, attention_mask = args + position_ids, inbatch_pack_offset = None, None + elif len(args) == 1: + (hidden_states,) = args + attention_mask, position_ids, inbatch_pack_offset = None, None, None + else: + hidden_states = args + attention_mask, position_ids, inbatch_pack_offset = None, None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + if self.training and self.use_full_recompute: + decoderlayer_act_offload_settings = self.config.get( + "decoderlayer_act_offload_settings", {"type": "", "value": ""} + ) + setting_type = decoderlayer_act_offload_settings["type"] + offload_value = decoderlayer_act_offload_settings["value"] + offload_kwargs = {} + if "mod" == setting_type: + assert isinstance(offload_value, (list, tuple)) + v1, v2 = offload_value + offload_kwargs["offload_indices"] = [0] if self.layer_idx % v1 == v2 else [] + elif "layer_idxs" == setting_type: + offload_kwargs["offload_indices"] = [0] if self.layer_idx in offload_value else [] + + if offload_kwargs.get("offload_indices", []) and res is not None: + inplace_offload(res) + + ret = recompute( + super().forward, + hidden_states, + attention_mask, + position_ids, + None, + False, + None, + False, + inbatch_pack_offset, + False, + **offload_kwargs, + ) + else: + ret = super().forward( + hidden_states, + attention_mask, + position_ids, + None, + False, + None, + False, + inbatch_pack_offset, + False, + ) + if isinstance(ret, paddle.Tensor): + ret = (ret,) + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if inbatch_pack_offset is not None: + ret += (inbatch_pack_offset.clone(),) + if len(ret) == 1: + (ret,) = ret + if self.config.multi_token_pred_depth > 0: + if self.config.enable_mtp_magic_send: + ret = (ret,) + else: + ret = (paddle.concat([ret, *inputs_embeds]),) + return ret + + +class RMSNormPipe(RMSNorm): + def __init__(self, config): + super().__init__(config) + self.use_moe = config.use_moe + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, args): + if self.config.multi_token_pred_depth > 0: + if self.config.enable_mtp_magic_send: + assert len(args) == self.config.multi_token_pred_depth + 1, "the length is not valid in mtp" + mtp_outputs = [] + for hidden_states in args: + mtp_outputs.append(super().forward(hidden_states)) + return mtp_outputs + else: + tensor_list = paddle.split(args[0], self.config.multi_token_pred_depth + 1) + mtp_outputs = [] + for hidden_states in tensor_list: + mtp_outputs.append(super().forward(hidden_states)) + return mtp_outputs + else: + if self.use_moe: + hidden_states = args[:1] + if isinstance(args, tuple): + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + elif len(args) == 2: + hidden_states, attention_mask = args + else: + hidden_states = args + hidden_states = super().forward(hidden_states) + return hidden_states + + +class ErnieMoELMHeadPipe(ErnieMoELMHead): + def forward(self, args): + if self.config.multi_token_pred_depth > 0: + logits = list() + for _hidden_states in args: + logits.append(super().forward(_hidden_states)) + return logits + hidden_states = args + logits = super().forward(hidden_states) + return logits + + +class MTPLayer(nn.Layer): + def __init__(self, config): + super().__init__() + config = copy.deepcopy(config) + self.config = config + if self.config.use_recompute_mtp: + self.config.use_recompute = False + assert self.config.multi_token_pred_depth > 0, "Adding MTPLayer must assign value to multi_token_pred_depth" + + self.mtp_block = paddle.nn.LayerList( + [ErnieDecoderLayer(config, layer_idx) for layer_idx in range(self.config.multi_token_pred_depth)] + ) + Norm = RMSNorm + self.mtp_hidden_norm = paddle.nn.LayerList([Norm(config) for _ in range(self.config.multi_token_pred_depth)]) + self.mtp_emb_norm = paddle.nn.LayerList([Norm(config) for _ in range(self.config.multi_token_pred_depth)]) + + LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else paddle.nn.Linear + self.mtp_linear_proj = paddle.nn.LayerList( + [ + LinearFN( + self.config.hidden_size * 2, + self.config.hidden_size, + bias_attr=config.use_bias, + ) + for _ in range(self.config.multi_token_pred_depth) + ] + ) + if config.sequence_parallel: + for mtp_linear in self.mtp_linear_proj: + mark_as_sequence_parallel_parameter(mtp_linear.weight) + if config.use_bias: + mark_as_sequence_parallel_parameter(mtp_linear.bias) + + def forward(self, args): + def custom_forward(*inputs): + return self.forward_impl(*inputs) + + if self.config.use_recompute_mtp: + return recompute(custom_forward, *args) + else: + return custom_forward(*args) + + def forward_impl(self, *args): + if self.config.enable_mtp_magic_send: + assert isinstance(args, tuple), "Input for MTPLayer must be tuple" + hidden_states, inputs_embeds = args + inputs_embeds_extra = inputs_embeds[:, -self.config.multi_token_pred_depth :, :] + inputs_embeds = inputs_embeds[:, : -self.config.multi_token_pred_depth, :] + inputs_embeds_ori = inputs_embeds + else: + res = args[0] + tensor_list = paddle.split(res, self.config.multi_token_pred_depth + 1) + hidden_states = tensor_list[0] + inputs_embeds_cur_depth_list = tensor_list[1:] + + output_list = [hidden_states] + for depth in range(self.config.multi_token_pred_depth): + if self.config.enable_mtp_magic_send: + inputs_embeds_cur_depth = paddle.concat( + [ + inputs_embeds_ori[:, (depth + 1) :, :], + inputs_embeds_extra[:, : (depth + 1), :], + ], + axis=1, + ) + + if self.config.sequence_parallel: + inputs_embeds_cur_depth = inputs_embeds_cur_depth.reshape([-1, inputs_embeds_cur_depth.shape[-1]]) + inputs_embeds_cur_depth = ScatterOp.apply(inputs_embeds_cur_depth) + else: + inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] + + inputs_embeds_cur_depth_norm = self.mtp_emb_norm[depth](inputs_embeds_cur_depth) + hidden_states_norm = self.mtp_hidden_norm[depth](hidden_states) + + inputs_embeds_cur_depth = self.mtp_linear_proj[depth]( + paddle.concat([inputs_embeds_cur_depth_norm, hidden_states_norm], axis=-1) + ) + + decoder_layer = self.mtp_block[depth] + + layer_outputs = decoder_layer( + inputs_embeds_cur_depth, + None, + None, + None, + False, + None, + False, + None, + False, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + output_list.append(hidden_states) + + if self.config.enable_mtp_magic_send: + return tuple(output_list) + else: + res = paddle.concat(output_list) + return (res,) + + +class ErniePretrainingCriterionPipe(ErniePretrainingCriterion): + def __init__(self, config): + super().__init__(config) + + def forward(self, logits, labels): + if self.config.multi_token_pred_depth > 0: + mtp_logits = logits[1:] + logits = logits[0] + loss, loss_sum = super().forward(logits, labels, mtp_logits=mtp_logits) + if not self.training: + return loss_sum + return loss + loss, loss_sum = super().forward(logits, labels) + if not self.training: + return loss_sum + return loss + + +class PipelinePretrainedModel(PretrainedModel): + def __init__(self, config, *args, **kwargs): + self.config = config + super().__init__(config, *args, **kwargs) + + def init(self, config, *args, **kwargs): + self._sequential_layers = [] + self._pipeline_name_mapping = None + self._pp_to_single_mapping = None + + def add_sequential_layer(self, layer_desc, name_prefix=""): + self._sequential_layers.append({"layer": layer_desc, "name_prefix": name_prefix}) + + def get_sequential_layers(self): + return [x["layer"] for x in self._sequential_layers] + + def get_sequential_name_prefixs(self): + return {str(index): x["name_prefix"] for index, x in enumerate(self._sequential_layers)} + + def get_shardlayer_prefix(self, name_splited): + shared_layer_names = {s.layer_name for s in self._layers_desc if isinstance(s, SharedLayerDesc)} + assert name_splited[1] in shared_layer_names, f"The shared layer name {name_splited[1]} must be in prefixes!" + shared_layer_key = name_splited[1] + for idx, layer in enumerate(self._layers_desc): + if isinstance(layer, SharedLayerDesc) and layer.layer_name == shared_layer_key: + if self.get_stage_from_index(idx) == self._stage_id: + return self.get_sequential_name_prefixs()[str(idx)] + + raise ValueError(f"The shared layer {shared_layer_key} must be in the current stage!") + + def _set_pipeline_name_mapping(self, mappings=None): + if mappings is not None: + self._pipeline_name_mapping = mappings + else: + single_to_pp_mapping = {} + pp_to_single_mapping = {} + + state_dict_keys = list(super().state_dict().keys()) + first_key = "" + for k in state_dict_keys: + if "shared_layers" not in k: + first_key = k + break + first_key = first_key.split(".") + use_virtual_pp_degree = first_key[0].isdigit() and first_key[1].isdigit() + + prefixes = self.get_sequential_name_prefixs() + for k in state_dict_keys: + name_splited = k.split(".") + if use_virtual_pp_degree: + if name_splited[0].isdigit(): + if name_splited[1].isdigit(): + idx = str(int(name_splited[0]) + int(name_splited[1])) + single_name = [prefixes[idx]] + single_name.extend(name_splited[2:]) + else: + single_name = [prefixes[str(len(prefixes) - 1)]] + single_name.extend(name_splited[2:]) + logger.warning( + f"Please check! we treat this key as last layer, get {k}, \ + set origin name as {'.'.join(single_name)}" + ) + elif name_splited[0] == "shared_layers": + single_name = [self.get_shardlayer_prefix(name_splited)] + single_name.extend(name_splited[2:]) + else: + single_to_pp_mapping[k] = k + pp_to_single_mapping[k] = k + continue + else: + idx = name_splited[0] + if idx.isdigit(): + single_name = [] if prefixes[idx] == "" else [prefixes[idx]] + single_name.extend(name_splited[1:]) + elif idx == "shared_layers": + single_name = [self.get_shardlayer_prefix(name_splited)] + single_name.extend(name_splited[2:]) + else: + single_to_pp_mapping[k] = k + pp_to_single_mapping[k] = k + continue + + single_to_pp_mapping[".".join(single_name)] = k + pp_to_single_mapping[k] = ".".join(single_name) + + self._pipeline_name_mapping = single_to_pp_mapping + self._pp_to_single_mapping = pp_to_single_mapping + + return self._pipeline_name_mapping + + def _check_shared_model_state(self): + if self._pipeline_name_mapping is None: + self._set_pipeline_name_mapping() + + super_state_dict = super().state_dict() + structure_name_to_tensor = {} + for k, v in super_state_dict.items(): + k = self._pp_to_single_mapping[k] + if k not in structure_name_to_tensor: + structure_name_to_tensor[k] = v + else: + old_v = structure_name_to_tensor[k] + assert old_v is v, f"Shared tensor with different structure name: {k}" + + missing_shared_keys = {} + for k, v in self._pp_to_single_mapping.items(): + mapped_k = self._pipeline_name_mapping[v] + if k != mapped_k: + missing_shared_keys[k] = mapped_k + return missing_shared_keys + + def state_dict(self, *args, **kwargs): + state_dict = super().state_dict(*args, **kwargs) + + if self._pipeline_name_mapping is None: + self._set_pipeline_name_mapping() + + for k in list(state_dict.keys()): + v = state_dict.pop(k) + state_dict[self._pp_to_single_mapping[k]] = v + + return state_dict + + def _init_weights(self, layer): + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + else: + rng_tracker = contextlib.nullcontext + + if isinstance( + layer, + ( + ColumnParallelLinear, + RowParallelLinear, + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + VocabParallelEmbedding, + ErnieMoELMHead, + nn.Embedding, + NativeLinear, + paddle.incubate.nn.FusedLinear, + ), + ): + is_moe = getattr(layer.weight, "no_sync", False) + with rng_tracker("local_seed" if is_moe else "model_parallel_rng"): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + layer.weight.set_value( + paddle.randn(layer.weight.shape, dtype=dtype).scale(self.config.initializer_range) + ) + paddle.set_default_dtype(dtype) + + elif isinstance(layer, (Top2Gate, TopKGateFused)): + if not hasattr(layer, "weight"): + return + with rng_tracker("model_parallel_rng"): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + moe_num_experts = self.config.moe_num_experts + if isinstance(moe_num_experts, (list, tuple)): + moe_num_experts = moe_num_experts[0] + if self.config.moe_group_experts: + layer.weight.set_value( + paddle.randn(layer.weight.shape, dtype=layer.weight.dtype).scale(self.config.initializer_range) + ) + else: + layer.weight.set_value( + paddle.randn( + [self.config.hidden_size, moe_num_experts], + dtype="float32", + ).scale(self.config.initializer_range) + ) + if isinstance(self.config.moe_num_experts, (tuple, list)): + for i in range(1, len(self.config.moe_num_experts)): + layer_weight = getattr(layer, f"weight_{i}") + layer_weight.set_value( + paddle.randn(layer_weight.shape, dtype=layer_weight.dtype).scale( + self.config.initializer_range + ) + ) + paddle.set_default_dtype(dtype) + + elif isinstance(layer, RotaryEmbedding): + head_dim = self.config.hidden_size // self.config.num_attention_heads + inv_freq = 1.0 / (layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) + + t = np.arange(layer.max_position_embeddings, dtype="float32") + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate([freqs, freqs], axis=-1) + cos_cached = np.cos(emb)[:, :] + sin_cached = np.sin(emb)[:, :] + + layer.cos_cached.set_value(cos_cached) + layer.sin_cached.set_value(sin_cached) + + +def get_pp_vp_split_layers(config): + hcg = fleet.get_hybrid_communicate_group() + pp_size = max(hcg.get_pipe_parallel_world_size(), 1) + vp_size = max(config.virtual_pp_degree, 1) + layer_num = config.num_hidden_layers + selective_no_recompute_num = config.selective_no_recompute_num + + no_recompute_layer_num = [] + if selective_no_recompute_num == 0: + return set(no_recompute_layer_num) + + assert layer_num % (pp_size * vp_size) == 0, ( + "layer_num must be divisible by pp_size * vp_size," + f" but got layer_num: {layer_num}, pp_size: {pp_size}, vp_size: {vp_size}" + ) + + chunk_size = layer_num // (pp_size * vp_size) + chunk_list = [list(range(i * chunk_size, (i + 1) * chunk_size)) for i in range(pp_size * vp_size)] + + stage_chunk_list = [[] for _ in range(pp_size)] + for i in range(pp_size * vp_size): + stage_chunk_list[i % pp_size].append(chunk_list[i]) + + if config.use_recompute_attn: + logger.error("selective recompute only support full recompute now, please set use_recompute_attn to False") + + for i in range(pp_size): + no_recompute_layer_num.extend(stage_chunk_list[i][-selective_no_recompute_num:]) + + return set(sum(no_recompute_layer_num, [])) + + +class ErnieMoEForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): + + config_class = ErnieMoEConfig + _get_tensor_parallel_mappings = ErniePretrainedModel._get_tensor_parallel_mappings + + ErnieEmbeddingPipeClass = ErnieEmbeddingPipe + ErnieDecoderLayerPipeClass = ErnieDecoderLayerPipe + MTPEmbeddingPipeClass = MTPEmbeddingPipe + MTPLayerClass = MTPLayer + RMSNormPipeClass = RMSNormPipe + ErnieMoELMHeadPipeClass = ErnieMoELMHeadPipe + + @classmethod + def _prepare_pipeline_inputs_func(cls, data): + global input_ids_for_mtp + input_ids_for_mtp.clear() + for d in data: + assert "input_ids" in d + input_ids_for_mtp.append(d["input_ids"]) + inputs = tuple( + [ + [d[k] for d in data] + for k in [ + "input_ids", + "attention_mask", + "position_ids", + "inbatch_pack_offset", + ] + if k in data[0] + ] + ) + if len(inputs) == 1: + inputs = inputs[0] + labels = [d["labels"] for d in data] + return inputs, labels + + def __init__( + self, + config, + ): + new_initializer_range = math.sqrt(0.3333 / config.hidden_size) + config.initializer_range = new_initializer_range + + if config.moe_group == "mp": + assert config.sequence_parallel + + if config.moe_group in {"mp", "model", "tp", "mpdp"}: + assert config.sequence_parallel + logger.info(f"disable FFN tensor model parallel, moe-group={config.moe_group}") + config.disable_ffn_model_parallel = True + + config.moe_group = _parse_moe_group(config.moe_group) + config.moe_world_size = dist.get_world_size(config.moe_group) + if config.moe_world_size < 0: + config.moe_world_size = 1 + config.moe_rank = dist.get_rank(config.moe_group) + + self.config = config + + hcg = fleet.get_hybrid_communicate_group() + tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) + tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) + logger.info(f"using vpp={config.virtual_pp_degree}") + if config.sequence_parallel: + logger.info(f"using sequence_parallel, input seqlen={config.seqlen}") + assert config.seqlen is not None + assert ( + config.tensor_parallel_degree > 1 + ), f"sequence-parallel needs mp>1, got mp={config.tensor_parallel_degree}" + + config.tensor_parallel_degree = tensor_parallel_degree + config.tensor_parallel_rank = tensor_parallel_rank + PipelinePretrainedModel.init(self, config=config) + + if config.pp_no_recompute_layer is not None: + no_recompute_layers = config.pp_no_recompute_layer + else: + no_recompute_layers = get_pp_vp_split_layers(config) + logger.info(f"use no_recompute_layers: {no_recompute_layers}") + + def _need_full_recompute(layer_idx): + return layer_idx not in no_recompute_layers and config.use_recompute + + insert_empty_layer = config.insert_empty_layer + if len(insert_empty_layer) > 0: + assert min(insert_empty_layer) >= 0, "cannot insert empty layer as first layer of the model" + assert max(insert_empty_layer) < config.num_hidden_layers, "empty layers location exceed the num layers" + logger.info(f"use insert_empty_layer: {insert_empty_layer}") + + if config.multi_token_pred_depth == 0: + self.add_sequential_layer(LayerDesc(self.ErnieEmbeddingPipeClass, config=config), "ernie") + else: + if config.enable_mtp_magic_send: + self.add_sequential_layer( + SharedLayerDesc( + key="embed_weight_share", + layer_func=self.ErnieEmbeddingPipeClass, + shared_weight_attr="embedding_weight", + config=config, + ), + "ernie.embed", + ) + else: + self.add_sequential_layer(LayerDesc(self.ErnieEmbeddingPipeClass, config=config), "ernie") + + num_empty_layers = config.remove_tail_layer if isinstance(config.remove_tail_layer, int) else 1 + for i in range(config.num_hidden_layers - num_empty_layers): + self.add_sequential_layer( + LayerDesc( + self.ErnieDecoderLayerPipeClass, + config=config, + layer_idx=i, + use_full_recompute=_need_full_recompute(i), + ), + f"ernie.layers.{i}", + ) + if i in insert_empty_layer: + self.add_sequential_layer( + LayerDesc( + EmptyLayer, + ), + f"empty.layers.{i}", + ) + + if config.multi_token_pred_depth > 0: + if config.enable_mtp_magic_send: + self.add_sequential_layer( + SharedLayerDesc( + key="embed_weight_share", + layer_func=self.MTPEmbeddingPipeClass, + shared_weight_attr="embedding_weight", + config=config, + ), + "embed_share", + ) + self.add_sequential_layer(LayerDesc(self.MTPLayerClass, config=config), "ernie") + num_empty_layers = num_empty_layers - config.multi_token_pred_depth + + if config.remove_tail_layer: + for n in range(num_empty_layers): + self.add_sequential_layer( + LayerDesc( + EmptyLayer, + ), + f"empty.layers.{n}", + ) + else: + for n in range(num_empty_layers): + self.add_sequential_layer( + LayerDesc( + self.ErnieDecoderLayerPipeClass, + config=config, + layer_idx=i, + use_full_recompute=_need_full_recompute(i), + ), + f"ernie.layers.{n + config.num_hidden_layers - num_empty_layers}", + ) + + i = config.num_hidden_layers + if i in insert_empty_layer: + self.add_sequential_layer( + LayerDesc( + EmptyLayer, + ), + f"empty.layers.{i}", + ) + + self.add_sequential_layer( + LayerDesc(self.RMSNormPipeClass, config=config), + "ernie.norm", + ) + + self.add_sequential_layer(LayerDesc(self.ErnieMoELMHeadPipeClass, config=config), "lm_head") + + recompute_interval = 0 + + seg_method = "layer:ErnieDecoderLayer|EmptyLayer|MTPLayer" + if config.num_hidden_layers % fleet.get_hybrid_communicate_group().topology().get_dim_size("pipe") != 0: + seg_method = "uniform" + logger.info(f"using recompute_interval={recompute_interval}, seg_method={seg_method}") + + PipelineLayer.__init__( + self, + layers=self.get_sequential_layers(), + loss_fn=self.get_loss_fn(config), + topology=fleet.get_hybrid_communicate_group().topology(), + seg_method=seg_method, + recompute_interval=recompute_interval, + recompute_ctx={ + "mp_group": fleet.get_hybrid_communicate_group().get_model_parallel_group(), + "offload": False, + "partition": False, + }, + num_virtual_pipeline_stages=config.virtual_pp_degree, + ) + + def get_loss_fn(self, config): + return ErniePretrainingCriterionPipe(config) + + def rename_model_params(self, func): + if self.config.virtual_pp_degree == 1: + _layers = iter(self.run_function) + else: + _layers = (cc for c in self._model_chunks for cc in c.run_function) + func(self.config, _layers) + + def fp8_quant_weight(self): + with paddle.no_grad(): + for i, layer in self._sub_layers.items(): + if isinstance(layer, ErnieDecoderLayer) and hasattr(layer, "fp8_quant_weight"): + layer.fp8_quant_weight() + + def _post_init(self, original_init, *args, **kwargs): + super()._post_init(self, original_init, *args, **kwargs) + with paddle.no_grad(): + for i, layer in self._sub_layers.items(): + if isinstance(layer, ErnieDecoderLayer): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + if self.config.use_linear_residual_norm_recompute: + layer.fused_linear_add_norm.linear_weight.scale_(factor) + else: + layer.self_attn.o_proj.weight.scale_(factor) + if isinstance(layer.mlp, (MOELayer)): + for e in layer.mlp.experts: + if isinstance(e, ErnieMLP): + e.down_proj.weight.scale_(factor) + else: + layer.mlp.down_proj.weight.scale_(factor) + + def set_state_dict(self, state_dict, *args, **kwargs): + if self._pipeline_name_mapping is None: + self._set_pipeline_name_mapping() + + layer_idxs = [] + if self.config.virtual_pp_degree == 1: + _layers = iter(self.run_function) + else: + _layers = (cc for c in self._model_chunks for cc in c.run_function) + + for layer in _layers: + if isinstance(layer, self.ErnieDecoderLayerPipeClass): + layer_idxs.append(layer.layer_idx) + logger.info(f"this pipeline stage has ErnieDecoderLayers: {layer_idxs}") + if not self.parameters(): + logger.info("this pipe not need param, skip set state-dict") + return {}, {} + state_dict = moe_statedict_upcycle( + state_dict, + self.config, + next(iter(self.parameters())).dtype, + self._get_tensor_parallel_mappings(self.config, is_split=False), + self._get_tensor_parallel_mappings(self.config, is_split=True), + layer_idxs, + ) + state_dict = moe_ep2mp( + state_dict, + self.config, + self._get_tensor_parallel_mappings(self.config, is_split=True), + ) + + for k in list(state_dict.keys()): + v = state_dict.pop(k) + if k not in self._pipeline_name_mapping: + continue + state_dict[self._pipeline_name_mapping[k]] = v + missing_keys, mismatch_keys = super().set_state_dict(state_dict, *args, **kwargs) + + missing_shared_keys = self._check_shared_model_state() + tmp_missing_keys = [] + for key in missing_keys: + if key in missing_shared_keys and missing_shared_keys[key] not in missing_keys: + continue + tmp_missing_keys.append(key) + missing_keys = tmp_missing_keys + + logger.info(f"moe_set_state_dict: {missing_keys}, {mismatch_keys}") + return missing_keys, mismatch_keys diff --git a/ernie/ERNIE/examples/pre-training/models/fp8_linear.py b/ernie/ERNIE/examples/pre-training/models/fp8_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..349621a8439d07fc5c94bb330a343c141eaa439a --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/fp8_linear.py @@ -0,0 +1,567 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FP8 Linear Layer Implementation for PaddlePaddle + +This module implements FP8 (8-bit floating point) linear layers using PaddlePaddle's +incubate APIs for low-precision training. Key features include: + +1. FP8 matrix multiplication with block-wise quantization +2. Memory-efficient forward/backward passes +3. PaddlePaddle-specific optimizations like: + - Using paddle.incubate.fp8 APIs + - Leveraging Paddle's automatic differentiation system + - Optimized for Paddle's tensor layout and memory management +""" + + +import numpy +import paddle +from paddle.incubate.fp8 import deep_gemm +from paddle.incubate.nn.functional import swiglu + +# Keep reference to original linear op for fallback if needed +original_linear = paddle.nn.functional.linear + + +# Expose only the main class to public API +__all__ = ["Fp8FusedMlp"] + + +def fp8_gemm( + x_fp8, + x_scale, + w_fp8, + w_scale, + is_a_1d_scaled, + is_b_1d_scaled, + out=None, + rtn_dtype=paddle.bfloat16, +): + """ + Performs FP8 matrix multiplication (GEMM) operation, using blockwise GEMM algorithm. + + Args: + x_fp8 (Tensor): Input tensor in FP8 format + x_scale (Tensor): Scaling factor for input tensor + w_fp8 (Tensor): Weight tensor in FP8 format + w_scale (Tensor): Scaling factor for weight tensor + is_a_1d_scaled (bool): Whether input tensor uses 1D scaling + is_b_1d_scaled (bool): Whether weight tensor uses 1D scaling + out (Tensor, optional): Output tensor for accumulation. Defaults to None + rtn_dtype (dtype, optional): Return data type. Defaults to paddle.bfloat16 + + Returns: + Tensor: Result of the matrix multiplication + """ + accumulate = out is not None + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + # Using Paddle's blockwise FP8 GEMM with split accumulator for numerical stability + y = paddle.incubate.nn.functional.fp8_gemm_blockwise( + a=x_fp8, + a_decode_scale=x_scale, # Input scaling factors + b=w_fp8, + b_decode_scale=w_scale, # Weight scaling factors + out_dtype=rtn_dtype, # Output dtype (bfloat16) + out=out, # Optional output tensor for accumulation + accumulate=accumulate, # Whether to accumulate into out tensor + use_split_accumulator=True, # Paddle-specific optimization + is_a_1d_scaled=is_a_1d_scaled, # 1D scaling for input + is_b_1d_scaled=is_b_1d_scaled, # 1D scaling for weights + ) + else: + y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) + if out is not None: + out = out + y + return out + + return y + + +def padding(x, axis): + """ + Pads the input tensor along specified axis to make its size divisible by 512 or 128. + + Args: + x (Tensor): Input tensor to be padded + axis (int): Axis along which to pad (0 for rows, 1 for columns) + + Returns: + Tensor: Padded tensor + """ + if x.shape[axis] % 512 != 0: + if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: + padding_size = 512 + else: + padding_size = 128 + pad_size = padding_size - (x.shape[axis] % padding_size) + if axis == 0: + x = paddle.concat([x, paddle.zeros([pad_size, x.shape[-1]], dtype=x.dtype)], axis=0) + else: + x = paddle.concat([x, paddle.zeros([x.shape[0], pad_size], dtype=x.dtype)], axis=-1) + return x + + +class Fp8FusedMlpFunc(paddle.autograd.PyLayer): + """ + Custom PyLayer implementation of FP8 fused MLP operation. + + This class implements both forward and backward passes for a memory-efficient + FP8 (8-bit floating point) multi-layer perceptron using PaddlePaddle's + FP8 quantization APIs. + """ + + @staticmethod + def forward(ctx, x, w1, w2): + """ + Forward pass for FP8 fused multi-layer perceptron (MLP) operation. + + Args: + ctx (PyLayerContext): Context object to save tensors for backward pass + x (paddle.Tensor): Input tensor of shape [batch_size, hidden_size] + w1 (paddle.Tensor): First weight matrix of shape [hidden_size, intermediate_size*2] + w2 (paddle.Tensor): Second weight matrix of shape [intermediate_size, hidden_size] + + Returns: + paddle.Tensor: Output tensor of shape [batch_size, hidden_size] + + Note: + - Uses Paddle's FP8 quantization for memory efficiency + - Implements SWiGLU activation internally + - Handles tensor padding for optimal FP8 GEMM performance + """ + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + if x.shape[0] % 512 != 0: + x_fp8, x_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + x, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + x = padding(x, 0) + _, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + x, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + + else: + x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + x, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + + _, _, w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w1, + quant_method="128x128", + input_transpose=True, + output_scale_transpose=False, + ) + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1) + + o2 = swiglu(o1) + o2_fp8, o2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2, quant_method="1x128", input_transpose=False, output_scale_transpose=True + ) + + _, _, w2_t_fp8, w2_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w2, + quant_method="128x128", + input_transpose=True, + output_scale_transpose=False, + ) + o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3) + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + ctx.save_for_backward( + x_t_fp8, + x_t_scale, + w1, + o1, + w2, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + """ + Memory-efficient backward pass for FP8 fused MLP operation. + + Args: + ctx: Context object containing saved tensors from forward pass + do3 (Tensor): Gradient of the loss with respect to the output + + Returns: + Tuple[Tensor, Tensor, Tensor]: Gradients with respect to x, w1, and w2 + """ + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + x_t_fp8, x_t_scale, w1, o1, w2, x_orig_shape = ctx.saved_tensor() + x_orig_shape = x_orig_shape.numpy() + + o2 = swiglu(o1) + if do3.shape[0] % 512 != 0: + do3_fp8, do3_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do3, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do3 = padding(do3, 0) + _, _, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do3, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + else: + do3_fp8, do3_scale, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do3, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + w2_fp8, w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w2, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + do2 = paddle.empty([do3_fp8.shape[0], w2_fp8.shape[0]], do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2) + + o2 = padding(o2, 0) + _, _, o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + + dw2 = fp8_gemm( + o2_t_fp8, + o2_t_scale, + do3_t_fp8, + do3_t_scale, + True, + True, + rtn_dtype=paddle.float32, + ) + + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + + if do1.shape[0] % 512 != 0: + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do1 = padding(do1, 0) + _, _, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + else: + do1_fp8, do1_scale, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w1, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + dw1 = fp8_gemm( + x_t_fp8, + x_t_scale, + do1_t_fp8, + do1_t_scale, + True, + True, + rtn_dtype=paddle.float32, + ) + return dx, dw1, dw2 + + +class MemEfficientFp8FusedMlpFunc(paddle.autograd.PyLayer): + """ + Memory-optimized version of FP8 fused MLP operation. + + This implementation reduces memory usage during training by: + - Avoiding redundant tensor storage in forward pass + - Recomputing intermediate values during backward pass + - Using optimized FP8 quantization strategies + + Inherits from paddle.autograd.PyLayer to implement custom backward pass. + """ + + @staticmethod + def forward(ctx, x, w1, w2): + """ + Memory-efficient forward pass for FP8 fused MLP operation. + + Args: + ctx (PyLayerContext): Context object to save minimal tensors for backward pass + x (paddle.Tensor): Input tensor of shape [batch_size, hidden_size] + w1 (paddle.Tensor): First weight matrix of shape [hidden_size, intermediate_size*2] + w2 (paddle.Tensor): Second weight matrix of shape [intermediate_size, hidden_size] + + Returns: + paddle.Tensor: Output tensor of shape [batch_size, hidden_size] + + Note: + - Saves only essential tensors for backward pass to reduce memory usage + - Uses recomputation strategy during backward pass + - Maintains same numerical accuracy as standard implementation + """ + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + x_fp8, x_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + x, quant_method="1x128", input_transpose=False, output_scale_transpose=True + ) + + _, _, w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w1, + quant_method="128x128", + input_transpose=True, + output_scale_transpose=False, + ) + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1) + + o2 = swiglu(o1) + o2_fp8, o2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2, quant_method="1x128", input_transpose=False, output_scale_transpose=True + ) + + _, _, w2_t_fp8, w2_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w2, + quant_method="128x128", + input_transpose=True, + output_scale_transpose=False, + ) + o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3) + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + ctx.save_for_backward( + x_fp8, + x_scale, + w1, + w2, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor() + x_orig_shape = x_orig_shape.numpy() + + _, _, w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w1, + quant_method="128x128", + input_transpose=True, + output_scale_transpose=False, + ) + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1) + + x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous()) + x_dequant_fp16 = padding(x_dequant_fp16, 0) + + _, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + x_dequant_fp16, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + + o2 = swiglu(o1) + + if do3.shape[0] % 512 != 0: + do3_fp8, do3_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do3, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do3 = padding(do3, 0) + _, _, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do3, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + else: + do3_fp8, do3_scale, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do3, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + w2_fp8, w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w2, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + do2 = paddle.empty([do3_fp8.shape[0], w2_fp8.shape[0]], do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2) + + o2 = padding(o2, 0) + _, _, o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + + dw2 = fp8_gemm( + o2_t_fp8, + o2_t_scale, + do3_t_fp8, + do3_t_scale, + True, + True, + rtn_dtype=paddle.float32, + ) + + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + + if do1.shape[0] % 512 != 0: + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do1 = padding(do1, 0) + _, _, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + else: + do1_fp8, do1_scale, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=True, + output_scale_transpose=True, + ) + w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + w1, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + dw1 = fp8_gemm( + x_t_fp8, + x_t_scale, + do1_t_fp8, + do1_t_scale, + True, + True, + rtn_dtype=paddle.float32, + ) + return dx, dw1, dw2 + + +class Fp8FusedMlp(paddle.nn.Layer): + """ + PaddlePaddle Layer implementing FP8 fused multi-layer perceptron (MLP). + + This layer combines: + - FP8 precision matrix operations for improved performance + - Fused MLP architecture with SWiGLU activation + - Memory-efficient training through custom PyLayer implementation + + """ + + def __init__(self, config): + """ + Initializes the FP8 Fused MLP layer. + + Args: + config (object): Configuration object containing: + - hidden_size (int): Dimension of the input/output features + - intermediate_size (int): Dimension of the intermediate features + + Note: + - Weights are initialized using Paddle's create_parameter + - Uses bfloat16 precision for weight storage + - No bias terms are used in this implementation + """ + + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.w1 = self.create_parameter( + shape=[self.hidden_size, self.intermediate_size * 2], + dtype="bfloat16", # Using Paddle's bfloat16 dtype + is_bias=False, # Paddle-specific parameter attribute + ) + self.w2 = self.create_parameter( + shape=[self.intermediate_size, self.hidden_size], + dtype="bfloat16", + is_bias=False, + ) + + def forward(self, x): + """ + Forward pass of the FP8 fused MLP layer. + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Output tensor after MLP transformation + """ + return Fp8FusedMlpFunc.apply(x, self.w1, self.w2) diff --git a/ernie/ERNIE/examples/pre-training/models/moe/__init__.py b/ernie/ERNIE/examples/pre-training/models/moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cc79cc9d7f1977efe8e066facf32c20c8ad3af --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/moe/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ernie/ERNIE/examples/pre-training/models/moe/moe_layer.py b/ernie/ERNIE/examples/pre-training/models/moe/moe_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..35c251de9d25d941d3c6759a62ccb0d89ee37cf2 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/moe/moe_layer.py @@ -0,0 +1,1228 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from collections import namedtuple + +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from models.comm_utils import profile +from models.moe.token_dispatcher.fp8_utils import ( + ExpertsGroupGemmContiguousNode, + ExpertsGroupGemmNode, + ExpertsGroupGemmWLCHNode, +) +from models.moe.token_dispatcher.moe_utils import ( + UnZipNode, + ZipNode, +) +from models.sequence_parallel_utils import ScatterOp +from models.utils import ( + manual_backward, +) +from paddle import framework, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.communication import stream +from paddle.distributed.fleet.utils import recompute +from paddle.incubate.nn.functional import ( + moe_combine, + moe_gate_dispatch, + moe_gate_dispatch_permute, +) + +try: + from paddle.incubate.nn.functional import ( + moe_gate_dispatch_and_quant, + ) +except ImportError: + moe_gate_dispatch_and_quant = None + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} + +logger = logging.getLogger(__name__) + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + +class Fp8MoeGateDispatchAndQuant(paddle.autograd.PyLayer): + """Fp8MoeGateDispatchAndQuant""" + + @staticmethod + def forward( + ctx, + x, + gate_logtis, + corr_bias, + k, + capacity, + use_pad, + use_pow2_scale=True, + ): + """forward""" + assert moe_gate_dispatch_and_quant is not None, "Please use new version Paddle." + with paddle.amp.auto_cast(enable=False): + ( + out_fp8, + scale, + combine_weights, + scatter_index, + expert_offset, + expert_id, + ) = moe_gate_dispatch_and_quant( + x, + gate_logtis, + corr_bias=corr_bias, + k=k, + capacity=capacity, + use_pad=use_pad, + use_pow2_scale=use_pow2_scale, + ) + assert out_fp8.shape[0] == scale.shape[0] + + out_fp8.stop_gradient = False + combine_weights.stop_gradient = False + scatter_index.stop_gradient = True + expert_offset.stop_gradient = True + expert_id.stop_gradient = True + scale.stop_gradient = True + + ctx.k = k + ctx.capacity = capacity + ctx.use_pad = use_pad + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + ctx.expert_id = expert_id + ctx.has_corr_bias = corr_bias is not None + + return ( + out_fp8, + combine_weights, + scatter_index, + expert_offset, + expert_id, + { + "scale": scale, + }, + ) + + @staticmethod + def backward(ctx, *grads): + """backward""" + out_grad, combine_weights_grad = grads[0], grads[1] + x_grad, gate_logits_grad = paddle._C_ops.moe_gate_dispatch_grad( + ctx.combine_weights, + ctx.scatter_index, + ctx.expert_id, + out_grad, + combine_weights_grad, + ctx.k, + ctx.capacity, + ctx.use_pad, + ) + if ctx.has_corr_bias: + return x_grad, gate_logits_grad, None + else: + return x_grad, gate_logits_grad + +def recompute_fwd_gate_up_func(config, layer_idx): + if "recompute_fwd_gate_up" in config.fp8_mem_configs: + if isinstance(config.fp8_mem_configs["recompute_fwd_gate_up"], bool): + return config.fp8_mem_configs["recompute_fwd_gate_up"] + if isinstance(config.fp8_mem_configs["recompute_fwd_gate_up"], list): + return layer_idx in config.fp8_mem_configs["recompute_fwd_gate_up"] + + return False + + +class MoEStatics(nn.Layer): + def __init__(self, config, layer_idx): + super().__init__() + self._cast_to_low_precision = False + self._cast_to_low_precision = False + num_experts = config.moe_num_experts + + with paddle.utils.unique_name.guard(f"mm_layer_{layer_idx}_"): + num_experts_groups = 1 + p = self.create_parameter( + shape=[num_experts_groups, num_experts], + dtype="float32", + is_bias=True, + attr=paddle.ParamAttr(name=paddle.utils.unique_name.generate("corr_bias")), + ) + p.stop_gradient = False + self.e_score_correction_bias = p + self.e_score_correction_bias.is_distributed = True + self.e_score_correction_bias.unused_param = True + p = paddle.zeros( + shape=[num_experts_groups, num_experts], + dtype="int64", + ) + p.stop_gradient = True + self.expert_usage = p + + +class GateCombine(PyLayer): + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + ret = moe_combine(x, combine_weights, scatter_index) + return ret + + @staticmethod + def backward(ctx, grad_y, *_): + # assert moe_combine is not None + grad_x, grad_combine_weight_helper = paddle._C_ops.moe_combine_grad( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + + +class FusionFP8Expert(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, hidden_states, custom_map): + ctx.node = ExpertsGroupGemmNode(None, custom_map) + + t1 = hidden_states.transpose([1, 0, 2, 3]).contiguous() + expert_num = t1.shape[0] + tokens_num = t1.shape[1] * t1.shape[2] + tokens_per_expert = paddle.full([expert_num], fill_value=tokens_num, dtype="int32") + + t1 = t1.reshape([-1, hidden_states.shape[-1]]) + out = ctx.node.forward_no_prob(t1, tokens_per_expert) + + expert_output = out.reshape(hidden_states.shape).transpose([1, 0, 2, 3]).contiguous() + + ctx.save_for_backward(tokens_per_expert) + return expert_output + + @staticmethod + def backward(ctx, output_grad): + (tokens_per_expert,) = ctx.saved_tensor() + + t1 = output_grad.transpose([1, 0, 2, 3]).contiguous() + t1 = t1.reshape([-1, output_grad.shape[-1]]) + + dx = ctx.node.backward_no_prob(t1, tokens_per_expert) + dx = dx.reshape(output_grad.shape).transpose([1, 0, 2, 3]).contiguous() + return dx + + +class AlltoAll(PyLayer): + @staticmethod + def forward(ctx, x, group, sync_op=True): + ctx.group = group + if dist.get_world_size(group) <= 1: + return x + output = paddle.empty_like(x) + output.stop_gradient = False + task = stream.alltoall_single(output, x, None, None, group, sync_op=sync_op, use_calc_stream=sync_op) + if not sync_op: + return output, task + else: + return output + + @staticmethod + def backward(ctx, *dx): + return AlltoAll.apply(*dx, group=ctx.group) + + +class AlltoAllExpertOverlap(PyLayer): + @staticmethod + def forward(ctx, input, group, num_local_experts, forward_func_dict, is_first_fwd=False): + assert ( + dist.get_world_size(group) > 1 + ), "AlltoAllExpertOverlap is not supported for a world size less than or equal to 1." + + ctx.bw_funcs = {} + ctx.group = group + ctx.num_local_experts = num_local_experts + + assert isinstance(forward_func_dict, nn.LayerList) + all2all_tasks = [] + all2all_ins = paddle.unbind(input, axis=0) + for stage_id in range(1): + stage_input = all2all_ins[stage_id] + x_out, task = AlltoAll.apply(stage_input, group=group, sync_op=False) + all2all_tasks.append((task, x_out)) + + expert_outputs = [] + for stage_id in range(num_local_experts): + if stage_id + 1 != num_local_experts: + stage_input = all2all_ins[stage_id + 1] + x_out, task = AlltoAll.apply(stage_input, group=group, sync_op=False) + all2all_tasks.append((task, x_out)) + + task, dispatched_input = all2all_tasks[stage_id] + task.wait() + bwf, (expert_outputs_cur_stage,) = manual_backward( + forward_func_dict[stage_id], is_first_fwd, dispatched_input + ) + ctx.bw_funcs[stage_id] = bwf + expert_outputs.append(expert_outputs_cur_stage) + + expert_output = paddle.stack(expert_outputs, axis=1) + return expert_output + + @staticmethod + def backward(ctx, out_grad): + all2all_tasks = [] + expert_outputs = [] + + out_grad_list = paddle.split(out_grad, num_or_sections=out_grad.shape[1], axis=1) + for stage_id in range(ctx.num_local_experts): + (grad_cur_stage,) = ctx.bw_funcs[stage_id](out_grad_list[stage_id]) + + x_out, task = AlltoAll.apply(grad_cur_stage, group=ctx.group, sync_op=False) + all2all_tasks.append(task) + expert_outputs.append(x_out) + + for task in all2all_tasks: + task.wait() + + expert_output = paddle.stack(expert_outputs, axis=0) + return expert_output + + +class AlltoAllAsync(PyLayer): + @staticmethod + def forward(ctx, x, *fn_args, group=None, fn=None, is_first_fwd=False): + assert fn is not None, "use AlltoAll no async" + ctx.group = group + if dist.get_world_size(group) <= 1: + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + return (x,) + fn_out + x_out = paddle.empty_like(x) + x_out.stop_gradient = False + task = stream.alltoall_single( + x_out, + x, + None, + None, + group, + sync_op=False, + ) + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + task.wait() + return (x_out,) + fn_out + + @staticmethod + def backward(ctx, dx_out, *fn_out_grads): + if dist.get_world_size(ctx.group) <= 1: + fn_args_grads = ctx.bwf(*fn_out_grads) + return (dx_out,) + fn_args_grads + + dx = paddle.empty_like(dx_out) + dx.stop_gradient = False + task = stream.alltoall_single( + dx, + dx_out, + None, + None, + ctx.group, + sync_op=False, + ) + fn_args_grads = ctx.bwf(*fn_out_grads) + task.wait() + return (dx,) + fn_args_grads + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + output = None + orig_dtype = x.dtype + scatter_index = scatter_index.unbind(1) + dispatch_mask = dispatch_mask.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype="float32") + updates = x * i_dispatch_mask.unsqueeze(-1).cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining_fused(x, combine_weights, scatter_index, hard_gate=False): + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) + return x_gatherd.squeeze(-2) + ret = GateCombine.apply(x, combine_weights, scatter_index) + ret.stop_gradient = False + return ret + + +class MOELayer(nn.Layer): + """Mixture of Experts (MoE) Layer implementation. + + This layer dynamically routes input tokens to different expert networks + based on a gating mechanism, allowing for conditional computation. + + """ + + def __init__( + self, + gate, + experts, + layer_idx, + shared_experts, + group, + recompute=False, + enable_logging=False, + k=2, + enable_bpr=False, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + ): + """Initialize the MoE layer. + + Args: + gate (nn.Layer): Gating network that outputs routing scores. + experts (nn.LayerList, optional): List of expert networks. + layer_idx (int): Identifier for this layer (used for logging). + shared_experts (nn.Layer): Shared expert applied to all tokens (optional). + group (dist.ProcessGroup): Process group for distributed expert parallelism. + recompute (bool, optional): If True, enables gradient checkpointing. Defaults to False. + enable_logging (bool, optional): If True, tracks expert usage statistics. Defaults to False. + k (int, optional): Number of experts to route each token to. Defaults to 2. + enable_bpr (bool, optional): If True, uses balanced positive routing. Defaults to False. + all_to_all_dropout (float, optional): Dropout rate for cross-device communication. Defaults to 0. + group_experts (bool, optional): If True, optimizes expert communication. Defaults to False. + """ + + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.enable_logging = enable_logging + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info(f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}") + assert self.gate.config.moe_use_aux_free + + self.is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + self.is_ep_moe = ( + hasattr(fleet.fleet, "_hcg") + and hasattr( + fleet.get_hybrid_communicate_group(), + "get_moe_sharding_parallel_world_size", + ) + and fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_world_size() > 0 + ) + is_dummy_moe = dist.get_world_size(group) == 1 + + for p in experts.parameters(): + p.expert = not (self.is_mp_moe or is_dummy_moe) + p.no_sync = not (self.is_mp_moe or is_dummy_moe) + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True + + expert_color = None + if self.is_ep_moe: + moe_grad_group = fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() + expert_color = {"color": "moe_expert", "group": moe_grad_group} + elif self.config.offline_quant_expert_weight and self.config.clear_origin_weight_when_offline_quant: + expert_color = {"color": "moe_expert"} + + self.world_size = dist.get_world_size(self.group) + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.num_local_experts = len(self.experts) + self.dispatch_by_task = hasattr(self.gate, "dispatch_by_task") and self.gate.dispatch_by_task + + if self.dispatch_by_task: + assert 0, "no supported, checkout earylier code" + assert self.num_local_experts == 1 + + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + if hasattr(fleet.fleet, "_hcg"): + hcg = fleet.get_hybrid_communicate_group() + if hasattr(hcg, "get_moe_sharding_parallel_world_size") and hcg.get_moe_sharding_parallel_world_size() > 0: + moe_grad_group = hcg.get_moe_sharding_parallel_group() + for p in self.experts.parameters(): + p.color = {"color": "moe_expert", "group": moe_grad_group} + + def forward_experts(self, dispatched_input): + with profile("fwd-expert"): + dispatched_input = dispatched_input.reshape( + [ + self.world_size, + self.num_local_experts, + -1, + dispatched_input.shape[-1], + ] + ) + expert_outputs = [] + if isinstance(self.experts, nn.LayerList): + if self.config.use_fp8_fuse_node: + expert_output = FusionFP8Expert.apply(dispatched_input, self) + else: + chunks = dispatched_input.transpose([1, 0, 2, 3]).contiguous().unbind(0) + assert len(chunks) == len(self.experts), ( + len(chunks), + len(self.experts), + ) + for chunk, expert in zip(chunks, self.experts): + expert_outputs += [expert(chunk)] + + expert_output = paddle.stack(expert_outputs, axis=1) + + else: + dispatched_input = dispatched_input.transpose([1, 0, 2, 3]) + dispatched_input.contiguous() + orig_shape = dispatched_input.shape + chunks = dispatched_input.reshape([orig_shape[0], -1, orig_shape[-1]]) + chunks = self.experts(chunks) + chunks = chunks.reshape(orig_shape[:-1] + [chunks.shape[-1]]).unbind(0) + expert_outputs += chunks + expert_output = paddle.stack(expert_outputs, axis=1) + return expert_output + + def fp8_quant_weight(self): + expert_w1_list = [expert.up_gate_proj.weight for expert in self.experts if expert is not None] + expert_w2_list = [expert.down_proj.weight for expert in self.experts if expert is not None] + + expert_w1 = expert_w1_list[0] + expert_w2 = expert_w2_list[0] + + fp8_weight_stacked_w1, fp8_scale_stacked_w1 = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w1_list, transpose=False) + setattr(expert_w1, "fp8_weight_stacked", fp8_weight_stacked_w1) + setattr(expert_w1, "fp8_scale_stacked", fp8_scale_stacked_w1) + + fp8_weight_stacked_w1_t, fp8_scale_stacked_w1_t = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w1_list, transpose=True) + setattr(expert_w1, "fp8_weight_stacked_transpose", fp8_weight_stacked_w1_t) + setattr(expert_w1, "fp8_scale_stacked_transpose", fp8_scale_stacked_w1_t) + + fp8_weight_stacked_w2, fp8_scale_stacked_w2 = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w2_list, transpose=False) + setattr(expert_w2, "fp8_weight_stacked", fp8_weight_stacked_w2) + setattr(expert_w2, "fp8_scale_stacked", fp8_scale_stacked_w2) + + fp8_weight_stacked_w2_t, fp8_scale_stacked_w2_t = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w2_list, transpose=True) + setattr(expert_w2, "fp8_weight_stacked_transpose", fp8_weight_stacked_w2_t) + setattr(expert_w2, "fp8_scale_stacked_transpose", fp8_scale_stacked_w2_t) + + + def fused_gate_logits_process(self, gate_logits, token_type_ids, offload_helper=None): + k = self.k + max_prob = None + + if self.group_experts: + prob = self.gate.act(gate_logits.reshape([gate_logits.shape[0], k, -1])) + max_prob = prob.max(-1, keepdim=True) + prob /= max_prob + prob = prob.reshape([prob.shape[0], -1]) + else: + prob = self.gate.act(gate_logits) + return prob, max_prob + + def gate_distpach_and_quant(self, input, token_type_ids): + """ + Quantization is performed within the op + """ + assert not self.config.use_ep_comm_overlap, "ep_comm_overlap is not supported" + + seqlen, d_model = input.shape + args = () + assert token_type_ids is None + + ( + gate_logits, + capacity, + router_loss, + ) = self.gate(input, *args) + + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + + k = self.k + prob, max_prob = self.fused_gate_logits_process(gate_logits, token_type_ids) + + with profile("dispatch_op"): + corr_bias = self.moe_statics.e_score_correction_bias[0].detach() if self.use_correction_bias else None + + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + fp8_dispatched_handle, + ) = Fp8MoeGateDispatchAndQuant.apply(input, prob, corr_bias, k=k, capacity=capacity, use_pad=True) + + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add(p, paddle.nonzero(token_type_ids == 0), -1 + max_prob) + else: + p = max_prob + combine_weights_unnorm = (combine_weights_unnorm.unsqueeze(-1) * p).squeeze(-1) + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape([p.shape[0], -1]) + if self.gate.norm_gate_logits: + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + combine_weights = combine_weights_unnorm + combine_weights = combine_weights.cast("bfloat16") + + def reshape_for_a2a(tensor): + return tensor.reshape( + [ + self.world_size * self.num_local_experts, + capacity, + -1, + ] + ) + + dispatched_input = reshape_for_a2a(dispatched_input) + fp8_dispatched_handle["scale"] = reshape_for_a2a(fp8_dispatched_handle["scale"]) + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = True + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + fp8_dispatched_handle, + ) + + def gate_and_distpach(self, input, token_type_ids): + seqlen, d_model = input.shape + args = () + assert token_type_ids is None + + ( + gate_logits, + capacity, + router_loss, + ) = self.gate(input, *args) + + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + + k = self.k + prob, max_prob = self.fused_gate_logits_process(gate_logits, token_type_ids) + + with profile("dispatch_op"): + if "corr_bias" in inspect.signature(moe_gate_dispatch).parameters: + if self.use_correction_bias: + compat_args = (self.moe_statics.e_score_correction_bias[0],) + else: + compat_args = (None,) + else: + assert not self.use_correction_bias, "correction bias not supported, rebuild moe-ops" + compat_args = () + + if not self.config.use_ep_comm_overlap: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_gate_dispatch( + input, + prob, + *compat_args, + k=k, + capacity=capacity, + use_pad=True, + ) + else: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_gate_dispatch_permute( + input, + prob, + *compat_args, + k=k, + capacity=capacity, + world_size=self.group.nranks, + ) + + dispatched_input = dispatched_input.cast(input.dtype) + + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add(p, paddle.nonzero(token_type_ids == 0), -1 + max_prob) + else: + p = max_prob + combine_weights_unnorm = (combine_weights_unnorm.unsqueeze(-1) * p).squeeze(-1) + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape([p.shape[0], -1]) + if self.gate.norm_gate_logits: + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + combine_weights = combine_weights_unnorm + combine_weights = combine_weights.cast(dispatched_input.dtype) + + if not self.config.use_ep_comm_overlap: + dispatched_input = dispatched_input.reshape( + [ + self.world_size * self.num_local_experts, + capacity, + (d_model), + ] + ) + else: + assert ( + len(dispatched_input.shape) == 4 + and dispatched_input.shape[1] == self.world_size + and dispatched_input.shape[0] == self.num_local_experts + ), ( + f"When using ep_comm_overlap, moe_gate_dispatch_permute is needed. " + f"Expected dispatched_input to have shape[1] == {self.world_size} " + f"and shape[0] == {self.num_local_experts}, " + f"but got shape {dispatched_input.shape}" + ) + dispatched_input = dispatched_input + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = True + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + ) + + def _calc_router_loss( + self, + dispatch_mask, + gate_logits, + gate_prob, + num_experts, + use_group, + layer_idx, + token_type=None, + tokens_type_mask=None, + dispatch_tokens_mask=None, + prefix="", + ): + router_loss, l_aux = 0.0, None + if self.gate.config.moe_aux_loss_lambda: + l_aux = self.gate._cal_aux_loss( + gate_prob, + dispatch_mask, + num_experts, + use_group, + tokens_type_mask, + dispatch_tokens_mask, + ) + router_loss += self.gate.moe_aux_loss_lambda[token_type or 0] * l_aux + else: + router_loss += self.zero * gate_prob[0, 0] + + return router_loss + + def calc_router_loss_and_logging( + self, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + dispatch_token_type_ids=None, + offload_helper=None, + ): + assert gate_prob is not None + router_loss += self._calc_router_loss( + dispatch_mask, + gate_logits, + gate_prob, + self.gate.num_experts_tensor, + self.group_experts, + self.layer_idx, + ) + + return router_loss + + def combine_expert_output(self, expert_output, combine_weights, scatter_index): + expert_output = expert_output.reshape([-1, expert_output.shape[-1]]) + combined_output = combining_fused(expert_output, combine_weights, scatter_index) + + if self.output_postprocess is not None: + combined_output = self.output_postprocess(combined_output) + return combined_output + + def forward_single_stage(self, dispatched_input, stage_id): + assert isinstance(self.experts, nn.LayerList) + return self.experts[stage_id](dispatched_input) + + def all2all_expert_overlap(self, x, group): + all2all_tasks = [] + all2all_ins = paddle.unbind(x, axis=0) + for stage_id in range(1): + stage_input = all2all_ins[stage_id] + x_out, task = AlltoAll.apply(stage_input, group=self.group, sync_op=False) + all2all_tasks.append((task, x_out)) + + expert_outputs = [] + for stage_id in range(self.num_local_experts): + if stage_id + 1 != self.num_local_experts: + stage_input = all2all_ins[stage_id + 1] + x_out, task = AlltoAll.apply(stage_input, group=self.group, sync_op=False) + all2all_tasks.append((task, x_out)) + + task, dispatched_input = all2all_tasks[stage_id] + task.wait() + expert_outputs_cur_stage = ( + recompute(self.forward_single_stage, dispatched_input, stage_id) + if self.recompute and self.training + else self.forward_single_stage(dispatched_input, stage_id) + ) + expert_outputs.append(expert_outputs_cur_stage) + + expert_output = paddle.stack(expert_outputs, axis=1) + return expert_output + + def forward( + self, + input, + token_type_ids=None, + ): + if input.ndim == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert len(input.shape) == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + + hidden_size = input.shape[1] + if token_type_ids is not None: + token_type_ids = token_type_ids.clone()[:, :-1] + if self.config.sequence_parallel: + token_type_ids = token_type_ids.reshape([-1]) + token_type_ids = ScatterOp.apply(token_type_ids) + token_type_ids.stop_gradient = True + + assert self.gate is not None + if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + orig_shape_2 = input.shape + output = self.forward_experts(input) + output += self.gate.weight.sum() * 0.0 + output = output.reshape(orig_shape or orig_shape_2) + return output, None, 0 + + is_first_fwd = not framework._dygraph_tracer()._has_grad + use_async = self.shared_experts is not None + gate_input = input + + use_fp8_fuse_node = self.config.use_combine_before_a2a and self.config.use_fp8_fuse_node + use_quant_before_a2a = self.config.use_quant_before_a2a and use_fp8_fuse_node + + with profile("fused_gate_and_dispatch"): + fp8_dispatched_handle = None + if use_quant_before_a2a: + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + fp8_dispatched_handle, + ) = self.gate_distpach_and_quant(gate_input, token_type_ids) + else: + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + ) = self.gate_and_distpach(gate_input, token_type_ids) + + if self.config.use_combine_before_a2a: + assert ( + not self.config.use_ep_comm_overlap + ), "Dont support `use_ep_comm_overlap` when enable `use_combine_before_a2a`." + cw_shape = combine_weights.shape + si_shape = scatter_index.shape + scatter_index = scatter_index.reshape([-1]) + + token_combine_weights = paddle.zeros([cw_shape[0] * cw_shape[1]], dtype=combine_weights.dtype) + token_combine_weights = paddle.scatter( + token_combine_weights, + scatter_index, + combine_weights.reshape([-1]), + overwrite=False, + ) + + token_combine_weights = token_combine_weights.reshape([cw_shape[0], cw_shape[1], 1]) + token_combine_weights = AlltoAll.apply(token_combine_weights, self.group) + + if not self.config.use_ep_comm_overlap: + if use_quant_before_a2a: + # To enable backward pass overlap, the all-to-all (a2a) operation is performed inside + # FP8FusedWLCHFunc, eliminating the need for external a2a. However, be careful not + # to skip the computation of shared_experts. + shared_out = self.shared_experts(input) if self.shared_experts is not None else None + else: + with profile("moe_comm_and_shared_expert"): + if use_async: + dispatched_input, shared_out = AlltoAllAsync.apply( + dispatched_input, + input, + group=self.group, + fn=self.shared_experts, + is_first_fwd=is_first_fwd, + ) + else: + dispatched_input = AlltoAll.apply(dispatched_input, self.group) + + if use_fp8_fuse_node: + expert_out = FP8FusedWLCHFunc.apply( + dispatched_input, + token_combine_weights, + self, + recompute_fwd_gate_up=recompute_fwd_gate_up_func(self.config, self.layer_idx), + dequant_input=("dequant_input" in self.config.fp8_mem_configs) + and self.config.fp8_mem_configs["dequant_input"], + quant_before_a2a=use_quant_before_a2a, + is_first_fwd=not framework._dygraph_tracer()._has_grad, + group=self.group, + fp8_dispatched_handle=fp8_dispatched_handle, + ) + else: + + expert_out = ( + recompute(self.forward_experts, dispatched_input) + if self.recompute and self.training + else self.forward_experts(dispatched_input) + ) + + if self.config.use_combine_before_a2a: + token_combine_weights = token_combine_weights.clone().reshape(expert_out.shape[:-1] + [1]) + expert_out = expert_out * token_combine_weights + else: + assert ( + len(dispatched_input.shape) == 4 + and dispatched_input.shape[1] == self.world_size + and dispatched_input.shape[0] == self.num_local_experts + ), ( + f"When using ep_comm_overlap, moe_gate_dispatch_permute is needed. " + f"Expected dispatched_input to have shape[1] == {self.world_size} " + f"and shape[0] == {self.num_local_experts}, " + f"but got shape {dispatched_input.shape}" + ) + with profile("moe_comm_and_forward_expert"): + expert_out = AlltoAllExpertOverlap.apply( + dispatched_input, + self.group, + self.num_local_experts, + self.experts, + is_first_fwd=is_first_fwd, + ) + if self.shared_experts is not None: + shared_out = self.shared_experts(input) + + with profile("moe_comm_and_calc_routerloss"): + expert_out, router_loss2 = AlltoAllAsync.apply( + expert_out, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + group=self.group, + fn=self.calc_router_loss_and_logging, + is_first_fwd=is_first_fwd, + ) + + with profile("combine"): + if self.config.use_combine_before_a2a: + expert_out = expert_out.reshape([-1, hidden_size]) + scatter_index = scatter_index.reshape(si_shape) + combined_output = paddle.incubate.nn.functional.moe_combine_no_weight( + expert_out, combine_weights, scatter_index, epsilon=1e-15 + ) + else: + combined_output = self.combine_expert_output(expert_out, combine_weights, scatter_index) + + if self.shared_experts is not None: + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.clone().reshape(orig_shape[:-1] + [combined_output.shape[-1]]) + return combined_output, combine_weights, router_loss2, gate_logits + + +class FP8FusedWLCHFunc(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + hidden_states, + dispatched_probs, + custom_map, + recompute_fwd_gate_up=False, + dequant_input=False, + quant_before_a2a=False, + is_first_fwd=False, + group=None, + fp8_dispatched_handle=None, + ): + ctx.node = ExpertsGroupGemmWLCHNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + dequant_input=dequant_input, + group=group, + ) + ctx.group = group + ctx.quant_before_a2a = quant_before_a2a + num_local_experts = custom_map.num_local_experts + + def a2a_fn(input_fp8, input_scale): + return AlltoAll.apply(input_fp8, group), AlltoAll.apply(input_scale, group) + + if quant_before_a2a: + assert fp8_dispatched_handle is not None + assert hidden_states.dtype == paddle.float8_e4m3fn + hidden_states, scale = a2a_fn(hidden_states, fp8_dispatched_handle["scale"]) + scale = scale.reshape([-1, scale.shape[-1]]) + else: + scale = None + + hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) + dispatched_probs = dispatched_probs.reshape([-1, dispatched_probs.shape[-1]]) + tokens_per_expert = [np.prod(hidden_states.shape[:-1]) // num_local_experts] * num_local_experts + + out = ctx.node.forward(hidden_states, dispatched_probs, tokens_per_expert, tokens_per_expert, scale=scale) + + if is_first_fwd: + ctx.node.reset_status() + + return out + + @staticmethod + def backward(ctx, output_grad): + def a2a_async_fn(input): + return AlltoAll.apply(input, ctx.group, sync_op=False) + + if ctx.quant_before_a2a: + return ctx.node.backward(output_grad, a2a_async_fn=a2a_async_fn) + else: + return ctx.node.backward(output_grad, a2a_async_fn=None) + + +class MlpNode: + def __init__(self, custom_map, max_topk, recompute_fwd_gate_up=False, dequant_input=False): + self.token_dispatcher = custom_map.dispatcher + self.experts = custom_map.experts + self.experts_group_gemm_node = ExpertsGroupGemmContiguousNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + dequant_input=dequant_input, + ) + self.unzip_node = UnZipNode(self.token_dispatcher) + self.zip_node = ZipNode(self.token_dispatcher) + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = self.token_dispatcher._comm_manager.tokens_per_expert_list + self.router_topk = max_topk + + def reset_status(self): + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.router_topk = None + self.experts_group_gemm_node.reset_status() + self.experts_group_gemm_node = None + + def release_mem(self): + self.experts_group_gemm_node.reset_status() + self.experts_group_gemm_node = None + + @paddle.no_grad() + def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs): + num_experts = len(self.tokens_per_expert) + + self.dispatched_indices = dispatched_indices.to(paddle.int32) + (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs) = self.unzip_node.forward( + hs_2d_dispatched, + self.dispatched_indices, + dispatched_probs, + topk=self.router_topk, + num_experts=num_experts, + tokens_per_expert=self.tokens_per_expert, + ) + hs_2d_dispatched._record_stream() + dispatched_indices._record_stream() + dispatched_probs._record_stream() + + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + expert_out = self.experts_group_gemm_node.forward( + unzipped_tokens, + unzipped_probs, + padding_token_per_experts, + self.tokens_per_expert, + ) + + expert_out_tmp = expert_out.reshape([-1, expert_out.shape[-1]]) + + expert_out_zipped = self.zip_node.forward( + expert_out_tmp, + zipped_expertwise_rowmap, + self.dispatched_indices, + unzipped_probs, + total_zipped_tokens=hs_2d_dispatched.shape[0], + num_experts=num_experts, + ) + + self.dispatched_probs = dispatched_probs + expert_out_zipped.stop_gradient = False + + return expert_out_zipped + + @paddle.no_grad() + def backward(self, hidden_states_out_grad): + unzipped_grad = self.zip_node.backward( + hidden_states_out_grad, + self.dispatched_indices, + self.dispatched_probs, + top_k=self.router_topk, + num_experts=len(self.tokens_per_expert), + tokens_per_expert=self.tokens_per_expert, + ) + hidden_states_out_grad._record_stream() + + expert_out, probs_grad = self.experts_group_gemm_node.backward(unzipped_grad) + + hs_fp8_dispatched_grad, dispatched_probs_grad = self.unzip_node.backward( + expert_out, + hidden_states_out_grad, + probs_grad, + self.dispatched_indices, + num_experts=len(self.tokens_per_expert), + ) + self.reset_status() + return hs_fp8_dispatched_grad, dispatched_probs_grad + + +class Fp8FusedMoeFunc(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + hidden_states, + dispatched_probs, + dispatched_indices, + custom_map, + max_topk, + recompute_fwd_gate_up=False, + dequant_input=False, + is_first_fwd=False, + ): + ctx.node = MlpNode( + custom_map, + max_topk, + recompute_fwd_gate_up=recompute_fwd_gate_up, + dequant_input=dequant_input, + ) + out = ctx.node.forward(hidden_states, dispatched_indices, dispatched_probs) + + if is_first_fwd: + ctx.node.release_mem() + return out + + @staticmethod + def backward(ctx, output_grad): + hidden_states_grad, dispatched_probs_grad = ctx.node.backward(output_grad) + return hidden_states_grad, dispatched_probs_grad, None diff --git a/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/__init__.py b/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cc79cc9d7f1977efe8e066facf32c20c8ad3af --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/fp8_utils.py b/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/fp8_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a30c6259d87b0b5b6da7ae5cfa7079b9faf14d54 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/fp8_utils.py @@ -0,0 +1,1493 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 DeepSeek +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FP8 Utilities for Mixture of Experts (MoE) Token Dispatcher + +This module provides optimized operations for FP8 (8-bit floating point) computations +in Mixture of Experts architectures. Key features include: +- FP8 GEMM (General Matrix Multiply) operations for expert computations +- Specialized forward and backward passes for MoE layers +- Memory-efficient quantization and dequantization routines +- Support for both contiguous and non-contiguous memory layouts + +The implementation leverages PaddlePaddle's FP8 incubator operations and provides +additional optimizations specific to MoE workloads. +""" + +import numpy +import paddle +from models.fp8_linear import fp8_gemm +from paddle.incubate.fp8 import deep_gemm +from paddle.incubate.nn.functional import swiglu + +__all__ = [ + "ExpertsGroupGemmNode", + "ExpertsGroupGemmContiguousNode", +] + +def _get_fp8_weight_and_scale(weight, stacked=False, transpose=False): + if stacked: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_stacked_transpose, weight.fp8_scale_stacked_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight_stacked, weight.fp8_scale_stacked + else: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_transpose, weight.fp8_scale_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight, weight.fp8_scale + return fp8_weight, fp8_scale + + +def fused_stack_transpose_quant(weight_list, transpose=False): + """ + Quant BF16 weight to FP8 + + Args: + weight_list (List[Tensor]): Input tensor list in BF16 format + transpose (Boolean): Transpose operation flag + + Returns: + Tuple[Tensor, Tensor]: The weight and scale after quant in FP8 format + """ + if hasattr(weight_list[0], "fp8_weight_stacked"): + w, scale = _get_fp8_weight_and_scale(weight_list[0], stacked=True, transpose=transpose) + else: + w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(weight_list, transpose) + return w, scale + + +def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out): + """ + Perform grouped GEMM operation with FP8 tensors, splitting by expert tokens. + + Args: + x_fp8 (Tensor): Input tensor in FP8 format + x_scale (Tensor): Scaling factors for input tensor + w_fp8 (Tensor): Weight tensor in FP8 format + w_scale (Tensor): Scaling factors for weight tensor + tokens_per_expert (List[int]): Number of tokens assigned to each expert + gemm_out (Tensor): Output tensor for GEMM results + + Returns: + Tensor: The GEMM output tensor with expert-specific computations + + Note: + This implementation uses deep_gemm operations optimized for FP8 precision + and handles the case where tokens may be unevenly distributed across experts. + """ + start_idx = 0 + for i, token_num in enumerate(tokens_per_expert): + if token_num == 0: + continue + end_idx = start_idx + token_num + + x_scale_tma_align = x_scale[start_idx:end_idx].T.contiguous().T + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (x_fp8[start_idx:end_idx], x_scale_tma_align), + (w_fp8[i], w_scale[i]), + gemm_out[start_idx:end_idx], + ) + + start_idx = end_idx + + return gemm_out + + +def has_config(config_map, key): + return bool(config_map is not None and key in config_map and config_map[key]) + + +class ExpertsGroupGemmNode: + """ + Node for performing grouped GEMM operations in FP8 precision for MoE layers. + + This class handles both forward and backward passes for expert computations, + including specialized operations for: + - Gate projection (up_gate_proj) + - SwiGLU activation + - Down projection (down_proj) + + The implementation supports both standard and probability-weighted computations. + """ + + def __init__(self, experts, custom_map, name="moe_experts_node"): + """ + Initialize the ExpertsGroupGemmNode. + + Args: + experts (List[Module]): List of expert modules + custom_map (CustomMap): Configuration mapping for expert operations + name (str): Optional name for the node + + Attributes: + o1 (Tensor): Cache for intermediate gate projection results + unzipped_tokens (Tensor): Cache for input tokens + custom_map (CustomMap): Expert configuration mapping + unzipped_probs (Tensor): Cache for expert probabilities + tokens_per_expert (List[int]): Token distribution across experts + fp8_fused_ops_configs (Dict): Configuration for FP8 fused operations + """ + self.o1 = None + self.unzipped_tokens = None + self.custom_map = custom_map + self.unzipped_probs = None + self.tokens_per_expert = None + self.fp8_fused_ops_configs = custom_map.config.fp8_fused_ops_configs + + def reset_status(self): + self.o1 = None + self.unzipped_tokens = None + self.unzipped_probs = None + self.tokens_per_expert = None + + def fwd_gate_up(self, x_bf16, expert_w1, expert_w_count, tokens_per_expert): + """ + Forward pass for gate projection in FP8 precision. + + Args: + x_bf16 (Tensor): Input tensor in bfloat16 format + expert_w1 (List[Tensor]): List of expert weights for gate projection + expert_w_count (int): Number of experts + tokens_per_expert (List[int]): Token distribution across experts + + Returns: + Tensor: Output of gate projection in bfloat16 format + + Note: + - Handles both stacked and individual expert weight quantization + - Supports FP8 fused operations when configured + - Maintains intermediate results for backward pass + """ + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + w1_t_quant, w1_t_scale = fused_stack_transpose_quant( + expert_w1, transpose=True + ) + else: + stacked_w1 = paddle.stack(expert_w1, axis=0) + stacked_w1_t = paddle.transpose(stacked_w1, [0, 2, 1]).contiguous() + concated_w1_t = stacked_w1_t.reshape([-1, stacked_w1_t.shape[-1]]) + + w1_t_quant, w1_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w1_t, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + + w1_t_quant = w1_t_quant.reshape([expert_w_count, -1, w1_t_quant.shape[-1]]) + w1_t_scale = w1_t_scale.reshape([expert_w_count, -1, w1_t_scale.shape[-1]]) + + x_fp8, x_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + x_bf16, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + + x_fp8 = x_fp8.reshape([expert_w_count, -1, x_fp8.shape[-1]]) + x_scale = x_scale.reshape([expert_w_count, -1, x_scale.shape[-1]]) + x_scale = paddle.transpose(paddle.transpose(x_scale, [0, 2, 1]).contiguous(), [0, 2, 1]) + + o1 = paddle.zeros([expert_w_count, x_fp8.shape[1], w1_t_quant.shape[1]], dtype=x_bf16.dtype) + if numpy.prod(x_fp8.shape) != 0: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (x_fp8, x_scale), + (w1_t_quant, w1_t_scale), + o1, + tokens_per_expert, + x_fp8.shape[1], + ) + return o1 + + def fwd_swiglu(self, o1): + """ + Compute SwiGLU activation function. + + Args: + o1 (Tensor): Input tensor from gate projection + + Returns: + Tensor: Output after SwiGLU activation + + Note: + Uses PaddlePaddle's optimized swiglu implementation + """ + o2 = swiglu(o1) + return o2 + + def fwd_down(self, o1, unzipped_probs, expert_w_count, tokens_per_expert): + """ + Forward pass for down projection with probability weighting. + + Args: + o1 (Tensor): Input tensor from SwiGLU activation + unzipped_probs (Tensor): Expert probabilities for each token + expert_w_count (int): Number of experts + tokens_per_expert (List[int]): Token distribution across experts + + Returns: + Tuple[Tensor, Tensor]: + - Output tensor after down projection + - Reshaped probabilities tensor + + Note: + - Handles both standard and fused FP8 quantization paths + - Applies probability weighting to expert outputs + - Uses grouped GEMM operations optimized for FP8 + """ + expert_w2 = [x.down_proj.weight for x in self.custom_map.experts if x is not None] + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + w2_quant, w2_scale = fused_stack_transpose_quant(expert_w2, transpose=True) + else: + stacked_w2 = paddle.stack(expert_w2, axis=0) + stacked_w2_t = paddle.transpose(stacked_w2, [0, 2, 1]).contiguous() + concated_w2_t = stacked_w2_t.reshape([-1, stacked_w2_t.shape[-1]]) + + w2_quant, w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w2_t, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + w2_quant = w2_quant.reshape([expert_w_count, -1, w2_quant.shape[-1]]) + w2_scale = w2_scale.reshape([expert_w_count, -1, w2_scale.shape[-1]]) + o2 = self.fwd_swiglu(o1) + unzipped_probs = unzipped_probs.unsqueeze(-1).reshape([expert_w_count, -1, 1]) + o2 = (o2 * unzipped_probs).cast(paddle.bfloat16) + o2_reshape = o2.reshape([-1, o2.shape[-1]]).contiguous() + o2_quant, o2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2_reshape, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + + o2_quant = o2_quant.reshape([expert_w_count, -1, o2_quant.shape[-1]]) + o2_scale = o2_scale.reshape([expert_w_count, -1, o2_scale.shape[-1]]) + o2_scale = paddle.transpose(paddle.transpose(o2_scale, [0, 2, 1]).contiguous(), [0, 2, 1]) + o3 = paddle.zeros([expert_w_count, o2_quant.shape[1], w2_quant.shape[1]], dtype=o1.dtype) + if numpy.prod(o2_quant.shape) != 0: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (o2_quant, o2_scale), + (w2_quant, w2_scale), + o3, + tokens_per_expert, + o2_quant.shape[1], + ) + return o3, unzipped_probs + + def fwd_down_no_probs(self, o1, expert_w2, expert_w_count, tokens_per_expert): + """ + Forward pass for down projection without probability weighting. + + Args: + o1 (Tensor): Input tensor from SwiGLU activation + expert_w2 (List[Tensor]): List of expert weights for down projection + expert_w_count (int): Number of experts + tokens_per_expert (List[int]): Token distribution across experts + + Returns: + Tensor: Output tensor after down projection + + Note: + - Simplified version of fwd_down without probability handling + - Still maintains FP8 optimized computation path + """ + expert_w2 = [x.down_proj.weight for x in self.custom_map.experts if x is not None] + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + w2_quant, w2_scale = fused_stack_transpose_quant(expert_w2, transpose=True) + else: + stacked_w2 = paddle.stack(expert_w2, axis=0) + stacked_w2_t = paddle.transpose(stacked_w2, [0, 2, 1]).contiguous() + concated_w2_t = stacked_w2_t.reshape([-1, stacked_w2_t.shape[-1]]) + + w2_quant, w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w2_t, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + w2_quant = w2_quant.reshape([expert_w_count, -1, w2_quant.shape[-1]]) + w2_scale = w2_scale.reshape([expert_w_count, -1, w2_scale.shape[-1]]) + o2 = self.fwd_swiglu(o1) + + o2_reshape = o2.reshape([-1, o2.shape[-1]]).contiguous() + o2_quant, o2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2_reshape, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + + o2_quant = o2_quant.reshape([expert_w_count, -1, o2_quant.shape[-1]]) + o2_scale = o2_scale.reshape([expert_w_count, -1, o2_scale.shape[-1]]) + o2_scale = paddle.transpose(paddle.transpose(o2_scale, [0, 2, 1]).contiguous(), [0, 2, 1]) + + o3 = paddle.zeros([expert_w_count, o2_quant.shape[1], w2_quant.shape[1]], dtype=o1.dtype) + if numpy.prod(o2_quant.shape) != 0: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (o2_quant, o2_scale), + (w2_quant, w2_scale), + o3, + tokens_per_expert, + o2_quant.shape[1], + ) + return o3 + + def bwd_down_input(self, expert_w2, unzipped_grad, tokens_per_expert, expected_m): + """ + Backward pass for down projection input gradient computation. + + Args: + expert_w2 (List[Tensor]): List of expert weights for down projection + unzipped_grad (Tensor): Gradient from downstream layer + tokens_per_expert (List[int]): Token distribution across experts + expected_m (int): Expected batch dimension size + + Returns: + Tuple[Tensor, Tensor, Tensor]: + - Input gradient (do1) + - SwiGLU output (o2_s) + - Probability gradients + + Note: + - Handles both standard and fused FP8 backprop paths + - Computes gradients for SwiGLU activation and probability weighting + """ + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + bw_w2_quant, bw_w2_scale = fused_stack_transpose_quant( + expert_w2, transpose=False + ) + else: + stacked_w2 = paddle.stack(expert_w2, axis=0) + concated_w2 = stacked_w2.reshape([-1, stacked_w2.shape[-1]]) + + bw_w2_quant, bw_w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w2, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]]) + bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]]) + + unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + unzipped_grad, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + unzipped_grad_scale = paddle.transpose( + paddle.transpose(unzipped_grad_scale, [0, 2, 1]).contiguous(), [0, 2, 1] + ) + do2_s = paddle.zeros( + [len(expert_w2), unzipped_grad_fp8.shape[1], bw_w2_quant.shape[1]], + dtype=unzipped_grad.dtype, + ) + if numpy.prod(unzipped_grad_fp8.shape) != 0: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (unzipped_grad_fp8, unzipped_grad_scale), + (bw_w2_quant, bw_w2_scale), + do2_s, + tokens_per_expert, + expected_m, + ) + if has_config(self.fp8_fused_ops_configs, "swiglu_probs_bwd"): + do1, probs_grad, o2_s = paddle.incubate.nn.functional.fused_swiglu_weighted_bwd( + self.o1, do2_s, self.unzipped_probs + ) + else: + o2 = self.fwd_swiglu(self.o1) + o2_s = (o2 * self.unzipped_probs).cast(paddle.bfloat16) + do2 = (do2_s.cast(paddle.float32) * self.unzipped_probs).cast(paddle.bfloat16) + + probs_grad = (do2_s.cast(paddle.float32) * (o2.cast(paddle.float32))).sum(axis=-1) + do1 = self.bwd_swiglu(self.o1, do2) + + return do1, o2_s, probs_grad + + def bwd_down_input_no_prob(self, expert_w2, unzipped_grad, tokens_per_expert, expected_m): + o2 = self.fwd_swiglu(self.o1) + o2_s = o2 + + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + bw_w2_quant, bw_w2_scale = fused_stack_transpose_quant( + expert_w2, transpose=False + ) + else: + stacked_w2 = paddle.stack(expert_w2, axis=0) + concated_w2 = stacked_w2.reshape([-1, stacked_w2.shape[-1]]) + + bw_w2_quant, bw_w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w2, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]]) + bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]]) + + unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + unzipped_grad, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + do2_s = paddle.zeros( + [len(expert_w2), unzipped_grad_fp8.shape[1], bw_w2_quant.shape[1]], + dtype=unzipped_grad.dtype, + ) + if numpy.prod(unzipped_grad_fp8.shape) != 0: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (unzipped_grad_fp8, unzipped_grad_scale), + (bw_w2_quant, bw_w2_scale), + do2_s, + tokens_per_expert, + expected_m, + ) + + return do2_s, o2_s + + def bwd_swiglu(self, o1, do2): + """ + Backward pass for SwiGLU activation function. + + Args: + o1 (Tensor): Original input to SwiGLU + do2 (Tensor): Gradient from downstream layer + + Returns: + Tensor: Gradient with respect to SwiGLU input + + Note: + Uses PaddlePaddle's optimized swiglu_grad operation + """ + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + return do1 + + def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, expected_m): + """ + Backward pass for gate projection input gradient computation. + + Args: + do1 (Tensor): Gradient from downstream layer + expert_w1 (List[Tensor]): List of expert weights for gate projection + tokens_per_expert (List[int]): Token distribution across experts + expected_m (int): Expected batch dimension size + + Returns: + Tensor: Input gradient (dx) + + Note: + - Performs FP8 optimized GEMM for gradient computation + - Handles both standard and fused quantization paths + """ + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + bw_w1_quant, bw_w1_scale = fused_stack_transpose_quant( + expert_w1, transpose=False + ) + else: + stacked_w1 = paddle.stack(expert_w1, axis=0) + concated_w1_t_2d = stacked_w1.reshape([-1, stacked_w1.shape[-1]]) + + bw_w1_quant, bw_w1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w1_t_2d, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]]) + bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]]) + + do1_fp8_reshape = do1.reshape([-1, do1.shape[-1]]).contiguous() + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1_fp8_reshape, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + + do1_fp8 = (do1_fp8.reshape([len(expert_w1), -1, do1_fp8.shape[-1]])).contiguous() + do1_scale = do1_scale.reshape([len(expert_w1), -1, do1_scale.shape[-1]]).contiguous() + do1_scale = paddle.transpose(paddle.transpose(do1_scale, [0, 2, 1]).contiguous(), [0, 2, 1]) + + dx = paddle.zeros( + shape=[len(expert_w1), do1_fp8.shape[1], bw_w1_quant.shape[1]], + dtype=paddle.bfloat16, + ) + if numpy.prod(do1_fp8.shape) != 0: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (do1_fp8, do1_scale), + (bw_w1_quant, bw_w1_scale), + dx, + tokens_per_expert, + expected_m, + ) + return dx + + def bwd_down_weight(self, out_grad, o2, expert_w2): + """ + Backward pass for down projection weight gradient computation. + + Args: + out_grad (Tensor): Gradient from downstream layer + o2 (Tensor): Output from SwiGLU activation + expert_w2 (List[Tensor]): List of expert weights for down projection + + Note: + - Computes weight gradients using FP8 optimized GEMM + - Handles both main_grad and standard grad accumulation + - Maintains proper gradient scaling for FP8 precision + """ + group_num = len(expert_w2) + H2 = o2.shape[-1] + + o2_t = ( + o2.reshape([group_num, -1, H2]) + .transpose([0, 2, 1]) + .contiguous() + .reshape([group_num * H2, -1]) + .contiguous() + ) + + o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2_t, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + + o2_t_fp8 = o2_t_fp8.reshape([group_num, int(o2_t_fp8.shape[0] / group_num), o2_t_fp8.shape[-1]]) + o2_t_scale = paddle.split(o2_t_scale, num_or_sections=group_num, axis=-1) + + H1 = out_grad.shape[-1] + out_grad = ( + out_grad.reshape([group_num, -1, H1]) + .transpose([0, 2, 1]) + .contiguous() + .reshape([group_num * H1, -1]) + .contiguous() + ) + + out_grad_fp8, out_grad_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + out_grad, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + + out_grad_fp8 = out_grad_fp8.reshape([group_num, H1, -1]) + out_grad_scale = paddle.split(out_grad_scale, num_or_sections=group_num, axis=-1) + + for i in range(len(expert_w2)): + if hasattr(expert_w2[i], "main_grad"): + if expert_w2[i].main_grad is None: + expert_w2[i].main_grad = paddle.zeros(shape=expert_w2[i].shape, dtype=paddle.float32) + fp8_gemm( + o2_t_fp8[i], + o2_t_scale[i], + out_grad_fp8[i], + out_grad_scale[i], + True, + True, + expert_w2[i].main_grad, + paddle.float32, + ) + else: + if expert_w2[i].grad is None: + expert_w2[i].grad = paddle.zeros(shape=expert_w2[i].shape, dtype=paddle.float32) + fp8_gemm( + o2_t_fp8[i], + o2_t_scale[i], + out_grad_fp8[i], + out_grad_scale[i], + True, + True, + expert_w2[i].grad, + paddle.float32, + ) + + def bwd_gate_up_weight(self, do1, input_x, expert_w1): + group_num = len(expert_w1) + """ + Backward pass for gate projection weight gradient computation. + + Args: + do1 (Tensor): Gradient from downstream layer + input_x (Tensor): Original input to gate projection + expert_w1 (List[Tensor]): List of expert weights for gate projection + + Note: + - Computes weight gradients using FP8 optimized GEMM + - Handles both main_grad and standard grad accumulation + - Maintains proper gradient scaling for FP8 precision + """ + H1 = input_x.shape[-1] + input_x = ( + input_x.reshape([group_num, -1, H1]) + .transpose([0, 2, 1]) + .contiguous() + .reshape([group_num * H1, -1]) + .contiguous() + ) + + input_x_fp8, input_x_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + input_x, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + input_x_scale = paddle.split(input_x_scale, num_or_sections=group_num, axis=-1) + + H2 = do1.shape[-1] + do1 = ( + do1.reshape([group_num, -1, H2]) + .transpose([0, 2, 1]) + .contiguous() + .reshape([group_num * H2, -1]) + .contiguous() + ) + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do1_scale = paddle.split(do1_scale, num_or_sections=group_num, axis=-1) + + for i in range(len(expert_w1)): + if hasattr(expert_w1[i], "main_grad"): + if expert_w1[i].main_grad is None: + expert_w1[i].main_grad = paddle.zeros(shape=expert_w1[i].shape, dtype=paddle.float32) + fp8_gemm( + input_x_fp8[i], + input_x_scale[i], + do1_fp8[i], + do1_scale[i], + True, + True, + expert_w1[i].main_grad, + paddle.float32, + ) + else: + if expert_w1[i].grad is None: + expert_w1[i].grad = paddle.zeros(shape=expert_w1[i].shape, dtype=paddle.float32) + fp8_gemm( + input_x_fp8[i], + input_x_scale[i], + do1_fp8[i], + do1_scale[i], + True, + True, + expert_w1[i].grad, + paddle.float32, + ) + + @paddle.no_grad() + def forward(self, hs_out, unzipped_probs, tokens_per_expert): + expert_w1 = [x.up_gate_proj.weight for x in self.custom_map.experts if x is not None] + expert_w_count = len(expert_w1) + + o1 = self.fwd_gate_up(hs_out, expert_w1, expert_w_count, tokens_per_expert) + self.o1 = o1 + + o3, unzipped_probs = self.fwd_down( + o1=o1, unzipped_probs=unzipped_probs, expert_w_count=expert_w_count, tokens_per_expert=tokens_per_expert + ) + + self.unzipped_probs = unzipped_probs + self.unzipped_tokens = hs_out + return o3 + + @paddle.no_grad() + def backward(self, out_grad, tokens_per_expert, dispatched_indices, expected_m): + expert_w2 = [x.down_proj.weight for x in self.custom_map.experts if x is not None] + expert_w1 = [x.up_gate_proj.weight for x in self.custom_map.experts if x is not None] + + do1, o2_s, probs_grad = self.bwd_down_input(expert_w2, out_grad, tokens_per_expert, expected_m) + + dx = self.bwd_gate_up_input(do1, expert_w1, tokens_per_expert, expected_m) + dx = dx.reshape([-1, dx.shape[-1]]) + self.bwd_down_weight(out_grad, o2_s, expert_w2) + self.bwd_gate_up_weight(do1, self.unzipped_tokens, expert_w1) + + self.reset_status() + return dx, probs_grad + + @paddle.no_grad() + def forward_no_prob(self, hs_out, tokens_per_expert): + expert_w1 = [x.up_gate_proj.weight for x in self.custom_map.experts if x is not None] + expert_w_count = len(expert_w1) + + expert_w2 = [x.down_proj.weight for x in self.custom_map.experts if x is not None] + o1 = self.fwd_gate_up(hs_out, expert_w1, expert_w_count, tokens_per_expert) + self.o1 = o1 + o3 = self.fwd_down_no_probs(o1, expert_w2, expert_w_count, tokens_per_expert) + self.unzipped_tokens = hs_out + return o3 + + @paddle.no_grad() + def backward_no_prob(self, out_grad, tokens_per_expert): + expert_w2 = [x.down_proj.weight for x in self.custom_map.experts if x is not None] + expert_w1 = [x.up_gate_proj.weight for x in self.custom_map.experts if x is not None] + + expected_m = int(numpy.prod(out_grad.shape[:-1]) // len(expert_w1)) + + out_grad = out_grad.reshape([-1, out_grad.shape[-1]]) + + do2, o2_s = self.bwd_down_input_no_prob(expert_w2, out_grad, tokens_per_expert, expected_m) + + do1 = self.bwd_swiglu(self.o1, do2) + + dx = self.bwd_gate_up_input(do1, expert_w1, tokens_per_expert, expected_m) + dx = dx.reshape([-1, dx.shape[-1]]) + + self.bwd_down_weight(out_grad, o2_s, expert_w2) + self.bwd_gate_up_weight(do1, self.unzipped_tokens, expert_w1) + + self.reset_status() + return dx + + +class ExpertsGroupGemmContiguousNode: + """ + Node for performing grouped GEMM operations with contiguous memory layout. + + This optimized version provides better performance for certain hardware configurations + by ensuring memory access patterns are more cache-friendly. Key differences from + ExpertsGroupGemmNode include: + - Contiguous memory layout for all intermediate tensors + - Specialized handling for recomputation scenarios + - Optional input dequantization support + - Split group GEMM optimization when configured + """ + + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + dequant_input=False, + group=None, + name="experts_group_gemm_contiguous_node", + ): + """ + Initialize the ExpertsGroupGemmContiguousNode. + + Args: + custom_map (CustomMap): Configuration mapping for expert operations + recompute_fwd_gate_up (bool): Whether to recompute gate projection in backward pass + dequant_input (bool): Whether to dequantize input tensors + name (str): Optional name for the node + + Attributes: + custom_map (CustomMap): Expert configuration mapping + recompute_fwd_gate_up (bool): Recompute flag + dequant_input (bool): Input dequantization flag + tokens_per_expert (List[int]): Token distribution across experts + m_indices (Tensor): Expert indices for contiguous operations + unzipped_probs (Tensor): Cache for expert probabilities + input (Tensor): Cache for input tensor (bf16) + input_fp8 (Tensor): Cache for input tensor (FP8) + input_scale (Tensor): Cache for input scaling factors + o1 (Tensor): Cache for intermediate gate projection results + fp8_fused_ops_configs (Dict): Configuration for FP8 fused operations + is_split_group_gemm (bool): Whether split group GEMM optimization is enabled + """ + self.custom_map = custom_map + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.dequant_input = dequant_input + self.tokens_per_expert = None + self.m_indices = None + self.unzipped_probs = None + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.fp8_fused_ops_configs = custom_map.config.fp8_fused_ops_configs + self.is_split_group_gemm = has_config(self.fp8_fused_ops_configs, "split_group_gemm") + self.group = group + + def reset_status(self): + self.tokens_per_expert = None + self.m_indices = None + self.unzipped_probs = None + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + + def gen_m_indices(self, tokens_per_expert): + """ + Generate expert indices tensor for contiguous operations. + + Args: + tokens_per_expert (List[int]): Token distribution across experts + + Returns: + Tensor: Contiguous tensor of expert indices + + Note: + This creates a flat tensor where each element indicates which expert + should process the corresponding token, enabling efficient batched + operations with contiguous memory access. + """ + tokens = [] + for i in range(len(tokens_per_expert)): + tokens.append(paddle.full([tokens_per_expert[i]], i, dtype="int32")) + out = paddle.concat(tokens, axis=0) + return out + + def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, scale=None): + """ + Forward pass for gate projection with contiguous memory layout. + + Args: + x (Tensor): Input tensor in bfloat16 or float8 format + expert_w1 (List[Tensor]): List of expert weights for gate projection + num_expert (int): Number of experts + tokens_per_expert (List[int]): Token distribution across experts + scale (Tensor|None): Scale tensor for dequantization, optional. + + Returns: + Tensor: Output of gate projection in bfloat16 format + + Note: + - Optimized for contiguous memory access patterns + - Supports both split and non-split group GEMM variants + - Handles input caching for recomputation scenarios + - Maintains FP8 precision for compute-intensive operations + """ + self.tokens_per_expert = tokens_per_expert + if not self.is_split_group_gemm: + self.m_indices = self.gen_m_indices(tokens_per_expert) + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + w1_t_quant, w1_t_scale = fused_stack_transpose_quant( + expert_w1, transpose=True + ) + else: + stacked_w1 = paddle.stack(expert_w1, axis=0) + stacked_w1_t = paddle.transpose(stacked_w1, [0, 2, 1]).contiguous() + concated_w1_t = stacked_w1_t.reshape([-1, stacked_w1_t.shape[-1]]) + w1_t_quant, w1_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w1_t, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]]) + w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]]) + + if x is None: + x_fp8, x_scale = self.input_fp8, self.input_scale + assert x_fp8 is not None and x_scale is not None + elif scale is not None: + x_fp8, x_scale = x, scale + assert self.dequant_input, ( + "If a scale is provided, it indicates that a2a is using fp8. Dequant_input must be enabled." + ) + else: + x_fp8, x_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + x, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + x_scale = x_scale.T + + o1 = paddle.empty([x_fp8.shape[0], w1_t_quant.shape[1]], dtype=expert_w1[0].dtype) + if numpy.prod(x_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (x_fp8, x_scale), + (w1_t_quant, w1_t_scale), + o1, + m_indices=self.m_indices, + ) + + if self.dequant_input: + self.input_fp8 = x_fp8 + self.input_scale = x_scale + else: + self.input = x + return o1 + + def fwd_swiglu(self, o1): + o2 = swiglu(o1) + return o2 + + def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert): + """ + Forward pass for down projection with contiguous memory layout. + + Args: + o1 (Tensor): Input tensor from SwiGLU activation + unzipped_probs (Tensor): Expert probabilities for each token + expert_w2 (List[Tensor]): List of expert weights for down projection + num_expert (int): Number of experts + + Returns: + Tuple[Tensor, Tensor]: + - Output tensor after down projection + - Reshaped probabilities tensor + + Note: + - Uses contiguous memory layout for all intermediate tensors + - Supports fused SwiGLU activation and quantization when configured + - Handles both split and non-split group GEMM variants + """ + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + w2_quant, w2_scale = fused_stack_transpose_quant(expert_w2, transpose=True) + else: + stacked_w2 = paddle.stack(expert_w2, axis=0) + stacked_w2_t = paddle.transpose(stacked_w2, [0, 2, 1]).contiguous() + concated_w2_t = stacked_w2_t.reshape([-1, stacked_w2_t.shape[-1]]) + w2_quant, w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w2_t, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]]) + w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]]) + + if has_config(self.fp8_fused_ops_configs, "spaq"): + with paddle.amp.auto_cast(False): + o2_fp8, o2_scale = paddle.incubate.nn.functional.fused_weighted_swiglu_act_quant( + o1, unzipped_probs, using_pow2_scaling=True + ) + o2_scale = paddle.transpose(paddle.transpose(o2_scale, [1, 0]).contiguous(), [1, 0]) + unzipped_probs = unzipped_probs.unsqueeze(-1) + else: + o2 = self.fwd_swiglu(o1) + unzipped_probs = unzipped_probs.unsqueeze(-1) + o2 = (o2 * unzipped_probs).cast(paddle.bfloat16) + o2_fp8, o2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=False, + ) + + o3 = paddle.empty([o2_fp8.shape[0], w2_quant.shape[1]], dtype=o1.dtype) + if numpy.prod(o2_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, self.tokens_per_expert, o3) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (o2_fp8, o2_scale), + (w2_quant, w2_scale), + o3, + m_indices=self.m_indices, + ) + return o3, unzipped_probs + + def bwd_down_input(self, expert_w2, unzipped_grad, o1): + """ + Backward pass for down projection input gradient (contiguous version). + + Args: + expert_w2 (List[Tensor]): List of expert weights for down projection + unzipped_grad (Tensor): Gradient from downstream layer + o1 (Tensor): Original input to SwiGLU activation + + Returns: + Tuple[Tensor, Tensor, Tensor]: + - Input gradient (do1) + - SwiGLU output (o2_s) + - Probability gradients + + Note: + - Optimized for contiguous memory access patterns + - Supports both standard and fused backprop paths + - Handles split group GEMM when configured + """ + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + bw_w2_quant, bw_w2_scale = fused_stack_transpose_quant( + expert_w2, transpose=False + ) + else: + stacked_w2 = paddle.stack(expert_w2, axis=0) + concated_w2 = stacked_w2.reshape([-1, stacked_w2.shape[-1]]) + bw_w2_quant, bw_w2_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w2, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]]) + bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]]) + + unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + unzipped_grad, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + unzipped_grad_scale = unzipped_grad_scale.T + do2_s = paddle.empty( + [unzipped_grad_fp8.shape[0], bw_w2_quant.shape[1]], + dtype=unzipped_grad.dtype, + ) + if numpy.prod(unzipped_grad_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm( + unzipped_grad_fp8, + unzipped_grad_scale, + bw_w2_quant, + bw_w2_scale, + self.tokens_per_expert, + do2_s, + ) + else: + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (unzipped_grad_fp8, unzipped_grad_scale), + (bw_w2_quant, bw_w2_scale), + do2_s, + m_indices=self.m_indices, + ) + + if has_config(self.fp8_fused_ops_configs, "swiglu_probs_bwd"): + do1, probs_grad, o2_s = paddle.incubate.nn.functional.fused_swiglu_weighted_bwd( + o1, do2_s, self.unzipped_probs.squeeze(-1) + ) + else: + o2 = self.fwd_swiglu(o1) + o2_s = (o2 * self.unzipped_probs).cast(paddle.bfloat16) + do2 = (do2_s.cast(paddle.float32) * self.unzipped_probs).cast(paddle.bfloat16) + probs_grad = (do2_s.cast(paddle.float32) * (o2.cast(paddle.float32))).sum(axis=-1) + do1 = self.bwd_swiglu(o1, do2) + + return do1, o2_s, probs_grad + + def bwd_swiglu(self, o1, do2): + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + return do1 + + def bwd_gate_up_input(self, do1, expert_w1): + """ + Args: + do1 (Tensor): Gradient from downstream layer + expert_w1 (List[Tensor]): List of expert weights for gate projection + + Returns: + Tensor: Input gradient (dx) + + Note: + - Uses contiguous memory layout for all operations + - Supports both standard and fused quantization paths + - Handles split group GEMM when configured + """ + if has_config(self.fp8_fused_ops_configs, "stack_quant"): + bw_w1_quant, bw_w1_scale = fused_stack_transpose_quant( + expert_w1, transpose=False + ) + else: + stacked_w1 = paddle.stack(expert_w1, axis=0) + concated_w1_t_2d = stacked_w1.reshape([-1, stacked_w1.shape[-1]]) + bw_w1_quant, bw_w1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + concated_w1_t_2d, + quant_method="128x128", + input_transpose=False, + output_scale_transpose=False, + ) + bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]]) + bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]]) + + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do1_scale = do1_scale.T + + dx = paddle.empty(shape=[do1_fp8.shape[0], bw_w1_quant.shape[1]], dtype=paddle.bfloat16) + if numpy.prod(do1_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm( + do1_fp8, + do1_scale, + bw_w1_quant, + bw_w1_scale, + self.tokens_per_expert, + dx, + ) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (do1_fp8, do1_scale), + (bw_w1_quant, bw_w1_scale), + dx, + m_indices=self.m_indices, + ) + + return dx + + def fused_transpose_split_quant(self, x, tokens_per_expert, pow_2_scales): + """ + Fused operation combining W-L-C-H transpose, split and quantization. + + Args: + x (Tensor): Input tensor to process + tokens_per_expert (List[int]): Token distribution across experts + pow_2_scales (bool): Whether to use power-of-2 scaling + + Returns: + Tuple[Tensor, Tensor]: + - Quantized and split tensor with W-L-C-H layout + - Corresponding scaling factors + + Note: + This optimized operation: + - Reshapes input into [World_size, Local_experts, Channels, Hidden] + - Performs fused transpose/split/quant in single kernel + - Maintains W-L-C-H memory layout throughout + - Reduces memory bandwidth requirements + """ + with paddle.amp.auto_cast(False): + out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales) + return out, scale + + def bwd_down_weight(self, do3, o2, expert_w2): + """ + Backward pass for down projection weight gradient (contiguous version). + + Args: + do3 (Tensor): Gradient from downstream layer + o2 (Tensor): Output from SwiGLU activation + expert_w2 (List[Tensor]): List of expert weights for down projection + + Note: + - Uses contiguous memory layout for all operations + - Supports both standard and fused transpose/split/quant paths + - Handles both main_grad and standard grad accumulation + """ + if has_config(self.fp8_fused_ops_configs, "transpose_split_quant"): + o2_t_fp8, o2_t_scale = self.fused_transpose_split_quant(o2, self.tokens_per_expert, True) + else: + o2_t = o2.transpose([1, 0]).contiguous() + o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + o2_t, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + o2_t_scale = paddle.split( + o2_t_scale, + num_or_sections=[int(x / 128) for x in self.tokens_per_expert], + axis=0, + ) + + if has_config(self.fp8_fused_ops_configs, "transpose_split_quant"): + do3_t_fp8, do3_t_scale = self.fused_transpose_split_quant(do3, self.tokens_per_expert, True) + else: + do3_t = do3.transpose([1, 0]).contiguous() + do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do3_t, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do3_t_scale = paddle.split( + do3_t_scale, + num_or_sections=[int(x / 128) for x in self.tokens_per_expert], + axis=0, + ) + + for i in range(len(expert_w2)): + if hasattr(expert_w2[i], "main_grad"): + if expert_w2[i].main_grad is None: + expert_w2[i].main_grad = paddle.zeros(shape=expert_w2[i].shape, dtype=paddle.float32) + fp8_gemm( + o2_t_fp8[i], + o2_t_scale[i], + do3_t_fp8[i], + do3_t_scale[i], + True, + True, + expert_w2[i].main_grad, + paddle.float32, + ) + else: + if expert_w2[i].grad is None: + expert_w2[i].grad = paddle.zeros(shape=expert_w2[i].shape, dtype=paddle.float32) + fp8_gemm( + o2_t_fp8[i], + o2_t_scale[i], + do3_t_fp8[i], + do3_t_scale[i], + True, + True, + expert_w2[i].grad, + paddle.float32, + ) + + def bwd_gate_up_weight(self, do1, input_x, expert_w1): + if has_config(self.fp8_fused_ops_configs, "transpose_split_quant"): + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(input_x, self.tokens_per_expert, True) + else: + input_x_t = input_x.transpose([1, 0]).contiguous() + input_x_t_fp8, input_x_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + input_x_t, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + input_x_t_scale = paddle.split( + input_x_t_scale, + num_or_sections=[int(x / 128) for x in self.tokens_per_expert], + axis=0, + ) + + if has_config(self.fp8_fused_ops_configs, "transpose_split_quant"): + do1_t_fp8, do1_t_scale = self.fused_transpose_split_quant(do1, self.tokens_per_expert, True) + else: + do1_t = do1.transpose([1, 0]).contiguous() + do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8.fp8_quant_blockwise( + do1_t, + quant_method="1x128", + input_transpose=False, + output_scale_transpose=True, + ) + do1_t_scale = paddle.split( + do1_t_scale, + num_or_sections=[int(x / 128) for x in self.tokens_per_expert], + axis=0, + ) + + for i in range(len(expert_w1)): + if hasattr(expert_w1[i], "main_grad"): + if expert_w1[i].main_grad is None: + expert_w1[i].main_grad = paddle.zeros(shape=expert_w1[i].shape, dtype=paddle.float32) + fp8_gemm( + input_x_t_fp8[i], + input_x_t_scale[i], + do1_t_fp8[i], + do1_t_scale[i], + True, + True, + expert_w1[i].main_grad, + paddle.float32, + ) + else: + if expert_w1[i].grad is None: + expert_w1[i].grad = paddle.zeros(shape=expert_w1[i].shape, dtype=paddle.float32) + fp8_gemm( + input_x_t_fp8[i], + input_x_t_scale[i], + do1_t_fp8[i], + do1_t_scale[i], + True, + True, + expert_w1[i].grad, + paddle.float32, + ) + + @paddle.no_grad() + def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_experts, scale=None): + self.origin_token_per_experts = origin_token_per_experts + if hs_out.shape[0] == 0: + o3 = paddle.zeros_like(hs_out) + self.unzipped_probs = unzipped_probs.unsqueeze(-1) + return o3 + expert_w1 = [x.up_gate_proj.weight for x in self.custom_map.experts if x is not None] + expert_w2 = [x.down_proj.weight for x in self.custom_map.experts if x is not None] + num_expert = len(expert_w1) + o1 = self.fwd_gate_up(hs_out, expert_w1, num_expert, tokens_per_expert, scale=scale) + if not self.recompute_fwd_gate_up: + self.o1 = o1 + o3, unzipped_probs = self.fwd_down(o1, unzipped_probs, expert_w2, num_expert) + self.unzipped_probs = unzipped_probs + return o3 + + @paddle.no_grad() + def backward(self, out_grad, a2a_async_fn=None): + if out_grad.shape[0] == 0: + dx = paddle.zeros_like(out_grad) + probs_grad = paddle.zeros_like(self.unzipped_probs) + + for expert in self.custom_map.experts: + if expert is None: + continue + + if hasattr(expert.down_proj.weight, "main_grad"): + if expert.down_proj.weight.main_grad is None: + expert.down_proj.weight.main_grad = paddle.zeros( + shape=expert.down_proj.weight.shape, dtype=paddle.float32 + ) + else: + if expert.down_proj.weight.grad is None: + expert.down_proj.weight.grad = paddle.zeros( + shape=expert.down_proj.weight.shape, dtype=paddle.float32 + ) + + if hasattr(expert.up_gate_proj.weight, "main_grad"): + if expert.up_gate_proj.weight.main_grad is None: + expert.up_gate_proj.weight.main_grad = paddle.zeros( + shape=expert.up_gate_proj.weight.shape, dtype=paddle.float32 + ) + else: + if expert.up_gate_proj.weight.grad is None: + expert.up_gate_proj.weight.grad = paddle.zeros( + shape=expert.up_gate_proj.weight.shape, dtype=paddle.float32 + ) + + if a2a_async_fn: + dx, task = a2a_async_fn(dx) + task.wait() + return dx, probs_grad + + expert_w2 = [x.down_proj.weight for x in self.custom_map.experts if x is not None] + expert_w1 = [x.up_gate_proj.weight for x in self.custom_map.experts if x is not None] + + if self.recompute_fwd_gate_up: + o1 = self.fwd_gate_up(self.input, expert_w1, len(expert_w1), self.tokens_per_expert) + else: + o1 = self.o1 + + do1, o2_s, probs_grad = self.bwd_down_input(expert_w2, out_grad, o1) + del o1 + if not self.recompute_fwd_gate_up: + self.o1 = None + + if self.dequant_input: + input = paddle.incubate.nn.functional.fused_act_dequant(self.input_fp8, self.input_scale) + self.input_scale = None + else: + input = self.input + + if a2a_async_fn is None: + # dw1 + self.bwd_gate_up_weight(do1, input, expert_w1) + del input + + if not self.dequant_input: + self.input = None + # dx + dx = self.bwd_gate_up_input(do1, expert_w1) + + # release do1 and input + del do1 + + self.bwd_down_weight(out_grad, o2_s, expert_w2) + else: + # dx + dx = self.bwd_gate_up_input(do1, expert_w1) + + dx, task = a2a_async_fn(dx) + + # dw1 + self.bwd_gate_up_weight(do1, input, expert_w1) + del input + + if not self.dequant_input: + self.input = None + + # release do1 and input + del do1 + + self.bwd_down_weight(out_grad, o2_s, expert_w2) + + task.wait() + + self.reset_status() + return dx, probs_grad + + +class ExpertsGroupGemmWLCHNode(ExpertsGroupGemmContiguousNode): + """ + Node for performing grouped GEMM operations with W-L-C-H memory layout. + + This specialized version optimizes for distributed MoE scenarios with: + - World-size (W) dimension for distributed expert parallelism + - Local-expert (L) dimension for per-node expert processing + - Channel (C) dimension for feature processing + - Hidden (H) dimension for output features + + Inherits from ExpertsGroupGemmContiguousNode and adds: + - W-L-C-H memory layout optimizations + - Specialized fused transpose/split/quant operations + - Distributed expert parallelism support + """ + + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + dequant_input=False, + group=None, + name="experts_group_gemm_WLCH_node", + ): + """ + Initialize the ExpertsGroupGemmWLCHNode. + + Args: + custom_map (CustomMap): Configuration mapping for expert operations + recompute_fwd_gate_up (bool): Whether to recompute gate projection in backward pass + dequant_input (bool): Whether to dequantize input tensors + name (str): Optional name for the node + + Attributes: + w (int): World size for distributed expert parallelism + l (int): Number of local experts per node + fp8_fused_ops_configs (Dict): Configuration for FP8 fused operations + """ + super().__init__( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + dequant_input=dequant_input, + group=group, + name=name, + ) + + self.fp8_fused_ops_configs["transpose_split_quant"] = True + self.fp8_fused_ops_configs["split_group_gemm"] = False + + self.w = custom_map.world_size + self.l = custom_map.num_local_experts + + def gen_m_indices(self, tokens_per_expert): + """ + Generate expert indices tensor with W-L-C-H memory layout. + + Args: + tokens_per_expert (List[int]): Token distribution across experts + + Returns: + Tensor: Contiguous tensor of expert indices with W-L-C-H layout + + Note: + - Creates indices tensor optimized for distributed expert parallelism + - Layout: [World_size, Local_experts, Channels, Hidden] + - Ensures contiguous memory access across distributed experts + """ + m_indices = paddle.arange(self.l, dtype=paddle.int32).repeat_interleave(tokens_per_expert[0]) + m_indices = m_indices.reshape([self.w, self.l, -1]).transpose([1, 0, 2]).contiguous().reshape([-1]) + + return m_indices + + def fused_transpose_split_quant(self, x, tokens_per_expert, pow_2_scales): + s, h = x.shape + x = x.reshape([self.w, self.l, -1, h]) + out, scale = paddle.incubate.nn.functional.fused_transpose_wlch_split_quant( + x, tokens_per_expert, pow_2_scales=pow_2_scales + ) + return out, scale diff --git a/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/moe_utils.py b/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/moe_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a35383440248b1ab45a0642e01ba049180af1fff --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/moe/token_dispatcher/moe_utils.py @@ -0,0 +1,369 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 DeepSeek +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import numpy as np +import paddle +from paddle import framework + + +def inplace_offload(x): + """Offload tensor to CPU in-place to save GPU memory. + + Args: + x (paddle.Tensor): The tensor to be offloaded to CPU. + + Note: + This operation modifies the tensor in-place by sharing data with a CPU copy. + """ + if not x.place._equals(paddle.CPUPlace()): + y = x.cpu() + if y is not x: + x_t = x.value().get_tensor() + y_t = y.value().get_tensor() + x_t._share_data_with(y_t) + + +def inplace_offload_if_needed(x, threshold=2 * 1024 * 1024 * 1024): + """Conditionally offload tensor to CPU if it exceeds memory threshold. + + Args: + x (paddle.Tensor): The tensor to potentially offload. + threshold (int, optional): Memory threshold in bytes. Defaults to 2GB. + + Note: + Only offloads tensors during gradient computation when memory usage exceeds threshold. + Issues a warning when offloading occurs. + """ + if not framework._dygraph_tracer()._has_grad: + return + + memory_size = np.prod(x.shape) * paddle.core.size_of_dtype(x.dtype) + if memory_size >= threshold: + inplace_offload(x) + warnings.warn(f"Offload tensor with shape: {x.shape}, dtype: {x.dtype}, memory size {memory_size}") + + +def topk_to_permuted_indices_single(x, num_tokens, expert_id, topk): + """Convert topk indices to permuted indices for a single expert. + + Args: + x (paddle.Tensor): Input tensor containing expert assignments. + num_tokens (int): Number of tokens assigned to this expert. + expert_id (int): ID of the expert to filter for. + topk (int): Number of experts selected per token (top-k value). + + Returns: + tuple: (token_permuted_indices, prob_permuted_indices) + - token_permuted_indices: Indices of tokens assigned to this expert + - prob_permuted_indices: Indices of probabilities for the expert assignments + """ + x = paddle.flatten(x) + prob_permuted_indices = paddle.tensor.search._restrict_nonzero(x == expert_id, num_tokens).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices + + +def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk): + """Convert topk indices to permuted indices for all experts. + + Args: + x (paddle.Tensor): Input tensor containing expert assignments. + num_tokens_per_expert_list (list[int]): List of token counts per expert. + topk (int): Number of experts selected per token (top-k value). + + Returns: + tuple: (token_permuted_indices, prob_permuted_indices) + - token_permuted_indices: Indices of tokens assigned to experts + - prob_permuted_indices: Indices of probabilities for all expert assignments + """ + x = paddle.flatten(x) + prob_permuted_indices = paddle.concat( + [ + paddle.tensor.search._restrict_nonzero(x == i, total_true_num) + for i, total_true_num in enumerate(num_tokens_per_expert_list) + ] + ).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices + + +def permute( + tokens, + token_permuted_indices, + drop_and_pad: bool = False, +): + """Permute tokens based on expert assignment indices. + + Args: + tokens (paddle.Tensor): Input tokens to be permuted. + token_permuted_indices (paddle.Tensor): Indices for permutation. + drop_and_pad (bool, optional): Whether to drop and pad tokens. Not supported yet. + + Returns: + paddle.Tensor: Permuted tokens. + + Raises: + AssertionError: If drop_and_pad is True (not supported). + """ + assert not drop_and_pad, "token-drop and pads is not supported" + permuted_input = paddle.gather(tokens, token_permuted_indices) + return permuted_input + + +def unpermute( + permuted_tokens: paddle.Tensor, + token_permuted_indices: paddle.Tensor, + prob_permuted_indices: paddle.Tensor, + restore_shape: paddle.shape, + probs: paddle.Tensor = None, + drop_and_pad: bool = False, +): + """Restore original token order from permuted tokens. + + Args: + permuted_tokens (paddle.Tensor): Permuted tokens to be restored. + token_permuted_indices (paddle.Tensor): Original token positions. + prob_permuted_indices (paddle.Tensor): Indices for probability values. + restore_shape (paddle.shape): Original shape of the tensor. + probs (paddle.Tensor, optional): Probability values for weighted restoration. + drop_and_pad (bool, optional): Whether to drop and pad tokens. Not supported yet. + + Returns: + paddle.Tensor: Restored tokens in original order. + + Raises: + AssertionError: If drop_and_pad is True (not supported). + """ + assert not drop_and_pad, "token-drop and pads is not supported" + _, hidden = restore_shape + if probs is not None: + permuted_probs = paddle.gather(probs.flatten(), prob_permuted_indices) + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + + output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype) + output_tokens.scatter_(index=token_permuted_indices, updates=permuted_tokens, overwrite=False) + return output_tokens + + +class UnZipNode: + """Handles the unzipping (high performance permute) of tokens for expert processing in Mixture of Experts, + in an efficient, deterministic manner. + + This class manages the process of expanding tokens assigned to experts, including: + - Forward pass: Distributes tokens to experts + - Backward pass: Collects gradients from experts + + Attributes: + token_dispatcher: Reference to the parent token dispatcher. + name (str): Identifier for this node. + unzipped_probs (paddle.Tensor): Probability values after unzipping. + zipped_expertwise_rowmap (paddle.Tensor): Mapping between original and expanded tokens. + """ + + def __init__(self, token_dispatcher, name="unzip"): + """Initialize the UnZipNode. + + Args: + token_dispatcher: Parent token dispatcher instance. + name (str, optional): Name identifier. Defaults to "unzip". + """ + self.token_dispatcher = token_dispatcher + self.name = name + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + def reset_status(self): + """Reset internal state between forward/backward passes.""" + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + @paddle.no_grad() + def forward( + self, + hs_2d_dispatched, + dispatched_indices, + dispatched_probs, + topk, + num_experts, + tokens_per_expert, + ): + """Forward pass - distribute tokens to experts. + + Args: + hs_2d_dispatched (paddle.Tensor): Dispatched hidden states (2D). + dispatched_indices (paddle.Tensor): Indices of expert assignments. + dispatched_probs (paddle.Tensor): Routing probabilities. + topk (int): Number of experts selected per token. + num_experts (int): Total number of experts. + tokens_per_expert (int): Tokens allocated per expert. + + Returns: + tuple: (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs) + - unzipped_tokens: Expanded tokens for expert processing + - zipped_expertwise_rowmap: Mapping between original and expanded tokens + - unzipped_probs: Expanded routing probabilities + """ + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + _, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched, + None, + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + self.unzipped_probs = unzipped_probs + self.zipped_expertwise_rowmap = zipped_expertwise_rowmap + return ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + ) + + @paddle.no_grad() + def backward(self, dx, hidden_states_out_grad, probs_grad, dispatched_indices, num_experts): + """Backward pass - collect gradients from experts. + + Args: + dx (paddle.Tensor): Gradient from experts. + hidden_states_out_grad (paddle.Tensor): Gradient of output hidden states. + probs_grad (paddle.Tensor): Gradient of routing probabilities. + dispatched_indices (paddle.Tensor): Original expert assignment indices. + num_experts (int): Total number of experts. + + Returns: + tuple: (weighted_zipped_tokens, probs_grad_zipped) + - weighted_zipped_tokens: Compressed gradients from experts + - probs_grad_zipped: Compressed probability gradients + """ + with paddle.amp.auto_cast(False): + weighted_zipped_tokens, probs_grad_zipped = paddle.nn.functional.moe_unpermute( + dx, + self.zipped_expertwise_rowmap, + dispatched_indices, + probs_grad, + total_zipped_tokens=hidden_states_out_grad.shape[0], + num_experts=num_experts, + ) + self.reset_status() + return weighted_zipped_tokens, probs_grad_zipped + + +class ZipNode: + """Handles the zipping (high performance unpermute) of expert outputs in Mixture of Experts, + in an efficient, deterministic manner. + + This class manages the process of combining expert outputs, including: + - Forward pass: Combines expert outputs + - Backward pass: Distributes gradients to experts + + Attributes: + token_dispatcher: Reference to the parent token dispatcher. + name (str): Identifier for this node. + """ + + def __init__(self, token_dispatcher, name="zip"): + """Initialize the ZipNode. + + Args: + token_dispatcher: Parent token dispatcher instance. + name (str, optional): Name identifier. Defaults to "zip". + """ + self.token_dispatcher = token_dispatcher + self.name = name + + @paddle.no_grad() + def forward( + self, + expert_out, + zipped_expertwise_rowmap, + routemap_topk, + unzipped_probs, + total_zipped_tokens, + num_experts, + ): + """Forward pass - combine expert outputs. + + Args: + expert_out (paddle.Tensor): Outputs from all experts. + zipped_expertwise_rowmap (paddle.Tensor): Mapping between original and expanded tokens. + routemap_topk (paddle.Tensor): Top-k routing information. + unzipped_probs (paddle.Tensor): Expanded routing probabilities. + total_zipped_tokens (int): Total number of original tokens. + num_experts (int): Total number of experts. + + Returns: + paddle.Tensor: Combined expert outputs. + """ + with paddle.amp.auto_cast(False): + expert_out_zipped, zipped_probs_topk = paddle.nn.functional.moe_unpermute( + expert_out, + zipped_expertwise_rowmap, + routemap_topk, + unzipped_probs, + total_zipped_tokens, + num_experts, + ) + return expert_out_zipped + + @paddle.no_grad() + def backward( + self, + grad_output, + dispatched_indices, + dispatched_probs, + top_k, + num_experts, + tokens_per_expert, + ): + """Backward pass - distribute gradients to experts. + + Args: + grad_output (paddle.Tensor): Gradient of the combined output. + dispatched_indices (paddle.Tensor): Original expert assignment indices. + dispatched_probs (paddle.Tensor): Original routing probabilities. + top_k (int): Number of experts selected per token. + num_experts (int): Total number of experts. + tokens_per_expert (int): Tokens allocated per expert. + + Returns: + paddle.Tensor: Expanded gradients to be sent to experts. + """ + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + _, + ) = paddle.nn.functional.moe_permute( + grad_output, + None, + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + + return unzipped_grad diff --git a/ernie/ERNIE/examples/pre-training/models/moe/top2_gate.py b/ernie/ERNIE/examples/pre-training/models/moe/top2_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..a6cef569a23b9b34499428708c22c546b72b7c84 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/moe/top2_gate.py @@ -0,0 +1,434 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import partial + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.incubate.nn.functional import cal_aux_loss, int_bincount +from paddle.utils import unique_name + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} + +logger = logging.getLogger(__name__) + + +def cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + global_aux_loss=False, + rank=None, + group=None, +): + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if tokens_mask is not None and gate_prob.shape[0] != dispatch_tokens_mask.shape[0]: + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + if global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=group) + dist.all_gather(ce_list, ce, group=group) + + me_list[rank] = me + ce_list[rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / moe_k + + if scale is not None: + l_aux = l_aux + (scale - 1) * l_aux.detach() + + return l_aux + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +class CalAuxLossFunctor(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min=1e-6, + ): + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + loss, seqlen_float, ce = cal_aux_loss( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min, + ) + ctx.save_for_backward(gate_prob, seqlen_float, ce) + ctx.num_experts = num_experts + ctx.use_group = use_group + ctx.moe_k = moe_k + return loss + + @staticmethod + def backward(ctx, out_grad): + gate_prob, seqlen_float, ce = ctx.saved_tensor() + num_experts = ctx.num_experts + use_group = ctx.use_group + moe_k = ctx.moe_k + return paddle._C_ops.cal_aux_loss_grad(gate_prob, seqlen_float, ce, out_grad, num_experts, use_group, moe_k) + + +def cast_if_needed(x, dtype): + return x.cast(dtype) if x.dtype != dtype else x + + +class FusedGateDetachMatmul(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, w): + ctx.dtype = paddle.float32 + ctx.save_for_backward(x, w) + return F.linear(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype)) + + @staticmethod + def backward(ctx, y_grad): + x, w = ctx.saved_tensor() + assert ctx.dtype == y_grad.dtype, "dtype not match" + x_g, w_g = paddle._C_ops.matmul_grad( + cast_if_needed(x, ctx.dtype), + cast_if_needed(w, ctx.dtype), + y_grad, + False, + False, + ) + return cast_if_needed(x_g, x.dtype), cast_if_needed(w_g, w.dtype) + + +def gate_detach_matmul(x, weight, use_fuse): + if use_fuse: + return FusedGateDetachMatmul.apply(x, weight) + else: + x = cast_if_needed(x, paddle.float32) + return F.linear(x, weight) + + +@paddle.no_grad() +def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): + n, _ = M.shape + P = F.softmax(-M / lam) + u = paddle.zeros(n, "float32") + for _ in range(max_iters): + if (u - P.sum(1)).abs().max() < epsilon: + break + u = P.sum(1) + P *= (r / (u + 1e-8)).reshape((-1, 1)) + P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) + P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) + return P, _ + + +class Top2Gate(nn.Layer): + """Gating network for Top-2 Mixture of Experts (MoE) routing. + + This gate computes routing weights for each token and selects the top-2 experts + for each input token. Supports both standard and balanced routing strategies. + + Attributes: + config: Configuration object containing hyperparameters. + layer_idx (int): Identifier for the layer in the overall model. + group (dist.ProcessGroup): Process group for distributed computation. + gate_weight (nn.Parameter, optional): Learnable gating weights. + """ + + def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: + """Initialize the Top-2 gating network. + + Args: + config: Configuration object containing: + layer_idx (int): Identifier for this gating layer (used for logging). + group (dist.ProcessGroup): Process group for distributed operations. + gate_weight (nn.Parameter, optional): Pre-initialized gating weight matrix. + If None, will be initialized internally. Shape: (d_model, num_experts). + """ + + super().__init__() + + self.config = config + self.fuse_gate_detach_matmul = config.fuse_gate_detach_matmul + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = config.moe_num_experts + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + self.global_aux_loss = config.global_aux_loss + if self.global_aux_loss: + self.rank = dist.get_rank(self.group) + + self.use_correction_bias = config.moe_use_aux_free + + if config.moe_gate_act == "softmax": + self.act = partial(F.softmax, axis=-1) + elif config.moe_gate_act == "sigmoid": + self.act = F.sigmoid + else: + raise ValueError(f"{config.moe_gate_act} is not supported.") + + self.expert_drop = False + self.norm_gate_logits = config.moe_norm_gate_logits + self.one = paddle.ones([], dtype="float32") + + self.moe_aux_loss_lambda = paddle.to_tensor(config.moe_aux_loss_lambda, dtype="float32") + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + + self.experts_type_ids = None + + self.eps = paddle.to_tensor([1e-12], dtype="float32") + self.num_experts_list = [self.num_experts] + if gate_weight is not None: + self.weight = gate_weight + logger.info("moe use gate_weight from outside") + self._cast_to_low_precision = False + self._cast_to_low_precision = False + else: + self._create_gate_parameter() + + def _create_gate_parameter(self): + self.weight = self.create_parameter( + shape=[self.model_dim, self.num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + + self._cast_to_low_precision = False + self._cast_to_low_precision = False + + def forward( + self, + input, + token_type_ids, + transform_weight, + correction_bias, + ): + orig_dtype = input.dtype + weight = self.weight + with paddle.amp.auto_cast(False): + logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + l_aux, + ) = self.top2_gating(logits, correction_bias=correction_bias) + router_loss = l_aux * self.moe_aux_loss_lambda + router_loss.stop_gradient = False + + combine_weights = combine_weights.cast(orig_dtype) + return ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + logits, + ) + + def get_capacity(self, num_tokens, cap_factor=None): + num_experts = self.num_experts + if cap_factor is not None: + cap = cap_factor + else: + if self.training: + cap = self.cap[0] + elif num_tokens < num_experts: + cap = self.cap[2] + else: + cap = self.cap[1] + capacity = int(cap * num_tokens // num_experts) + assert capacity > 0, f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + return capacity + + def top2_gating(self, logits, cap=None, correction_bias=None): + gates = self.act(logits) + + assert logits.ndim == 2, logits.shape + num_experts = gates.shape[1] + capacity = self.get_capacity(logits.shape[0], cap) + + score_for_argmax = gates + correction_bias.unsqueeze(0) if correction_bias is not None else gates + indices1_s = paddle.argmax(score_for_argmax, axis=1) + mask1 = F.one_hot(indices1_s, num_classes=num_experts).cast(paddle.int64) + + l_aux = self._cal_aux_loss(gates, mask1.sum(axis=0), self.num_experts_tensor) + logits_w_noise = logits + + logits_except1 = masked_fill(logits_w_noise, mask1.cast(paddle.bool), float("-inf")) + score_for_argmax = ( + self.act(logits_except1) + correction_bias.unsqueeze(0) if correction_bias is not None else logits_except1 + ) + indices2_s_original = paddle.argmax(score_for_argmax, axis=1) + + mask2 = F.one_hot(indices2_s_original, num_classes=self.num_experts).cast(paddle.int64) + + locations1 = paddle.cumsum(mask1, axis=0) - 1 + locations2 = paddle.cumsum(mask2, axis=0) - 1 + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = (gates * mask1_float).sum(axis=-1) + gates2_s = (gates * mask2_float).sum(axis=-1) + + if self.norm_gate_logits: + denom_s = gates1_s + gates2_s + denom_s = paddle.clip(denom_s, min=1e-6) + gates1_s /= denom_s + gates2_s /= denom_s + if self.training and self.expert_drop: + gates2_s = paddle.where( + 2 * gates2_s < paddle.rand_like(gates2_s), + paddle.zeros_like(gates2_s), + gates2_s, + ) + + gates1 = gates1_s.unsqueeze(1) * mask1_float + gates2 = gates2_s.unsqueeze(1) * mask2_float + + expert1_index = paddle.argmax(gates1, -1) + combine1_weight = paddle.max(gates1, -1, keepdim=True) + scatter1_index = expert1_index * capacity + locations1_s + scatter1_index = scatter1_index.cast("int64") + dispatch1_mask = combine1_weight.cast(paddle.bool).detach() + + expert2_index = paddle.argmax(gates2, -1) + combine2_weight = paddle.max(gates2, -1, keepdim=True) + scatter2_index = expert2_index * capacity + locations2_s + scatter2_index = scatter2_index.cast("int64") + dispatch2_mask = combine2_weight.cast(paddle.bool).detach() + + return ( + capacity, + paddle.concat((dispatch1_mask, dispatch2_mask), 1), + paddle.concat((combine1_weight, combine2_weight), 1), + paddle.stack((scatter1_index, scatter2_index), 1), + l_aux, + ) + + def _cal_aux_loss( + self, + gate_prob, + dispatch_mask, + num_experts=None, + use_group=None, + tokens_mask=None, + dispatch_tokens_mask=None, + ): + if self.act is F.sigmoid: + gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) + + if self.use_correction_bias: + if tokens_mask is not None: + gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] + if gate_prob_this_modality.shape[0]: + _, top_idx = gate_prob_this_modality.topk(k=self.config.moe_k, axis=-1) + dispatch_mask = int_bincount(top_idx, 0, gate_prob.shape[-1], paddle.int64) + else: + dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") + dist.stream.all_reduce( + dispatch_mask, + group=self.group, + use_calc_stream=True, + ) + else: + _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) + dispatch_mask = int_bincount(top_idx, 0, gate_prob.shape[-1], paddle.int64) + + if num_experts is None: + num_experts = self.num_experts_tensor + if use_group is None: + use_group = self.config.moe_group_experts + + return CalAuxLossFunctor.apply( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + clip_min=1e-6, + ) + + +class TopKGateFused(Top2Gate): + def forward( + self, + input, + token_type_ids=None, + transform_weight=True, + ): + capacity = self.get_capacity(input.shape[0]) + weight = self.weight + with paddle.amp.auto_cast(False): + logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + + return logits, capacity, router_loss diff --git a/ernie/ERNIE/examples/pre-training/models/sequence_parallel_utils.py b/ernie/ERNIE/examples/pre-training/models/sequence_parallel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dfad163688fc3f8131f725b921e5b92be55214c3 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/sequence_parallel_utils.py @@ -0,0 +1,590 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import paddle +from models.comm_utils import ( + all_gather, + reduce_scatter, + scatter, +) +from paddle import distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients_with_group, +) +from paddle.incubate.tensor.manipulation import create_async_load +from paddle.nn import functional as F +from paddle.nn.layer.layers import Layer + +try: + from paddle.nn.functional import all_gather_gemm, flux, gemm_reduce_scatter +except ImportError: + gemm_reduce_scatter = None + all_gather_gemm = None + flux = None + +logger = logging.getLogger(__name__) + + +def get_hcg(): + return fleet.get_hybrid_communicate_group() + + +async_loader = None + + +def get_async_loader(): + global async_loader + if not hasattr(fleet.fleet, "_hcg"): + if async_loader is None: + async_loader = create_async_load() + return async_loader + + hcg = get_hcg() + if not hasattr(hcg, "async_loader"): + hcg.async_loader = create_async_load() + return hcg.async_loader + + +def hack_offload_wait(task): + task.cpu_wait() + + +def hack_reload_wait(task): + task.cuda_wait() + + +class ScatterOp(PyLayer): + @staticmethod + def forward(ctx, input, axis=0, group=None): + ctx.axis = axis + ctx.group = group + return scatter(input, axis=axis, group=ctx.group) + + @staticmethod + def backward(ctx, grad): + return all_gather(grad, axis=ctx.axis, group=ctx.group) + + +class GatherOp(PyLayer): + @staticmethod + def forward(ctx, input, axis=0, group=None): + ctx.axis = axis + ctx.group = group + return all_gather(input, axis=axis, group=group) + + @staticmethod + def backward(ctx, grad): + return scatter(grad, axis=ctx.axis, group=ctx.group) + + +class AllGatherOp(PyLayer): + @staticmethod + def forward(ctx, input, group=None): + ctx.group = group + return all_gather(input, group=group) + + @staticmethod + def backward(ctx, grad): + return reduce_scatter(grad, group=ctx.group) + + +class ReduceScatterOp(PyLayer): + @staticmethod + def forward(ctx, input, group=None): + + ctx.group = group + return reduce_scatter(input, group=group) + + @staticmethod + def backward(ctx, grad): + return all_gather(grad, group=ctx.group) + + +class AllGatherVarlenOp(PyLayer): + @staticmethod + def forward(ctx, input, group=None): + hcg = fleet.get_hybrid_communicate_group() + if group is None: + group = hcg.get_model_parallel_group() + + shape0 = paddle.to_tensor([input.shape[0]]) + shape0_all = paddle.empty(shape=[group.nranks], dtype=shape0.dtype) + dist.stream.all_gather(shape0_all, shape0, group=group, use_calc_stream=True) + shape0_all = shape0_all.numpy() + max_shape0 = shape0_all.max() + + indices = [] + for idx, s in enumerate(shape0_all): + offset = idx * max_shape0 + indices.append(list(range(offset, offset + s))) + indices = np.concatenate(indices, axis=0) + indices = indices.reshape([-1] + [1] * (len(input.shape) - 1)) + indices = paddle.to_tensor(indices, dtype=paddle.int32) + + padding = max_shape0 - input.shape[0] + + ctx.shape0 = input.shape[0] + ctx.max_shape0 = max_shape0 + ctx.shape0_all = shape0_all + ctx.padding = padding + ctx.indices = indices + ctx.group = group + + if padding > 0: + input_shape = input.shape + input_shape[0] = padding + padding_tensor = paddle.empty(shape=input_shape, dtype=input.dtype) + input = paddle.concat([input, padding_tensor], axis=0) + output = all_gather(input, group) + output = paddle.take_along_axis(output, indices, axis=0) + + return output + + @staticmethod + def backward(ctx, grad): + input_shape = grad.shape + input_shape[0] = ctx.max_shape0 * ctx.shape0_all.shape[0] + output = paddle.zeros(shape=input_shape, dtype=grad.dtype) + + grad = paddle.scatter(output, ctx.indices, grad) + + grad = scatter(grad, ctx.group) + + if ctx.padding > 0: + grad = grad[: ctx.shape0] + return grad + + +class GemmReduceScatterOp(PyLayer): + @staticmethod + def forward(ctx, input, weight, group): + ctx.save_for_backward(input, weight) + ctx.group = group + output = gemm_reduce_scatter(input, weight, group) + return output + + @staticmethod + def backward(ctx, grad): + input, weight = ctx.saved_tensor() + group = ctx.group + if input.stop_gradient and weight.stop_gradient: + return None, None + + if input.stop_gradient: + input_grad = None + grad_parallel = None + else: + input_grad, grad_parallel = all_gather_gemm(grad, weight, group, deepcopy_input_parallel=False) + + if weight.stop_gradient: + weight_grad = None + else: + if grad_parallel is None: + grad_parallel = all_gather(grad) + weight_grad = paddle.matmul(input, grad_parallel, transpose_x=True) + return input_grad, weight_grad + + +class AllGatherGemmOp(PyLayer): + @staticmethod + def forward(ctx, input, weight, group): + output, input_parallel = all_gather_gemm(input, weight, group, deepcopy_input_parallel=True) + ctx.save_for_backward(input_parallel, weight) + ctx.group = group + ctx.input_stop_gradient = input.stop_gradient + return output + + @staticmethod + def backward(ctx, grad): + input_parallel, weight = ctx.saved_tensor() + group = ctx.group + if ctx.input_stop_gradient and weight.stop_gradient: + return None, None + if ctx.input_stop_gradient: + input_grad = None + else: + input_grad = gemm_reduce_scatter(grad, weight, group) + if weight.stop_gradient: + weight_grad = None + else: + weight_grad = paddle.matmul(input_parallel, grad, transpose_x=True) + + return input_grad, weight_grad + + +def sequence_parallel_sparse_mask_labels(labels, ignore_label=-100): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + labels = labels.flatten() + labels_local = paddle.split(labels, group.nranks)[group.rank] + + tgt_index = paddle.nonzero(labels_local != ignore_label).squeeze() + if tgt_index.numel() == 0: + tgt_index = paddle.to_tensor([0]) + + tgt_index = tgt_index.reshape([-1]).astype(paddle.int32) + labels_local_gather = paddle.take_along_axis(labels_local, tgt_index, axis=0) + labels_all_gather = AllGatherVarlenOp.apply(labels_local_gather) + return labels_all_gather, tgt_index.reshape([-1, 1]) + + +def mark_as_sequence_parallel_parameter(parameter): + parameter.sequence_parallel = True + + +def is_sequence_parallel_parameter(parameter): + return getattr(parameter, "sequence_parallel", False) + + +def create_fused_allreduce_gradient_hook(parameter_list, accumulation_steps): + hcg = get_hcg() + group = hcg.get_model_parallel_group() + + step = [0] + accumulation_steps *= len(parameter_list) + + def __impl__(grad): + step[0] += 1 + if step[0] == accumulation_steps: + step[0] = 0 + fused_allreduce_gradients_with_group(parameter_list, group=group, scale=1.0) + return grad + + return __impl__ + + +def create_non_fused_allreduce_gradient_hook(param, model, verbose=False): + hcg = get_hcg() + pg = hcg.get_model_parallel_group().process_group + step = [0] + + @paddle.autograd.no_grad() + def __impl__(): + step[0] += 1 + accumulation_steps = model.accumulate_steps + if verbose: + logger.info( + f'hook called: acc-step={step[0]}/{accumulation_steps}, use_main_grad={hasattr(param, "main_grad")}' + ) + if (step[0] % accumulation_steps) == 0: + step[0] = 0 + if hasattr(param, "main_grad"): + pg.allreduce(param.main_grad).wait() + else: + pg.allreduce(param.grad).wait() + + return __impl__ + + +def register_sequence_parallel_allreduce_hooks(model, fuse_sequence_parallel_allreduce=False): + logger.warning("DO NOT use sphook unless your PyLayer does not trigger param backward hook") + mp_group = get_hcg().get_model_parallel_group() + if mp_group.nranks <= 1: + return + + params = [] + for n, p in model._layers.named_parameters(): + if is_sequence_parallel_parameter(p): + logger.info(f"register bw hook for:{n}") + params.append(p) + logger.info(f"#-sp-sync param:{len(params)}") + + if fuse_sequence_parallel_allreduce: + raise NotImplementedError + else: + for i, p in enumerate(params): + if p.stop_gradient: + continue + hook = create_non_fused_allreduce_gradient_hook(p, model, verbose=False) + p._register_backward_hook(hook) + + +def is_fused_matmul_bias_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + import_module_error = False + try: + from paddle.base import core + except ModuleNotFoundError: + logger.warning("Unable to import paddle.base, are you using paddle latest build?") + import_module_error = True + + if import_module_error: + try: + from paddle.fluid import core + except ModuleNotFoundError: + logger.warning("Unable to import paddle.fluid, are you using paddle latest build?") + return False + return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue") + else: + return False + + +class ColumnSequenceParallelLinear(Layer): + def __init__( + self, + in_features, + out_features, + weight_attr=None, + has_bias=None, + gather_output=True, + fuse_matmul_bias=False, + mp_group=None, + use_rr=False, + name=None, + use_comm=True, + use_tpsp_comm_overlap=False, + ): + super(ColumnSequenceParallelLinear, self).__init__() + + hcg = get_hcg() + self.model_parallel_group = hcg.get_model_parallel_group() if mp_group is None else mp_group + self.world_size = hcg.get_model_parallel_group().nranks if mp_group is None else mp_group.nranks + self._name = name + self.is_mp = self.world_size > 1 + self.use_comm = use_comm + if not self.use_comm: + assert not use_rr, "The moe allgather not compatibale with rr for now." + + self.use_tpsp_comm_overlap = use_tpsp_comm_overlap + if self.use_tpsp_comm_overlap: + assert all_gather_gemm is not None + assert flux is not None + + assert ( + gather_output is False + ), "If sequence_parallel is True, \ + gather_output is False" + + self.gather_output = gather_output + assert out_features % self.world_size == 0, ( + f"Number of column of the weight for linear ({out_features}) must be" + f" divisible by model parallel size ({self.world_size})" + ) + self.output_size_per_partition = out_features // self.world_size + + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[in_features, self.output_size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + else: + self.weight = self.create_parameter( + shape=[in_features, self.output_size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + + self.weight.is_distributed = True if self.is_mp else False + if self.weight.is_distributed: + self.weight.split_axis = 1 + + if has_bias: + self.bias = self.create_parameter( + shape=[self.output_size_per_partition], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype, + is_bias=True, + ) + self.bias.is_distributed = True if self.is_mp else False + if self.bias.is_distributed: + self.bias.split_axis = 0 + else: + self.bias = None + + self.linear = F.linear + + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in ColumnSequenceParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher." + ) + from paddle.incubate.nn.functional import fused_linear + + self.linear = fused_linear + + def forward(self, x, use_comm=True): + if ( + self.use_tpsp_comm_overlap + and self.is_mp + and (use_comm and self.use_comm) + and flux.all_gather_gemm_can_implement(x, self.weight, self.model_parallel_group) + ): + output = AllGatherGemmOp.apply(x, self.weight, self.model_parallel_group) + if self.bias is not None: + output += self.bias + return output + else: + if self.is_mp and (use_comm and self.use_comm): + input_parallel = AllGatherOp.apply(x) + else: + input_parallel = x + + output = self.linear(input_parallel, self.weight, self.bias) + return output + + +class MPScale(PyLayer): + @staticmethod + def forward(ctx, x, mp_degree): + out = paddle.scale(x, 1.0 / mp_degree) + return out + + @staticmethod + def backward(ctx, dout): + return dout + + +class RowSequenceParallelLinear(Layer): + def __init__( + self, + in_features, + out_features, + weight_attr=None, + has_bias=True, + input_is_parallel=False, + fuse_matmul_bias=False, + use_rr=False, + mp_group=None, + name=None, + use_comm=True, + use_tpsp_comm_overlap=False, + ): + super(RowSequenceParallelLinear, self).__init__() + + self.in_features = in_features + self.out_features = out_features + assert ( + input_is_parallel is True + ), "If sequence_parallel is True, \ + input_is_parallel should be true." + + self.input_is_parallel = input_is_parallel + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + self._name = name + self.use_comm = use_comm + if not self.use_comm: + assert not use_rr, "The moe allgather not compatibale with rr for now." + + self.use_tpsp_comm_overlap = use_tpsp_comm_overlap + if self.use_tpsp_comm_overlap: + assert gemm_reduce_scatter is not None + assert flux is not None + + hcg = get_hcg() + self.model_parallel_group = hcg.get_model_parallel_group() if mp_group is None else mp_group + self.world_size = hcg.get_model_parallel_group().nranks if mp_group is None else mp_group.nranks + self.rank = hcg.get_model_parallel_group().rank if mp_group is None else mp_group.rank + + self.is_mp = self.world_size > 1 + assert in_features % self.world_size == 0, ( + f"Number of row of the weight for linear ({in_features}) must be" + f" divisible by model parallel size ({self.world_size})" + ) + + self.input_size_per_partition = in_features // self.world_size + + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[self.input_size_per_partition, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + else: + self.weight = self.create_parameter( + shape=[self.input_size_per_partition, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + + self.weight.is_distributed = True if self.is_mp else False + if self.weight.is_distributed: + self.weight.split_axis = 0 + + if has_bias: + self.bias = self.create_parameter( + shape=[self.out_features], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype, + is_bias=True, + ) + if self.is_mp: + mark_as_sequence_parallel_parameter(self.bias) + else: + self.bias = None + + self.linear = F.linear + self.mp_scale = None + + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in RowParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher." + ) + from paddle.incubate.nn.functional import fused_linear + + self.linear = fused_linear + + def forward(self, x): + input_parallel = x + if self.is_mp: + if self.mp_scale is not None: + bias = self.mp_scale(self.bias, self.world_size) + else: + bias = None + + if ( + self.use_tpsp_comm_overlap + and self.use_comm + and flux.gemm_reduce_scatter_can_implement(x, self.weight, self.model_parallel_group) + ): + output_ = GemmReduceScatterOp.apply(x, self.weight, self.model_parallel_group) + if bias is not None: + output_ = output_ + bias + else: + output_parallel = self.linear(input_parallel, self.weight, bias) + if self.use_comm: + output_ = ReduceScatterOp.apply(output_parallel) + else: + output_ = output_parallel + + if bias is None and self.bias is not None and self.use_comm: + output = output_ + self.bias + else: + output = output_ + else: + output = self.linear(input_parallel, self.weight, self.bias) + return output diff --git a/ernie/ERNIE/examples/pre-training/models/utils.py b/ernie/ERNIE/examples/pre-training/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6abda1f79161f0cde7bc71690aaf9f87563c2f26 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/models/utils.py @@ -0,0 +1,209 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, List + +import paddle +from paddle import framework + +logger = logging.getLogger(__name__) + +try: + import moe_permutation + +except ImportError: + moe_permutation = None + logger.warning("moe_permutation is not installed.") + + +def get_global_training_logs(): + try: + from src.utils.misc import global_training_logs + + return global_training_logs + except (ImportError, ModuleNotFoundError): + pass + try: + from rl.utils.stat_utils import global_training_logs + + return global_training_logs + except (ImportError, ModuleNotFoundError): + pass + return {} + + +def global_training_logs_enabled(): + global_training_logs = get_global_training_logs() + return isinstance(global_training_logs, dict) or global_training_logs.is_enabled() + + +def inplace_offload(tensor): + tmp = tensor.pin_memory() if paddle.is_compiled_with_cuda() else tensor.cpu() + tmp._share_buffer_to(tensor) + + +def detach_and_requires_grad_(*args): + ret = [a.detach() if a is not None else None for a in args] + for r, a in zip(ret, args): + if a is not None: + r.stop_gradient = a.stop_gradient + return ret + + +class FakeClone(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, input): + if input.is_contiguous(): + fake_output = paddle.empty_like(input) + input._share_buffer_to(fake_output) + else: + fake_output = input.clone() + return fake_output + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +def manual_backward(f: Callable, is_first_fwd: bool, *args: List[Any]): + tracer = framework._dygraph_tracer() + orig = tracer._has_grad + if not is_first_fwd: + tracer._has_grad = True + + detached_args = detach_and_requires_grad_(*args) + detached_args_clone = [FakeClone.apply(a) if a is not None else None for a in detached_args] + out = f(*detached_args_clone) + if isinstance(out, list): + out = tuple(out) + elif not isinstance(out, tuple): + out = (out,) + + if is_first_fwd: + tracer._has_grad = orig + return None, out + + out_cached = [FakeClone.apply(o) for o in out if o is not None] + + for o in out_cached: + o._clear_dataptr() + tracer._has_grad = orig + + def bwd_f(*grad): + nonlocal out_cached, detached_args, f + grad = list(grad) + grad = [g for g in grad if g is not None] + assert grad and out_cached, (len(grad), len(out_cached)) + grad, out_cached = zip(*[(g, o) for g, o in zip(grad, out_cached) if not o.stop_gradient]) + + assert len(grad) == len(out_cached), (len(grad), len(out_cached), f) + paddle.autograd.backward(out_cached, grad) + return tuple([t.grad for t in detached_args if t is not None]) + + return bwd_f, out + + +class FakeGather(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, input, indices): + assert len(indices.shape) == 1 + ctx.save_for_backward(indices, input.shape) + if indices.shape[0] == 0: + out_shape = input.shape + out_shape[0] = 0 + return paddle.zeros(out_shape, dtype=input.dtype) + return paddle.index_select(input, axis=0, index=indices) + + @staticmethod + def backward(ctx, grad_output): + indices, input_shape = ctx.saved_tensor() + + grad_input = paddle.zeros(input_shape, dtype=grad_output.dtype) + if indices.shape[0] != 0: + paddle.scatter_(grad_input, indices.unsqueeze(-1), grad_output, overwrite=False) + return grad_input, None + + +class FusedUnpermutation(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + output_tokens, + permuted_tokens, + token_permuted_indices, + dispatched_probs, + prob_permuted_indices, + ): + assert token_permuted_indices.stop_gradient, "token_permuted_indices must be stop_gradient" + if dispatched_probs is not None: + assert ( + prob_permuted_indices is not None and prob_permuted_indices.stop_gradient + ), "dispatched_probs must be stop_gradient" + + output_tokens.stop_gradient = False + + src_token_num = permuted_tokens.shape[0] + if src_token_num > 0: + output_tokens = moe_permutation.unpermute( + output_tokens, + permuted_tokens, + token_permuted_indices, + dispatched_probs, + prob_permuted_indices, + ) + else: + output_tokens = FakeClone.apply(output_tokens) + + ctx.save_for_backward( + permuted_tokens, + token_permuted_indices, + dispatched_probs, + prob_permuted_indices, + ) + return output_tokens + + @staticmethod + def backward(ctx, output_tokens_grad): + ( + permuted_tokens, + token_permuted_indices, + dispatched_probs, + prob_permuted_indices, + ) = ctx.saved_tensor() + + src_token_num = permuted_tokens.shape[0] + if src_token_num > 0: + permuted_tokens_grad, dispatched_probs_grad = moe_permutation.unpermute_grad( + output_tokens_grad, + permuted_tokens, + token_permuted_indices, + dispatched_probs, + prob_permuted_indices, + ) + else: + permuted_tokens_grad = paddle.zeros_like(permuted_tokens) + if dispatched_probs is not None: + dispatched_probs_grad = paddle.zeros_like(dispatched_probs) + + if dispatched_probs is None: + return output_tokens_grad, permuted_tokens_grad, None + else: + return ( + output_tokens_grad, + permuted_tokens_grad, + None, + dispatched_probs_grad, + None, + ) diff --git a/ernie/ERNIE/examples/pre-training/tools/README_zh.md b/ernie/ERNIE/examples/pre-training/tools/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..e13b5a640c8273f7f42915a68e8a8a646b8b1a1e --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/tools/README_zh.md @@ -0,0 +1,13 @@ +[English](README.md) | 简体中文 + +# 预训练权重转换工具 +这篇文档介绍如何将我们发布的预训练权重转换为当前模型可以加载的权重格式。 + +## 下载预训练权重 +下载已发布的预训练权重,请参考[Introduction to ERNIE 4.5](/README.md)。 + +## 保存当前模型的checkpoint +运行预训练模型,并得到一份当前模型的checkpoint,运行方式参考[ERNIE-4.5-300B-A47B Pre-Training](/examples/pre-training/README.md)。 + +## 转换权重 +`python convert_ckpt.py --org --cur --dst ` diff --git a/ernie/ERNIE/examples/pre-training/yamls/ci_ce/pretrain_4_nodes_ce.yaml b/ernie/ERNIE/examples/pre-training/yamls/ci_ce/pretrain_4_nodes_ce.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c8fe359cf8109f0b3b20a92b0aa3bdc8606eece --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/yamls/ci_ce/pretrain_4_nodes_ce.yaml @@ -0,0 +1,122 @@ +# -----------环境变量----------------------# +env: + HOME: null + +# ---------------------------model args-------------------------------------------------# +model_args: + model_name_or_path: model_configs/ + tokenizer_name: ./ernie/src/tokenizers/tokenizer_model + output_dir: ./output/eb4p5_turbo + data_load_process_num: 40 + max_seq_length: 4096 + base_seq_length: 4096 + num_consecutive: 32 + + enable_global_training_logs: False + moe_use_aux_free_update_coef: 0.001 + global_logging_interval: 10 + enable_mtp_magic_send: True + + model_config: + num_hidden_layers: 24 + moe_num_experts: 16 + + multi_token_pred_depth: 1 + use_ep_comm_overlap: false + use_combine_before_a2a: true + use_rms_qkv_recompute: true + + moe_logging: True + moe_use_aux_free: true + use_recompute: false + use_fp8_mlp: true + use_fp8_fuse_node: true + fp8_mem_configs: + shared_expert: false + recompute_fwd_gate_up: [4, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 18] + dequant_input: true + fp8_fused_ops_configs: + stack_quant: true + swiglu_probs_bwd: true + split_group_gemm: false + spaq: true + transpose_split_quant: true + + moe_gate: top2_fused + + + +# ---------------------------trainer args-------------------------------------------------# +trainer_args: + input_dir: "0.4 ./demo_data/data-1-part0 0.6 ./demo_data/data-1-part0" + split: "998,1,1" + + use_sp_callback: true + moe_gate_lr_ratio: 0.01 + do_train: True + dataloader_num_workers: 8 + prefetch_factor: 32 + overwrite_output_dir: 1 + disable_tqdm: 1 + logging_steps: 1 + eval_steps: 1000 + eval_iters: -1 + save_steps: 100 + max_steps: 100 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-8 + learning_rate: 2.2e-4 + min_lr: 2.2e-5 + + gradient_accumulation_steps: 90 + per_device_train_batch_size: 2 + per_device_eval_batch_size: 1 + + lr_scheduler: wsd:231084 + decay_function: 1-sqrt + max_grad_norm: 1.0 + use_async_save: True + + weight_decay: 0.1 + warmup_steps: 200 + save_total_limit: 5 + bf16: True + fp16_opt_level: "O2" + use_fp8: False + scale_loss: 4096 + seed: 666 + use_train_part_sharding: 1 + pre_alloc_memory: 70 + offload_optim: True + + pipeline_parallel_degree: 4 + tensor_parallel_degree: 1 + virtual_pp_degree: 1 + data_parallel_degree: 1 + expert_parallel_degree: 8 + sharding: "stage1" + sharding_parallel_degree: 8 + amp_master_grad: 1 + pipeline_parallel_config: enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler + sharding_parallel_config: split_param + sharding_comm_buffer_size_MB: 2048 + tensor_parallel_config: sync_param sync_grad sync_moment + hybrid_parallel_topo_order: sharding_first + + skip_profile_timer: True + ignore_data_skip: 0 + shuffle_consecutive: True + load_sharded_model: True + save_sharded_model: True + save_sharding_stage1_model_include_freeze_params: True + ignore_load_lr_and_optim: False + metrics_output_path: ./output/paddle_distributed_logs/ + pdc_download_ckpt: true + pdc_download_timeout: 300 + use_moe: true + moe_with_send_router_loss: False + moe_group: ep + log_global_grad_norm: True + enable_optimizer_timer: False + gc_interval: 100000 diff --git a/ernie/ERNIE/examples/pre-training/yamls/ci_ce/pretrain_8_gpus_ci.yaml b/ernie/ERNIE/examples/pre-training/yamls/ci_ce/pretrain_8_gpus_ci.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28b821153b018fe17006d382553940970653782c --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/yamls/ci_ce/pretrain_8_gpus_ci.yaml @@ -0,0 +1,121 @@ +# -----------环境变量----------------------# +env: + HOME: null + +# ---------------------------model args-------------------------------------------------# +model_args: + model_name_or_path: model_configs/ + tokenizer_name: ./ernie/src/tokenizers/tokenizer_model + output_dir: ./output/ + data_load_process_num: 40 + max_seq_length: 4096 + base_seq_length: 4096 + num_consecutive: 32 + + enable_global_training_logs: False + moe_use_aux_free_update_coef: 0.001 + global_logging_interval: 10 + enable_mtp_magic_send: True + + model_config: + num_hidden_layers: 6 + moe_num_experts: 8 + moe_k: 4 + moe_capacity: [4,4,4] + + multi_token_pred_depth: 1 + use_ep_comm_overlap: false + use_combine_before_a2a: true + use_rms_qkv_recompute: true + + moe_logging: True + moe_use_aux_free: true + use_recompute: false + use_fp8_mlp: true + use_fp8_fuse_node: true + fp8_mem_configs: + shared_expert: false + recompute_fwd_gate_up: [4, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 18] + dequant_input: true + fp8_fused_ops_configs: + stack_quant: true + swiglu_probs_bwd: true + split_group_gemm: false + spaq: true + transpose_split_quant: true + + moe_gate: top2_fused + + + +# ---------------------------trainer args-------------------------------------------------# +trainer_args: + input_dir: "0.4 ./demo_data/data-1-part0 0.6 ./demo_data/data-1-part0" + split: "998,1,1" + + use_sp_callback: true + moe_gate_lr_ratio: 0.01 + do_train: True + dataloader_num_workers: 8 + prefetch_factor: 32 + overwrite_output_dir: 1 + disable_tqdm: 1 + logging_steps: 1 + eval_steps: 1000 + eval_iters: -1 + save_steps: 100 + max_steps: 100 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-8 + learning_rate: 2.2e-4 + min_lr: 2.2e-5 + + gradient_accumulation_steps: 2 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + + lr_scheduler: wsd:231084 + decay_function: 1-sqrt + max_grad_norm: 1.0 + use_async_save: True + + weight_decay: 0.1 + warmup_steps: 200 + save_total_limit: 5 + bf16: True + fp16_opt_level: "O2" + use_fp8: False + scale_loss: 4096 + seed: 666 + use_train_part_sharding: 1 + offload_optim: True + + pipeline_parallel_degree: 2 + tensor_parallel_degree: 1 + virtual_pp_degree: 1 + data_parallel_degree: 1 + expert_parallel_degree: 4 + sharding: "stage1" + sharding_parallel_degree: 4 + amp_master_grad: 1 + pipeline_parallel_config: enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler + sharding_parallel_config: split_param + sharding_comm_buffer_size_MB: 2048 + tensor_parallel_config: sync_param sync_grad sync_moment + hybrid_parallel_topo_order: sharding_first + + skip_profile_timer: True + ignore_data_skip: 0 + shuffle_consecutive: True + load_sharded_model: True + save_sharded_model: True + save_sharding_stage1_model_include_freeze_params: True + ignore_load_lr_and_optim: False + metrics_output_path: ./output/paddle_distributed_logs/ + use_moe: true + moe_with_send_router_loss: False + moe_group: ep + log_global_grad_norm: True + enable_optimizer_timer: False + gc_interval: 100000 \ No newline at end of file diff --git a/ernie/ERNIE/examples/pre-training/yamls/pretrain_2016_gpus.yaml b/ernie/ERNIE/examples/pre-training/yamls/pretrain_2016_gpus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08ae89fba60bcff8e4de4a1d4deb0724837487a0 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/yamls/pretrain_2016_gpus.yaml @@ -0,0 +1,116 @@ +# -----------环境变量----------------------# +env: + HOME: null + +# ---------------------------model args-------------------------------------------------# +model_args: + model_name_or_path: model_configs/ + tokenizer_name: ./ernie/src/tokenizers/tokenizer_model + output_dir: ./output/ + data_load_process_num: 40 + max_seq_length: 4096 + base_seq_length: 4096 + num_consecutive: 32 + + enable_global_training_logs: False + moe_use_aux_free_update_coef: 0.001 + global_logging_interval: 10 + enable_mtp_magic_send: True + + model_config: + multi_token_pred_depth: 1 + use_ep_comm_overlap: false + use_combine_before_a2a: true + use_rms_qkv_recompute: true + + moe_logging: True + moe_use_aux_free: true + use_recompute: false + use_fp8_mlp: true + use_fp8_fuse_node: true + fp8_mem_configs: + shared_expert: false + recompute_fwd_gate_up: [4, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 18] + dequant_input: true + fp8_fused_ops_configs: + stack_quant: true + swiglu_probs_bwd: true + split_group_gemm: false + spaq: true + transpose_split_quant: true + + moe_gate: top2_fused + + + +# ---------------------------trainer args-------------------------------------------------# +trainer_args: + input_dir: "0.4 ./demo_data/data-1-part0 0.6 ./demo_data/data-1-part0" + split: "998,1,1" + + use_sp_callback: true + moe_gate_lr_ratio: 0.01 + do_train: True + dataloader_num_workers: 8 + prefetch_factor: 32 + overwrite_output_dir: 1 + disable_tqdm: 1 + logging_steps: 1 + eval_steps: 1000 + eval_iters: -1 + save_steps: 100 + max_steps: 100 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-8 + learning_rate: 2.2e-4 + min_lr: 2.2e-5 + + gradient_accumulation_steps: 90 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + + lr_scheduler: wsd:231084 + decay_function: 1-sqrt + max_grad_norm: 1.0 + use_async_save: True + + weight_decay: 0.1 + warmup_steps: 200 + save_total_limit: 5 + bf16: True + fp16_opt_level: "O2" + use_fp8: False + scale_loss: 4096 + seed: 666 + use_train_part_sharding: 1 + pre_alloc_memory: 60 + + pipeline_parallel_degree: 12 + tensor_parallel_degree: 1 + virtual_pp_degree: 1 + data_parallel_degree: 1 + expert_parallel_degree: 8 + sharding: "stage1" + sharding_parallel_degree: 168 + amp_master_grad: 1 + pipeline_parallel_config: enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler + sharding_parallel_config: split_param + sharding_comm_buffer_size_MB: 2048 + tensor_parallel_config: sync_param sync_grad sync_moment + hybrid_parallel_topo_order: sharding_first + + skip_profile_timer: True + ignore_data_skip: 0 + shuffle_consecutive: True + load_sharded_model: True + save_sharded_model: True + save_sharding_stage1_model_include_freeze_params: True + ignore_load_lr_and_optim: False + metrics_output_path: ./output/paddle_distributed_logs/ + use_moe: true + moe_with_send_router_loss: False + moe_group: ep + log_global_grad_norm: True + enable_optimizer_timer: False + gc_interval: 100000 diff --git a/ernie/ERNIE/examples/pre-training/yamls/pretrain_96_gpus.yaml b/ernie/ERNIE/examples/pre-training/yamls/pretrain_96_gpus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b04d96f50cfd93d83f6e1b3a3bac6593713679b4 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/yamls/pretrain_96_gpus.yaml @@ -0,0 +1,117 @@ +# -----------环境变量----------------------# +env: + HOME: null + +# ---------------------------model args-------------------------------------------------# +model_args: + model_name_or_path: model_configs/ + tokenizer_name: ./ernie/src/tokenizers/tokenizer_model + output_dir: ./output/ + data_load_process_num: 40 + max_seq_length: 4096 + base_seq_length: 4096 + num_consecutive: 32 + + enable_global_training_logs: False + moe_use_aux_free_update_coef: 0.001 + global_logging_interval: 10 + enable_mtp_magic_send: True + + model_config: + multi_token_pred_depth: 1 + use_ep_comm_overlap: false + use_combine_before_a2a: true + use_rms_qkv_recompute: true + + moe_logging: True + moe_use_aux_free: true + use_recompute: false + use_fp8_mlp: true + use_fp8_fuse_node: true + fp8_mem_configs: + shared_expert: false + recompute_fwd_gate_up: [4, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 18] + dequant_input: true + fp8_fused_ops_configs: + stack_quant: true + swiglu_probs_bwd: true + split_group_gemm: false + spaq: true + transpose_split_quant: true + + moe_gate: top2_fused + + + +# ---------------------------trainer args-------------------------------------------------# +trainer_args: + input_dir: "0.4 ./demo_data/data-1-part0 0.6 ./demo_data/data-1-part0" + split: "998,1,1" + + use_sp_callback: true + moe_gate_lr_ratio: 0.01 + do_train: True + dataloader_num_workers: 8 + prefetch_factor: 32 + overwrite_output_dir: 1 + disable_tqdm: 1 + logging_steps: 1 + eval_steps: 1000 + eval_iters: -1 + save_steps: 100 + max_steps: 100 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-8 + learning_rate: 2.2e-4 + min_lr: 2.2e-5 + + gradient_accumulation_steps: 1890 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + + lr_scheduler: wsd:231084 + decay_function: 1-sqrt + max_grad_norm: 1.0 + use_async_save: True + + weight_decay: 0.1 + warmup_steps: 200 + save_total_limit: 5 + bf16: True + fp16_opt_level: "O2" + use_fp8: False + scale_loss: 4096 + seed: 666 + use_train_part_sharding: 1 + pre_alloc_memory: 60 + offload_optim: True + + pipeline_parallel_degree: 12 + tensor_parallel_degree: 1 + virtual_pp_degree: 1 + data_parallel_degree: 1 + expert_parallel_degree: 8 + sharding: "stage1" + sharding_parallel_degree: 8 + amp_master_grad: 1 + pipeline_parallel_config: enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler + sharding_parallel_config: split_param + sharding_comm_buffer_size_MB: 2048 + tensor_parallel_config: sync_param sync_grad sync_moment + hybrid_parallel_topo_order: sharding_first + + skip_profile_timer: True + ignore_data_skip: 0 + shuffle_consecutive: True + load_sharded_model: True + save_sharded_model: True + save_sharding_stage1_model_include_freeze_params: True + ignore_load_lr_and_optim: False + metrics_output_path: ./output/paddle_distributed_logs/ + use_moe: true + moe_with_send_router_loss: False + moe_group: ep + log_global_grad_norm: True + enable_optimizer_timer: False + gc_interval: 100000 diff --git a/ernie/ERNIE/examples/pre-training/yamls/pretrain_96_gpus_small_acc.yaml b/ernie/ERNIE/examples/pre-training/yamls/pretrain_96_gpus_small_acc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..acd8cb96e76809f0bb3f9270c11e5100369928c2 --- /dev/null +++ b/ernie/ERNIE/examples/pre-training/yamls/pretrain_96_gpus_small_acc.yaml @@ -0,0 +1,117 @@ +# -----------环境变量----------------------# +env: + HOME: null + +# ---------------------------model args-------------------------------------------------# +model_args: + model_name_or_path: model_configs/ + tokenizer_name: ./ernie/src/tokenizers/tokenizer_model + output_dir: ./output/ + data_load_process_num: 40 + max_seq_length: 4096 + base_seq_length: 4096 + num_consecutive: 32 + + enable_global_training_logs: False + moe_use_aux_free_update_coef: 0.001 + global_logging_interval: 10 + enable_mtp_magic_send: True + + model_config: + multi_token_pred_depth: 1 + use_ep_comm_overlap: false + use_combine_before_a2a: true + use_rms_qkv_recompute: true + + moe_logging: True + moe_use_aux_free: true + use_recompute: false + use_fp8_mlp: true + use_fp8_fuse_node: true + fp8_mem_configs: + shared_expert: false + recompute_fwd_gate_up: [4, 5, 6, 7, 8, 10, 11, 12, 13, 16, 17, 18] + dequant_input: true + fp8_fused_ops_configs: + stack_quant: true + swiglu_probs_bwd: true + split_group_gemm: false + spaq: true + transpose_split_quant: true + + moe_gate: top2_fused + + + +# ---------------------------trainer args-------------------------------------------------# +trainer_args: + input_dir: "0.4 ./demo_data/data-1-part0 0.6 ./demo_data/data-1-part0" + split: "998,1,1" + + use_sp_callback: true + moe_gate_lr_ratio: 0.01 + do_train: True + dataloader_num_workers: 8 + prefetch_factor: 32 + overwrite_output_dir: 1 + disable_tqdm: 1 + logging_steps: 1 + eval_steps: 1000 + eval_iters: -1 + save_steps: 100 + max_steps: 100 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-8 + learning_rate: 2.2e-4 + min_lr: 2.2e-5 + + gradient_accumulation_steps: 90 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + + lr_scheduler: wsd:231084 + decay_function: 1-sqrt + max_grad_norm: 1.0 + use_async_save: True + + weight_decay: 0.1 + warmup_steps: 200 + save_total_limit: 5 + bf16: True + fp16_opt_level: "O2" + use_fp8: False + scale_loss: 4096 + seed: 666 + use_train_part_sharding: 1 + pre_alloc_memory: 60 + offload_optim: True + + pipeline_parallel_degree: 12 + tensor_parallel_degree: 1 + virtual_pp_degree: 1 + data_parallel_degree: 1 + expert_parallel_degree: 8 + sharding: "stage1" + sharding_parallel_degree: 8 + amp_master_grad: 1 + pipeline_parallel_config: enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler + sharding_parallel_config: split_param + sharding_comm_buffer_size_MB: 2048 + tensor_parallel_config: sync_param sync_grad sync_moment + hybrid_parallel_topo_order: sharding_first + + skip_profile_timer: True + ignore_data_skip: 0 + shuffle_consecutive: True + load_sharded_model: True + save_sharded_model: True + save_sharding_stage1_model_include_freeze_params: True + ignore_load_lr_and_optim: False + metrics_output_path: ./output/paddle_distributed_logs/ + use_moe: true + moe_with_send_router_loss: False + moe_group: ep + log_global_grad_norm: True + enable_optimizer_timer: False + gc_interval: 100000