Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ms-swift/.ipynb_checkpoints/clean_transcripts-checkpoint.py +95 -0
- ms-swift/.ipynb_checkpoints/dataset_new-checkpoint.json +0 -0
- ms-swift/silence_overlaps/delete_transcript.json +0 -0
- ms-swift/swift/llm/train/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/llm/train/__pycache__/kto.cpython-310.pyc +0 -0
- ms-swift/swift/megatron/__init__.py +35 -0
- ms-swift/swift/megatron/argument/megatron_args.py +253 -0
- ms-swift/swift/megatron/model/gpt/mcore2hf.py +70 -0
- ms-swift/swift/megatron/train/__init__.py +2 -0
- ms-swift/swift/megatron/train/pt.py +19 -0
- ms-swift/swift/megatron/train/sft.py +65 -0
- ms-swift/swift/plugin/agent_template/__pycache__/glm4.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/agent_template/__pycache__/llama.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/agent_template/__pycache__/qwen.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/agent_template/__pycache__/react.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/agent_template/hermes.py +78 -0
- ms-swift/swift/plugin/agent_template/react.py +66 -0
- ms-swift/swift/plugin/loss_scale/__init__.py +1 -0
- ms-swift/swift/plugin/loss_scale/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/loss_scale/__pycache__/loss_scale.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/loss_scale/config/agentflan.json +22 -0
- ms-swift/swift/plugin/loss_scale/config/hermes.json +3 -0
- ms-swift/swift/plugin/prm.py +154 -0
- ms-swift/swift/plugin/rm_plugin.py +229 -0
- ms-swift/swift/trainers/__init__.py +49 -0
- ms-swift/swift/trainers/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/trainers/__pycache__/callback.cpython-310.pyc +0 -0
- ms-swift/swift/trainers/__pycache__/trainer_factory.cpython-310.pyc +0 -0
- ms-swift/swift/trainers/__pycache__/trainers.cpython-310.pyc +0 -0
- ms-swift/swift/trainers/callback.py +124 -0
- ms-swift/swift/trainers/mixin.py +516 -0
- ms-swift/swift/trainers/optimizers/__init__.py +1 -0
- ms-swift/swift/trainers/optimizers/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/trainers/optimizers/galore/__init__.py +28 -0
- ms-swift/swift/trainers/optimizers/galore/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/trainers/optimizers/galore/adafactor.py +272 -0
- ms-swift/swift/trainers/optimizers/galore/galore_projector.py +109 -0
- ms-swift/swift/trainers/optimizers/galore/utils.py +214 -0
- ms-swift/swift/trainers/rlhf_arguments.py +63 -0
- ms-swift/swift/trainers/rlhf_trainer/kto_trainer.py +69 -0
- ms-swift/swift/trainers/rlhf_trainer/orpo_trainer.py +19 -0
- ms-swift/swift/trainers/rlhf_trainer/ppo_trainer.py +65 -0
- ms-swift/swift/trainers/rlhf_trainer/reward_trainer.py +78 -0
- ms-swift/swift/trainers/rlhf_trainer/rlhf_mixin.py +104 -0
- ms-swift/swift/trainers/rlhf_trainer/utils.py +132 -0
- ms-swift/swift/trainers/rlhf_trainer/vllm_client.py +212 -0
- ms-swift/swift/trainers/sequence_parallel/base.py +45 -0
- ms-swift/swift/trainers/sequence_parallel/ulysses.py +594 -0
- ms-swift/swift/trainers/sequence_parallel/xtuner.py +127 -0
- ms-swift/swift/trainers/torchacc_mixin.py +156 -0
ms-swift/.ipynb_checkpoints/clean_transcripts-checkpoint.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
def parse_timestamp(timestamp: str) -> Tuple[int, int]:
|
| 6 |
+
"""Convert timestamp string like '00:15' to seconds."""
|
| 7 |
+
minutes, seconds = map(int, timestamp.split(':'))
|
| 8 |
+
return minutes * 60 + seconds
|
| 9 |
+
|
| 10 |
+
def extract_time_and_speaker(line: str) -> Tuple[Tuple[int, int], str]:
|
| 11 |
+
"""Extract time range and speaker from a line."""
|
| 12 |
+
# Extract time range
|
| 13 |
+
time_match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\] (Speaker [A-Z]):', line)
|
| 14 |
+
if not time_match:
|
| 15 |
+
return None, None
|
| 16 |
+
|
| 17 |
+
start_time = parse_timestamp(time_match.group(1))
|
| 18 |
+
end_time = parse_timestamp(time_match.group(2))
|
| 19 |
+
speaker = time_match.group(3)
|
| 20 |
+
|
| 21 |
+
return (start_time, end_time), speaker
|
| 22 |
+
|
| 23 |
+
def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
|
| 24 |
+
"""Check if two time ranges overlap."""
|
| 25 |
+
start1, end1 = range1
|
| 26 |
+
start2, end2 = range2
|
| 27 |
+
return not (end1 <= start2 or end2 <= start1)
|
| 28 |
+
|
| 29 |
+
def has_same_speaker_overlap(transcript: str) -> bool:
|
| 30 |
+
"""Check if a transcript contains overlapping timestamps for the same speaker."""
|
| 31 |
+
lines = transcript.split('\n')
|
| 32 |
+
# Dictionary to store time ranges for each speaker
|
| 33 |
+
speaker_ranges = {}
|
| 34 |
+
|
| 35 |
+
for line in lines:
|
| 36 |
+
if not line.strip():
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
time_range, speaker = extract_time_and_speaker(line)
|
| 40 |
+
if time_range is None or speaker is None:
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
# Check for overlaps with existing ranges of the same speaker
|
| 44 |
+
if speaker in speaker_ranges:
|
| 45 |
+
for existing_range in speaker_ranges[speaker]:
|
| 46 |
+
if has_overlap(time_range, existing_range):
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
speaker_ranges[speaker].append(time_range)
|
| 50 |
+
else:
|
| 51 |
+
speaker_ranges[speaker] = [time_range]
|
| 52 |
+
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
def process_file(input_file: str, output_file: str, delete_file: str):
|
| 56 |
+
"""Process the JSON file and separate entries with same-speaker overlapping timestamps."""
|
| 57 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 58 |
+
data = json.load(f)
|
| 59 |
+
|
| 60 |
+
if isinstance(data, dict):
|
| 61 |
+
data = [data]
|
| 62 |
+
|
| 63 |
+
cleaned_data = []
|
| 64 |
+
deleted_data = []
|
| 65 |
+
removed_count = 0
|
| 66 |
+
|
| 67 |
+
for entry in data:
|
| 68 |
+
if 'model_output' in entry:
|
| 69 |
+
if not has_same_speaker_overlap(entry['model_output']):
|
| 70 |
+
cleaned_data.append(entry)
|
| 71 |
+
else:
|
| 72 |
+
deleted_data.append(entry)
|
| 73 |
+
removed_count += 1
|
| 74 |
+
print(f"Removing entry with key: {entry.get('key', 'unknown')}")
|
| 75 |
+
|
| 76 |
+
# Save cleaned data
|
| 77 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 78 |
+
json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
|
| 79 |
+
|
| 80 |
+
# Save deleted data
|
| 81 |
+
with open(delete_file, 'w', encoding='utf-8') as f:
|
| 82 |
+
json.dump(deleted_data, f, ensure_ascii=False, indent=2)
|
| 83 |
+
|
| 84 |
+
print(f"\nProcessing Summary:")
|
| 85 |
+
print(f"Processed {len(data)} entries")
|
| 86 |
+
print(f"Removed {removed_count} entries with same-speaker overlapping timestamps")
|
| 87 |
+
print(f"Remaining entries: {len(cleaned_data)}")
|
| 88 |
+
|
| 89 |
+
if __name__ == '__main__':
|
| 90 |
+
input_file = 'silence_overlaps/transcriptions.json'
|
| 91 |
+
output_file = 'silence_overlaps/cleaned_transcriptions2.json'
|
| 92 |
+
delete_file = 'silence_overlaps/delete_transcript2.json'
|
| 93 |
+
process_file(input_file, output_file, delete_file)
|
| 94 |
+
print(f"\nCleaned transcriptions have been saved to {output_file}")
|
| 95 |
+
print(f"Deleted entries have been saved to {delete_file}")
|
ms-swift/.ipynb_checkpoints/dataset_new-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/delete_transcript.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/swift/llm/train/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (392 Bytes). View file
|
|
|
ms-swift/swift/llm/train/__pycache__/kto.cpython-310.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
ms-swift/swift/megatron/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from .init import init_megatron_env
|
| 5 |
+
init_megatron_env()
|
| 6 |
+
except Exception:
|
| 7 |
+
# allows lint pass.
|
| 8 |
+
raise
|
| 9 |
+
|
| 10 |
+
from typing import TYPE_CHECKING
|
| 11 |
+
|
| 12 |
+
from swift.utils.import_utils import _LazyModule
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from .train import megatron_sft_main, megatron_pt_main
|
| 16 |
+
from .utils import convert_hf2mcore, convert_mcore2hf
|
| 17 |
+
from .argument import MegatronTrainArguments
|
| 18 |
+
from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model
|
| 19 |
+
else:
|
| 20 |
+
_import_structure = {
|
| 21 |
+
'train': ['megatron_sft_main', 'megatron_pt_main'],
|
| 22 |
+
'utils': ['convert_hf2mcore', 'convert_mcore2hf'],
|
| 23 |
+
'argument': ['MegatronTrainArguments'],
|
| 24 |
+
'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model']
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
sys.modules[__name__] = _LazyModule(
|
| 30 |
+
__name__,
|
| 31 |
+
globals()['__file__'],
|
| 32 |
+
_import_structure,
|
| 33 |
+
module_spec=__spec__,
|
| 34 |
+
extra_objects={},
|
| 35 |
+
)
|
ms-swift/swift/megatron/argument/megatron_args.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers.utils.versions import require_version
|
| 9 |
+
|
| 10 |
+
from swift.llm.argument.base_args import to_abspath
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ExtraMegatronArguments:
|
| 15 |
+
padded_vocab_size: Optional[int] = None
|
| 16 |
+
rope_scaling: Optional[Union[dict, str]] = None
|
| 17 |
+
torch_dtype: Optional[torch.dtype] = None
|
| 18 |
+
|
| 19 |
+
dataloader_persistent_workers: bool = True
|
| 20 |
+
dataloader_prefetch_factor: int = 10
|
| 21 |
+
|
| 22 |
+
model_type: Optional[str] = None
|
| 23 |
+
max_epochs: Optional[int] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class MegatronArguments(ExtraMegatronArguments):
|
| 28 |
+
# training
|
| 29 |
+
micro_batch_size: int = 1
|
| 30 |
+
global_batch_size: int = 16
|
| 31 |
+
recompute_granularity: Literal['selective', 'full'] = 'selective'
|
| 32 |
+
recompute_method: Literal['uniform', 'block'] = None
|
| 33 |
+
recompute_num_layers: Optional[int] = None
|
| 34 |
+
recompute_modules: List[str] = field(default_factory=lambda: ['core_attn'])
|
| 35 |
+
use_cpu_initialization: bool = False
|
| 36 |
+
deterministic_mode: bool = False
|
| 37 |
+
train_iters: Optional[int] = None
|
| 38 |
+
log_interval: int = 5
|
| 39 |
+
tensorboard_dir: Optional[str] = None
|
| 40 |
+
no_masked_softmax_fusion: bool = False
|
| 41 |
+
no_bias_dropout_fusion: bool = False
|
| 42 |
+
no_bias_swiglu_fusion: bool = False
|
| 43 |
+
no_rope_fusion: bool = False
|
| 44 |
+
no_gradient_accumulation_fusion: bool = False
|
| 45 |
+
cross_entropy_loss_fusion: bool = False
|
| 46 |
+
calculate_per_token_loss: bool = True
|
| 47 |
+
use_flash_attn: bool = False
|
| 48 |
+
attention_backend: str = 'auto' # flash, fused, unfused, local, auto
|
| 49 |
+
optimizer: Literal['adam', 'sgd'] = 'adam'
|
| 50 |
+
dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic'
|
| 51 |
+
manual_gc: bool = False
|
| 52 |
+
manual_gc_interval: int = 0
|
| 53 |
+
|
| 54 |
+
# learning rate
|
| 55 |
+
lr: float = 1e-5
|
| 56 |
+
lr_decay_style: Literal['cosine', 'linear', 'constant'] = 'cosine'
|
| 57 |
+
# The default is None, which will be set to `train_iters`.
|
| 58 |
+
lr_decay_iters: Optional[int] = None
|
| 59 |
+
lr_warmup_iters: int = 0
|
| 60 |
+
min_lr: float = 0
|
| 61 |
+
|
| 62 |
+
# regularization
|
| 63 |
+
weight_decay: float = 0.1
|
| 64 |
+
clip_grad: float = 1.
|
| 65 |
+
adam_beta1: float = 0.9
|
| 66 |
+
adam_beta2: float = 0.95
|
| 67 |
+
adam_eps: float = 1e-8
|
| 68 |
+
sgd_momentum: float = 0.9
|
| 69 |
+
|
| 70 |
+
# checkpoint
|
| 71 |
+
save: Optional[str] = None
|
| 72 |
+
save_interval: int = 500
|
| 73 |
+
no_save_optim: bool = False
|
| 74 |
+
no_save_rng: bool = False
|
| 75 |
+
load: Optional[str] = None
|
| 76 |
+
no_load_optim: bool = False
|
| 77 |
+
no_load_rng: bool = False
|
| 78 |
+
finetune: bool = False
|
| 79 |
+
ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist'
|
| 80 |
+
no_initialization: bool = True
|
| 81 |
+
auto_detect_ckpt_format: bool = True
|
| 82 |
+
exit_on_missing_checkpoint: bool = True
|
| 83 |
+
|
| 84 |
+
# dist
|
| 85 |
+
distributed_backend: Literal['nccl', 'gloo'] = 'nccl'
|
| 86 |
+
use_distributed_optimizer: bool = True
|
| 87 |
+
tensor_model_parallel_size: int = 1
|
| 88 |
+
pipeline_model_parallel_size: int = 1
|
| 89 |
+
decoder_first_pipeline_num_layers: Optional[int] = None
|
| 90 |
+
decoder_last_pipeline_num_layers: Optional[int] = None
|
| 91 |
+
sequence_parallel: bool = False
|
| 92 |
+
context_parallel_size: int = 1
|
| 93 |
+
tp_comm_overlap: bool = False
|
| 94 |
+
overlap_grad_reduce: bool = False
|
| 95 |
+
overlap_param_gather: bool = False
|
| 96 |
+
distributed_timeout_minutes: int = 60
|
| 97 |
+
|
| 98 |
+
# model
|
| 99 |
+
num_layers: Optional[int] = None
|
| 100 |
+
hidden_size: Optional[int] = None
|
| 101 |
+
ffn_hidden_size: Optional[int] = None
|
| 102 |
+
num_attention_heads: Optional[int] = None
|
| 103 |
+
group_query_attention: Optional[bool] = None
|
| 104 |
+
num_query_groups: Optional[int] = None
|
| 105 |
+
max_position_embeddings: Optional[int] = None
|
| 106 |
+
position_embedding_type: Literal['learned_absolute', 'rope', 'relative', 'none'] = 'rope'
|
| 107 |
+
rotary_base: Optional[int] = None
|
| 108 |
+
rotary_percent: float = 1.
|
| 109 |
+
normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm'
|
| 110 |
+
norm_epsilon: Optional[float] = None
|
| 111 |
+
swiglu: Optional[bool] = None
|
| 112 |
+
untie_embeddings_and_output_weights: Optional[bool] = None
|
| 113 |
+
disable_bias_linear: Optional[bool] = None
|
| 114 |
+
add_qkv_bias: Optional[bool] = None
|
| 115 |
+
attention_dropout: Optional[float] = None
|
| 116 |
+
hidden_dropout: float = 0.
|
| 117 |
+
kv_channels: Optional[int] = None
|
| 118 |
+
qk_layernorm: Optional[bool] = None
|
| 119 |
+
transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine'
|
| 120 |
+
|
| 121 |
+
# moe
|
| 122 |
+
num_experts: Optional[int] = None
|
| 123 |
+
moe_ffn_hidden_size: Optional[int] = None
|
| 124 |
+
moe_shared_expert_intermediate_size: Optional[int] = None
|
| 125 |
+
moe_router_topk: Optional[int] = None
|
| 126 |
+
moe_router_pre_softmax: Optional[bool] = None
|
| 127 |
+
moe_aux_loss_coeff: Optional[float] = None
|
| 128 |
+
|
| 129 |
+
expert_model_parallel_size: int = 1
|
| 130 |
+
moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'alltoall_seq'] = 'alltoall'
|
| 131 |
+
moe_grouped_gemm: bool = False
|
| 132 |
+
moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'] = 'aux_loss'
|
| 133 |
+
moe_z_loss_coeff: Optional[float] = None
|
| 134 |
+
moe_expert_capacity_factor: Optional[float] = None
|
| 135 |
+
moe_shared_expert_overlap: bool = False
|
| 136 |
+
|
| 137 |
+
# mixed precision
|
| 138 |
+
fp16: Optional[bool] = None
|
| 139 |
+
bf16: Optional[bool] = None
|
| 140 |
+
apply_query_key_layer_scaling: Optional[bool] = None
|
| 141 |
+
attention_softmax_in_fp32: bool = True
|
| 142 |
+
|
| 143 |
+
# logging
|
| 144 |
+
log_params_norm: bool = False
|
| 145 |
+
log_throughput: bool = True
|
| 146 |
+
tensorboard_log_interval: int = 1
|
| 147 |
+
tensorboard_queue_size: int = 50
|
| 148 |
+
log_timers_to_tensorboard: bool = True
|
| 149 |
+
no_log_learning_rate_to_tensorboard: bool = False
|
| 150 |
+
log_validation_ppl_to_tensorboard: bool = True
|
| 151 |
+
log_memory_to_tensorboard: bool = True
|
| 152 |
+
logging_level: Optional[str] = None
|
| 153 |
+
wandb_project: Optional[str] = None
|
| 154 |
+
wandb_exp_name: Optional[str] = None
|
| 155 |
+
wandb_save_dir: Optional[str] = None
|
| 156 |
+
|
| 157 |
+
# evaluate
|
| 158 |
+
eval_iters: int = 100
|
| 159 |
+
eval_interval: Optional[int] = None
|
| 160 |
+
|
| 161 |
+
# other
|
| 162 |
+
seed: int = 42
|
| 163 |
+
seq_length: Optional[int] = None
|
| 164 |
+
num_workers: int = 4
|
| 165 |
+
no_create_attention_mask_in_dataloader: bool = True
|
| 166 |
+
|
| 167 |
+
def _set_default(self):
|
| 168 |
+
if self.num_query_groups is None:
|
| 169 |
+
self.num_query_groups = 1
|
| 170 |
+
if self.norm_epsilon is None:
|
| 171 |
+
self.norm_epsilon = 1e-5
|
| 172 |
+
if self.rotary_base is None:
|
| 173 |
+
self.rotary_base = 10000
|
| 174 |
+
if self.attention_dropout is None:
|
| 175 |
+
self.attention_dropout = 0.
|
| 176 |
+
if self.untie_embeddings_and_output_weights is None:
|
| 177 |
+
self.untie_embeddings_and_output_weights = True
|
| 178 |
+
if self.swiglu is None:
|
| 179 |
+
self.swiglu = True
|
| 180 |
+
if self.add_qkv_bias is None:
|
| 181 |
+
self.add_qkv_bias = True
|
| 182 |
+
if self.disable_bias_linear is None:
|
| 183 |
+
self.disable_bias_linear = True
|
| 184 |
+
if self.moe_router_topk is None:
|
| 185 |
+
self.moe_router_topk = 2
|
| 186 |
+
if self.moe_router_pre_softmax is None:
|
| 187 |
+
self.moe_router_pre_softmax = False
|
| 188 |
+
if self.moe_aux_loss_coeff is None:
|
| 189 |
+
self.moe_aux_loss_coeff = 0.
|
| 190 |
+
if self.qk_layernorm is None:
|
| 191 |
+
self.qk_layernorm = False
|
| 192 |
+
|
| 193 |
+
def _init_mixed_precision(self):
|
| 194 |
+
from swift.llm.argument.base_args.model_args import ModelArguments
|
| 195 |
+
ModelArguments._init_mixed_precision(self)
|
| 196 |
+
if self.apply_query_key_layer_scaling is None:
|
| 197 |
+
self.apply_query_key_layer_scaling = self.fp16
|
| 198 |
+
if self.apply_query_key_layer_scaling:
|
| 199 |
+
os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1'
|
| 200 |
+
|
| 201 |
+
def _init_moe(self):
|
| 202 |
+
if self.moe_shared_expert_intermediate_size == 0:
|
| 203 |
+
self.moe_shared_expert_intermediate_size = None
|
| 204 |
+
if self.moe_ffn_hidden_size is None:
|
| 205 |
+
self.moe_ffn_hidden_size = self.ffn_hidden_size
|
| 206 |
+
else:
|
| 207 |
+
self.ffn_hidden_size = self.moe_ffn_hidden_size
|
| 208 |
+
|
| 209 |
+
def __post_init__(self):
|
| 210 |
+
from swift.llm.argument.base_args.model_args import ModelArguments
|
| 211 |
+
if self.use_flash_attn or self.attention_backend == 'flash':
|
| 212 |
+
require_version('flash-attn')
|
| 213 |
+
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
|
| 214 |
+
self._set_default()
|
| 215 |
+
self.group_query_attention = self.num_query_groups > 1
|
| 216 |
+
if self.rope_scaling is not None:
|
| 217 |
+
self.rope_scaling = ModelArguments.parse_to_dict(self.rope_scaling)
|
| 218 |
+
if self.eval_interval is None:
|
| 219 |
+
self.eval_interval = self.save_interval
|
| 220 |
+
if self.seq_length is None:
|
| 221 |
+
self.seq_length = self.max_position_embeddings
|
| 222 |
+
if self.tensorboard_dir is None and self.save is not None:
|
| 223 |
+
self.tensorboard_dir = f'{self.save}/runs'
|
| 224 |
+
self._init_moe()
|
| 225 |
+
self._init_mixed_precision()
|
| 226 |
+
|
| 227 |
+
self.tensorboard_dir = to_abspath(self.tensorboard_dir)
|
| 228 |
+
|
| 229 |
+
def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]:
|
| 230 |
+
new_args = []
|
| 231 |
+
args_dict = asdict(self)
|
| 232 |
+
extra_args = {}
|
| 233 |
+
for k, value in args_dict.items():
|
| 234 |
+
if k not in MegatronArguments.__annotations__:
|
| 235 |
+
extra_args[k] = value
|
| 236 |
+
continue
|
| 237 |
+
if value is None or value is False:
|
| 238 |
+
continue
|
| 239 |
+
new_args.append(f"--{k.replace('_', '-')}")
|
| 240 |
+
if isinstance(value, list):
|
| 241 |
+
new_args += [str(v) for v in value]
|
| 242 |
+
elif value is not True:
|
| 243 |
+
new_args.append(str(value))
|
| 244 |
+
|
| 245 |
+
return new_args, extra_args
|
| 246 |
+
|
| 247 |
+
def parse_to_megatron(self):
|
| 248 |
+
new_args, extra_args = self._args_to_argv()
|
| 249 |
+
sys._old_argv = sys.argv
|
| 250 |
+
sys.argv = sys.argv[:1] + new_args
|
| 251 |
+
# parameter conflict
|
| 252 |
+
extra_args.pop('loss_scale', None)
|
| 253 |
+
return extra_args
|
ms-swift/swift/megatron/model/gpt/mcore2hf.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from megatron.training import get_args
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def set_attn_state(args, mg_attn, hf_attn):
|
| 6 |
+
num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
|
| 7 |
+
# Copy weights
|
| 8 |
+
mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size))
|
| 9 |
+
q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[
|
| 10 |
+
0] // num_query_groups
|
| 11 |
+
hf_attn.q_proj.weight.data.copy_(mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size))
|
| 12 |
+
hf_attn.k_proj.weight.data.copy_(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size))
|
| 13 |
+
hf_attn.v_proj.weight.data.copy_(mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size))
|
| 14 |
+
hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight)
|
| 15 |
+
|
| 16 |
+
# Copy bias
|
| 17 |
+
if args.add_qkv_bias:
|
| 18 |
+
mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1))
|
| 19 |
+
hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1))
|
| 20 |
+
hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
|
| 21 |
+
hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1))
|
| 22 |
+
|
| 23 |
+
if args.qk_layernorm:
|
| 24 |
+
hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight)
|
| 25 |
+
hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _set_mlp_state(mg_mlp, hf_mlp):
|
| 29 |
+
ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0]
|
| 30 |
+
hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:ffn_hidden_size])
|
| 31 |
+
hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[ffn_hidden_size:])
|
| 32 |
+
hf_mlp.down_proj.weight.data.copy_(mg_mlp.linear_fc2.weight)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def set_mlp_state(args, mg_mlp, hf_mlp):
|
| 36 |
+
if args.num_experts:
|
| 37 |
+
hf_mlp.gate.weight.data.copy_(mg_mlp.router.weight)
|
| 38 |
+
if mg_mlp.shared_experts is not None:
|
| 39 |
+
hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight)
|
| 40 |
+
for expert_idx in range(args.num_experts):
|
| 41 |
+
_set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx])
|
| 42 |
+
|
| 43 |
+
if mg_mlp.shared_experts is not None:
|
| 44 |
+
_set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert)
|
| 45 |
+
else:
|
| 46 |
+
_set_mlp_state(mg_mlp, hf_mlp)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def set_layer_state(args, mg_model, hf_model, layer_idx):
|
| 50 |
+
mg_layer = mg_model.decoder.layers[layer_idx]
|
| 51 |
+
hf_layer = hf_model.model.layers[layer_idx]
|
| 52 |
+
set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn)
|
| 53 |
+
set_mlp_state(args, mg_layer.mlp, hf_layer.mlp)
|
| 54 |
+
|
| 55 |
+
post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight
|
| 56 |
+
if args.num_experts:
|
| 57 |
+
post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight)
|
| 58 |
+
else:
|
| 59 |
+
post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight)
|
| 60 |
+
hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def convert_mcore2hf(hf_model, mg_model):
|
| 64 |
+
args = get_args()
|
| 65 |
+
hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight)
|
| 66 |
+
if args.untie_embeddings_and_output_weights:
|
| 67 |
+
hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight)
|
| 68 |
+
hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight)
|
| 69 |
+
for layer_idx in range(args.num_layers):
|
| 70 |
+
set_layer_state(args, mg_model, hf_model, layer_idx)
|
ms-swift/swift/megatron/train/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pt import megatron_pt_main
|
| 2 |
+
from .sft import megatron_sft_main
|
ms-swift/swift/megatron/train/pt.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
from ..argument import MegatronTrainArguments
|
| 5 |
+
from .sft import MegatronSft
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MegatronPt(MegatronSft):
|
| 9 |
+
args_class = MegatronTrainArguments
|
| 10 |
+
args: args_class
|
| 11 |
+
|
| 12 |
+
def _prepare_template(self) -> None:
|
| 13 |
+
self.args.use_chat_template = False
|
| 14 |
+
super()._prepare_template()
|
| 15 |
+
self.template.loss_scale = 'all'
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def megatron_pt_main(args: Union[List[str], MegatronTrainArguments, None] = None):
|
| 19 |
+
return MegatronPt(args).main()
|
ms-swift/swift/megatron/train/sft.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Union
|
| 4 |
+
|
| 5 |
+
from megatron.core.enums import ModelType
|
| 6 |
+
from megatron.training import pretrain
|
| 7 |
+
|
| 8 |
+
from swift.llm.train import SwiftSft
|
| 9 |
+
from swift.utils import get_logger, is_master, plot_images
|
| 10 |
+
from ..argument import MegatronTrainArguments
|
| 11 |
+
from ..utils import patch_megatron_tokenizer
|
| 12 |
+
from .patcher import patch_megatron_data_collator, patch_training_log
|
| 13 |
+
from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MegatronSft(SwiftSft):
|
| 19 |
+
args_class = MegatronTrainArguments
|
| 20 |
+
args: args_class
|
| 21 |
+
|
| 22 |
+
def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) -> None:
|
| 23 |
+
self.train_msg = {}
|
| 24 |
+
super(SwiftSft, self).__init__(args)
|
| 25 |
+
args = self.args
|
| 26 |
+
_, self.processor = args.get_model_processor(load_model=False)
|
| 27 |
+
patch_megatron_tokenizer(self.processor)
|
| 28 |
+
args.init_model_args(self.processor.model_info.config)
|
| 29 |
+
self._prepare_template()
|
| 30 |
+
self.template.use_megatron = True
|
| 31 |
+
args.save_args(args.save)
|
| 32 |
+
|
| 33 |
+
def run(self):
|
| 34 |
+
args = self.args
|
| 35 |
+
|
| 36 |
+
train_dataset, val_dataset = self._get_dataset()
|
| 37 |
+
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
|
| 38 |
+
data_collator = self.template.data_collator
|
| 39 |
+
if args.streaming:
|
| 40 |
+
train_dataset = build_streaming_dataloader(args, train_dataset, data_collator)
|
| 41 |
+
if val_dataset is not None:
|
| 42 |
+
val_dataset = build_streaming_dataloader(args, val_dataset, data_collator)
|
| 43 |
+
datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset)
|
| 44 |
+
datasets_provider.is_distributed = True
|
| 45 |
+
|
| 46 |
+
logging_path = os.path.join(args.save, 'logging.jsonl')
|
| 47 |
+
logger.info(f'The logging file will be saved in: {logging_path}')
|
| 48 |
+
try:
|
| 49 |
+
with patch_training_log(), patch_megatron_data_collator(data_collator):
|
| 50 |
+
pretrain(
|
| 51 |
+
datasets_provider,
|
| 52 |
+
args.megatron_model_meta.model_provider,
|
| 53 |
+
ModelType.encoder_or_decoder,
|
| 54 |
+
forward_step,
|
| 55 |
+
args_defaults=args.extra_args)
|
| 56 |
+
finally:
|
| 57 |
+
# Visualization
|
| 58 |
+
if is_master():
|
| 59 |
+
images_dir = os.path.join(args.save, 'images')
|
| 60 |
+
logger.info(f'images_dir: {images_dir}')
|
| 61 |
+
plot_images(images_dir, args.tensorboard_dir)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def megatron_sft_main(args: Union[List[str], MegatronTrainArguments, None] = None):
|
| 65 |
+
return MegatronSft(args).main()
|
ms-swift/swift/plugin/agent_template/__pycache__/glm4.cpython-310.pyc
ADDED
|
Binary file (3.32 kB). View file
|
|
|
ms-swift/swift/plugin/agent_template/__pycache__/llama.cpython-310.pyc
ADDED
|
Binary file (3.39 kB). View file
|
|
|
ms-swift/swift/plugin/agent_template/__pycache__/qwen.cpython-310.pyc
ADDED
|
Binary file (4.2 kB). View file
|
|
|
ms-swift/swift/plugin/agent_template/__pycache__/react.cpython-310.pyc
ADDED
|
Binary file (2.55 kB). View file
|
|
|
ms-swift/swift/plugin/agent_template/hermes.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import re
|
| 3 |
+
from typing import TYPE_CHECKING, List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
from .base import BaseAgentTemplate
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from swift.llm.infer import Function
|
| 11 |
+
from swift.llm.template import Prompt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HermesAgentTemplate(BaseAgentTemplate):
|
| 15 |
+
|
| 16 |
+
def get_toolcall(self, response: str) -> List['Function']:
|
| 17 |
+
from swift.llm.infer import Function
|
| 18 |
+
res_list = re.findall(r'<tool_call>(.+?)</tool_call>', response, re.DOTALL)
|
| 19 |
+
functions = []
|
| 20 |
+
for res in res_list:
|
| 21 |
+
res = self._parse_json(res)
|
| 22 |
+
if isinstance(res, dict) and 'name' in res and 'arguments' in res:
|
| 23 |
+
functions.append(Function(name=res['name'], arguments=res['arguments']))
|
| 24 |
+
if len(functions) == 0:
|
| 25 |
+
# compat react_en
|
| 26 |
+
return super().get_toolcall(response)
|
| 27 |
+
return functions
|
| 28 |
+
|
| 29 |
+
def _format_tool_responses(
|
| 30 |
+
self,
|
| 31 |
+
assistant_content: str,
|
| 32 |
+
tool_messages,
|
| 33 |
+
) -> Tuple[str, 'Prompt']:
|
| 34 |
+
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
|
| 35 |
+
if with_action:
|
| 36 |
+
return super()._format_tool_responses(assistant_content, tool_messages)
|
| 37 |
+
if hasattr(self, 'template_meta'):
|
| 38 |
+
prompt = self.template_meta.prompt
|
| 39 |
+
chat_sep = self.template_meta.chat_sep
|
| 40 |
+
else:
|
| 41 |
+
prompt = ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n']
|
| 42 |
+
chat_sep = ['<|im_end|>\n']
|
| 43 |
+
res = chat_sep.copy()
|
| 44 |
+
res_tool = []
|
| 45 |
+
for tool_message in tool_messages:
|
| 46 |
+
tool_content = tool_message['content']
|
| 47 |
+
res_tool.append(f'<tool_response>\n{tool_content}\n</tool_response>')
|
| 48 |
+
total_tool = '\n'.join(res_tool)
|
| 49 |
+
for context in prompt:
|
| 50 |
+
if isinstance(context, str):
|
| 51 |
+
context = context.replace('{{QUERY}}', total_tool)
|
| 52 |
+
res.append(context)
|
| 53 |
+
return assistant_content, res
|
| 54 |
+
|
| 55 |
+
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
|
| 56 |
+
tool_descs = [json.dumps(self.wrap_tool(tool), ensure_ascii=False) for tool in tools]
|
| 57 |
+
return f"""{system}
|
| 58 |
+
|
| 59 |
+
# Tools
|
| 60 |
+
|
| 61 |
+
You may call one or more functions to assist with the user query.
|
| 62 |
+
|
| 63 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 64 |
+
<tools>
|
| 65 |
+
""" + '\n'.join(tool_descs) + """
|
| 66 |
+
</tools>
|
| 67 |
+
|
| 68 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 69 |
+
<tool_call>
|
| 70 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 71 |
+
</tool_call>"""
|
| 72 |
+
|
| 73 |
+
def _format_tool_calls(self, tool_call_messages):
|
| 74 |
+
tool_calls = []
|
| 75 |
+
for message in tool_call_messages:
|
| 76 |
+
tool_call = self._parse_tool_call(message['content'])
|
| 77 |
+
tool_calls.append(f'<tool_call>\n{json.dumps(tool_call, ensure_ascii=False)}\n</tool_call>')
|
| 78 |
+
return '\n'.join(tool_calls)
|
ms-swift/swift/plugin/agent_template/react.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
from .base import BaseAgentTemplate
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ReactEnAgentTemplate(BaseAgentTemplate):
|
| 8 |
+
|
| 9 |
+
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
|
| 10 |
+
tool_names = []
|
| 11 |
+
tool_descs = []
|
| 12 |
+
for tool in tools:
|
| 13 |
+
tool_desc = self._parse_tool(tool, 'en')
|
| 14 |
+
tool_names.append(tool_desc.name_for_model)
|
| 15 |
+
tool_descs.append(
|
| 16 |
+
f'{tool_desc.name_for_model}: Call this tool to interact with the {tool_desc.name_for_human} API. '
|
| 17 |
+
f'What is the {tool_desc.name_for_human} API useful for? {tool_desc.description_for_model} '
|
| 18 |
+
f'Parameters: {tool_desc.parameters} {tool_desc.args_format}')
|
| 19 |
+
|
| 20 |
+
return """Answer the following questions as best you can. You have access to the following tools:
|
| 21 |
+
|
| 22 |
+
""" + '\n\n'.join(tool_descs) + f"""
|
| 23 |
+
|
| 24 |
+
Use the following format:
|
| 25 |
+
|
| 26 |
+
Question: the input question you must answer
|
| 27 |
+
Thought: you should always think about what to do
|
| 28 |
+
Action: the action to take, should be one of [{','.join(tool_names)}]
|
| 29 |
+
Action Input: the input to the action
|
| 30 |
+
Observation: the result of the action
|
| 31 |
+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
| 32 |
+
Thought: I now know the final answer
|
| 33 |
+
Final Answer: the final answer to the original input question
|
| 34 |
+
|
| 35 |
+
Begin!
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ReactZnAgentTemplate(BaseAgentTemplate):
|
| 40 |
+
|
| 41 |
+
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
|
| 42 |
+
tool_names = []
|
| 43 |
+
tool_descs = []
|
| 44 |
+
for tool in tools:
|
| 45 |
+
tool_desc = self._parse_tool(tool, 'zh')
|
| 46 |
+
tool_names.append(tool_desc.name_for_model)
|
| 47 |
+
tool_descs.append(f'{tool_desc.name_for_model}: 调用此工具与 {tool_desc.name_for_human} API 进行交互。'
|
| 48 |
+
f'{tool_desc.name_for_human} 有什么用?{tool_desc.description_for_model} '
|
| 49 |
+
f'输入参数:{tool_desc.parameters} {tool_desc.args_format}')
|
| 50 |
+
return """尽可能地回答以下问题。你可以使用以下工具:
|
| 51 |
+
|
| 52 |
+
""" + '\n\n'.join(tool_descs) + f"""
|
| 53 |
+
|
| 54 |
+
请按照以下格式进行:
|
| 55 |
+
|
| 56 |
+
Question: 需要你回答的输入问题
|
| 57 |
+
Thought: 你应该总是思考该做什么
|
| 58 |
+
Action: 需要使用的工具,应该是[{','.join(tool_names)}]中的一个
|
| 59 |
+
Action Input: 传入工具的内容
|
| 60 |
+
Observation: 行动的结果
|
| 61 |
+
... (这个Thought/Action/Action Input/Observation可以重复N次)
|
| 62 |
+
Thought: 我现在知道最后的答案
|
| 63 |
+
Final Answer: 对原始输入问题的最终答案
|
| 64 |
+
|
| 65 |
+
现在开始!
|
| 66 |
+
"""
|
ms-swift/swift/plugin/loss_scale/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .loss_scale import loss_scale_map
|
ms-swift/swift/plugin/loss_scale/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (227 Bytes). View file
|
|
|
ms-swift/swift/plugin/loss_scale/__pycache__/loss_scale.cpython-310.pyc
ADDED
|
Binary file (4.7 kB). View file
|
|
|
ms-swift/swift/plugin/loss_scale/config/agentflan.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"response":{
|
| 3 |
+
"Name:": [1.0, 3.0],
|
| 4 |
+
"Action:": [1.0, 3.0],
|
| 5 |
+
"ACTION:": [1.0,3.0],
|
| 6 |
+
"Tool:": [1.0, 3.0],
|
| 7 |
+
"Command": [1.0, 3.0],
|
| 8 |
+
"Arguments:": [1.0, 3.0],
|
| 9 |
+
"action input": [1.0, 3.0],
|
| 10 |
+
"ACTION_INPUT:":[1.0, 3.0],
|
| 11 |
+
"Action Input:": [1.0, 3.0],
|
| 12 |
+
"Thought:": [1.0, 1.0],
|
| 13 |
+
"Final Answer:": [1.0, 1.0],
|
| 14 |
+
"Observation:": [2.0, 0.0]
|
| 15 |
+
},
|
| 16 |
+
"query":{
|
| 17 |
+
"What is the tool you want to use": [3.0],
|
| 18 |
+
"What are the required parameter names": [3.0],
|
| 19 |
+
"What is the value of": [3.0],
|
| 20 |
+
"What are the required parameter names for this tool": [3.0]
|
| 21 |
+
}
|
| 22 |
+
}
|
ms-swift/swift/plugin/loss_scale/config/hermes.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<tool_call>.+?</tool_call>": [2.0]
|
| 3 |
+
}
|
ms-swift/swift/plugin/prm.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, List, Union
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from swift.llm import InferRequest
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PRM:
|
| 10 |
+
|
| 11 |
+
def __call__(self, **kwargs) -> List[Any]:
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
SYSTEM = """
|
| 16 |
+
You are a process reward model, give the reward value of the answer, you must follow the instructions below:
|
| 17 |
+
|
| 18 |
+
1. Output a float reward value between -1.0 and 1.0, -1.0 means the worst answer, 1.0 means the best answer, please think step by step to give your reasons and thoughts, but the reward must appare at the end with this format: **Reward: your-reward-value**.
|
| 19 |
+
|
| 20 |
+
2. The answer may be incomplete, you must give the reward by the existing part of the answer, taking into account semantic coherence, logical correctness, and clarity.
|
| 21 |
+
|
| 22 |
+
3. A ground truth answer will be given to you, it may be not the best one, consider it as a reference example.
|
| 23 |
+
|
| 24 |
+
Begin!
|
| 25 |
+
""" # noqa
|
| 26 |
+
|
| 27 |
+
QUERY = """
|
| 28 |
+
The original question or the previous conversation:
|
| 29 |
+
|
| 30 |
+
#query#
|
| 31 |
+
|
| 32 |
+
Here is the ground truth as the reference:
|
| 33 |
+
|
| 34 |
+
#ground_truth#
|
| 35 |
+
|
| 36 |
+
Given the upper information, give your reward(-1.0~1.0) of the following answer:
|
| 37 |
+
|
| 38 |
+
#response#
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class QwenMaxPRM(PRM):
|
| 43 |
+
|
| 44 |
+
def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
|
| 45 |
+
**kwargs) -> List[float]:
|
| 46 |
+
# TODO: check request_config
|
| 47 |
+
rewards = []
|
| 48 |
+
|
| 49 |
+
from openai import OpenAI
|
| 50 |
+
|
| 51 |
+
client = OpenAI(
|
| 52 |
+
api_key=os.getenv('DASHSCOPE_API_KEY'),
|
| 53 |
+
base_url='https://dashscope.aliyuncs.com/compatible-mode/v1',
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
for request, ground_truth in zip(infer_requests, ground_truths):
|
| 57 |
+
previous = request['messages'][:-1]
|
| 58 |
+
if previous[0]['role'] == 'system':
|
| 59 |
+
previous = previous[1:]
|
| 60 |
+
|
| 61 |
+
assert request['messages'][-1]['role'] == 'assistant'
|
| 62 |
+
query = QUERY.replace('#query#', json.dumps(previous))
|
| 63 |
+
query = query.replace('#ground_truth#', ground_truth)
|
| 64 |
+
query = query.replace('#response#', request['messages'][-1]['content'])
|
| 65 |
+
messages = [
|
| 66 |
+
{
|
| 67 |
+
'role': 'system',
|
| 68 |
+
'content': SYSTEM
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
'role': 'user',
|
| 72 |
+
'content': query
|
| 73 |
+
},
|
| 74 |
+
]
|
| 75 |
+
completion = client.chat.completions.create(
|
| 76 |
+
model='qwen-max',
|
| 77 |
+
messages=messages,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
content = completion.choices[0].message.content
|
| 81 |
+
if 'Reward:' not in content:
|
| 82 |
+
rewards.append(0.)
|
| 83 |
+
else:
|
| 84 |
+
try:
|
| 85 |
+
reward = float(content.split('Reward:')[1].strip().replace('*', ''))
|
| 86 |
+
rewards.append(reward)
|
| 87 |
+
except Exception:
|
| 88 |
+
rewards.append(0.)
|
| 89 |
+
|
| 90 |
+
return rewards
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ClientPRM(PRM):
|
| 94 |
+
|
| 95 |
+
def __init__(self, api_key=None, base_url=None, model=None):
|
| 96 |
+
from swift.llm import InferClient
|
| 97 |
+
import os
|
| 98 |
+
if api_key is None:
|
| 99 |
+
api_key = os.getenv('DASHSCOPE_API_KEY')
|
| 100 |
+
if base_url is None:
|
| 101 |
+
base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
|
| 102 |
+
if model is None:
|
| 103 |
+
model = 'qwen-plus'
|
| 104 |
+
self.infer_engine = InferClient(base_url=base_url, api_key=api_key)
|
| 105 |
+
self.infer_engine.strict = False
|
| 106 |
+
self.infer_kwargs = {
|
| 107 |
+
'model': model,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
|
| 111 |
+
**kwargs) -> List[float]:
|
| 112 |
+
prm_infer_requests = []
|
| 113 |
+
request_config = kwargs.get('request_config')
|
| 114 |
+
for request, ground_truth in zip(infer_requests, ground_truths):
|
| 115 |
+
previous = request['messages'][:-1]
|
| 116 |
+
if previous[0]['role'] == 'system':
|
| 117 |
+
previous = previous[1:]
|
| 118 |
+
|
| 119 |
+
assert request['messages'][-1]['role'] == 'assistant'
|
| 120 |
+
query = QUERY.replace('#query#', json.dumps(previous))
|
| 121 |
+
query = query.replace('#ground_truth#', ground_truth)
|
| 122 |
+
query = query.replace('#response#', request['messages'][-1]['content'])
|
| 123 |
+
messages = [
|
| 124 |
+
{
|
| 125 |
+
'role': 'system',
|
| 126 |
+
'content': SYSTEM
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
'role': 'user',
|
| 130 |
+
'content': query
|
| 131 |
+
},
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
prm_infer_requests.append(InferRequest(messages=messages))
|
| 135 |
+
|
| 136 |
+
responses = self.infer_engine.infer(prm_infer_requests, request_config=request_config, **self.infer_kwargs)
|
| 137 |
+
rewards = []
|
| 138 |
+
for response in responses:
|
| 139 |
+
content = response.choices[0].message.content
|
| 140 |
+
if 'Reward:' not in content:
|
| 141 |
+
rewards.append(0.)
|
| 142 |
+
else:
|
| 143 |
+
try:
|
| 144 |
+
reward = float(content.split('Reward:')[1].strip().replace('*', ''))
|
| 145 |
+
rewards.append(reward)
|
| 146 |
+
except Exception:
|
| 147 |
+
rewards.append(0.)
|
| 148 |
+
return rewards
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
prms = {
|
| 152 |
+
'qwen_max': QwenMaxPRM,
|
| 153 |
+
'client': ClientPRM,
|
| 154 |
+
}
|
ms-swift/swift/plugin/rm_plugin.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import textwrap
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from swift.llm import PtEngine, RequestConfig, Template, to_device
|
| 9 |
+
from swift.llm.infer.protocol import ChatCompletionResponse
|
| 10 |
+
from swift.utils import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DefaultRMPlugin:
|
| 16 |
+
"""
|
| 17 |
+
Default Reward Model Plugin
|
| 18 |
+
|
| 19 |
+
This class implements the default processing logic for reward models.
|
| 20 |
+
It assumes that `self.model` is a classification model with a value head(output dimmension 1).
|
| 21 |
+
The first logits value from the model's output is used as the reward score.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model, template):
|
| 25 |
+
self.model = model
|
| 26 |
+
self.template: Template = template
|
| 27 |
+
|
| 28 |
+
def __call__(self, inputs):
|
| 29 |
+
batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
|
| 30 |
+
reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
|
| 31 |
+
reward_inputs.pop('labels')
|
| 32 |
+
|
| 33 |
+
with torch.inference_mode():
|
| 34 |
+
return self.model(**reward_inputs).logits[:, 0]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GenRMPlugin(DefaultRMPlugin):
|
| 38 |
+
|
| 39 |
+
def __init__(self, model, template):
|
| 40 |
+
"""
|
| 41 |
+
Generative Reward Model Plugin Example.
|
| 42 |
+
|
| 43 |
+
This method sets up the reward model plugin by initializing the PtEngine for efficient inference,
|
| 44 |
+
configuring the request parameters, and defining the system prompt that guides the reward model in
|
| 45 |
+
evaluating responses.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model (torch.nn.Module): The generative reward model.
|
| 49 |
+
template (Template): The template used for encoding input data.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
super().__init__(model, template)
|
| 53 |
+
# initilize PTEngine to infer
|
| 54 |
+
self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit
|
| 55 |
+
self.request_config = RequestConfig() # customise your request config here
|
| 56 |
+
self.system = textwrap.dedent("""
|
| 57 |
+
Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant.
|
| 58 |
+
Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct.
|
| 59 |
+
Before finishing your response, please assign a reward using the following format:
|
| 60 |
+
|
| 61 |
+
Reward: {reward}
|
| 62 |
+
|
| 63 |
+
For example:
|
| 64 |
+
Reward: 0.85
|
| 65 |
+
""") # noqa
|
| 66 |
+
|
| 67 |
+
def __call__(self, inputs):
|
| 68 |
+
"""
|
| 69 |
+
Compute reward scores for the provided inputs.
|
| 70 |
+
|
| 71 |
+
This method processes each input by converting dialogue messages into a query, sending the query to the
|
| 72 |
+
reward model for inference, and extracting the reward scores from the model's responses. The final reward
|
| 73 |
+
for each input is the average of all extracted scores.
|
| 74 |
+
Args:
|
| 75 |
+
inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing:
|
| 76 |
+
- 'messages' (List[Dict]): messages from the training model. Each message dictionary includes:
|
| 77 |
+
- 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
|
| 78 |
+
- 'content' (str): The content of the message.
|
| 79 |
+
- Additional dataset columns as key-value pairs (e.g., 'solutions', 'images').
|
| 80 |
+
Returns:
|
| 81 |
+
torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,),
|
| 82 |
+
where N is the number of input requests.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
rm_inputs = self.prepare_rm_inputs(inputs)
|
| 86 |
+
results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
|
| 87 |
+
rewards = self.compute_rewards(results)
|
| 88 |
+
return torch.tensor(rewards, dtype=torch.float32)
|
| 89 |
+
|
| 90 |
+
def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]:
|
| 91 |
+
"""
|
| 92 |
+
Prepare inputs for the reward model by converting messages into queries.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
inputs (List[Dict]): A list of input requests.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
List[Dict]: Processed inputs for the reward model.
|
| 99 |
+
"""
|
| 100 |
+
rm_inputs = []
|
| 101 |
+
for idx, infer_request in enumerate(inputs):
|
| 102 |
+
# Deep copy to prevent modification of original input
|
| 103 |
+
rm_infer_request = deepcopy(infer_request)
|
| 104 |
+
|
| 105 |
+
# Extract and convert messages to a single query string
|
| 106 |
+
messages = rm_infer_request.get('messages')
|
| 107 |
+
query = self.messages_to_query(messages)
|
| 108 |
+
|
| 109 |
+
# Construct new messages tailored for the reward model
|
| 110 |
+
rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}]
|
| 111 |
+
|
| 112 |
+
# Update the messages in the reward infer request
|
| 113 |
+
rm_infer_request['messages'] = rm_messages
|
| 114 |
+
rm_inputs.append(rm_infer_request)
|
| 115 |
+
return rm_inputs
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def extract_reward(model_output: str) -> float:
|
| 119 |
+
"""
|
| 120 |
+
Extract the reward score from the model's output.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
model_output (str): The model's output string, expected to follow the format "Reward: {reward}".
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
float: The extracted reward score.
|
| 127 |
+
|
| 128 |
+
Raises:
|
| 129 |
+
ValueError: If the reward score cannot be extracted or the format is incorrect.
|
| 130 |
+
"""
|
| 131 |
+
match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output)
|
| 132 |
+
if match:
|
| 133 |
+
return float(match.group(1))
|
| 134 |
+
else:
|
| 135 |
+
logger.warning("Unable to extract reward score from the model's output, set reward to 0")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def messages_to_query(messages):
|
| 140 |
+
"""
|
| 141 |
+
Compress a list of message dictionaries into a single query string.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
messages (list[dict]): A list of message dictionaries, each containing:
|
| 145 |
+
- 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
|
| 146 |
+
- 'content' (str): The content of the message.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
str: A single string that concatenates all messages in a formatted manner.
|
| 150 |
+
|
| 151 |
+
Example:
|
| 152 |
+
>>> messages = [
|
| 153 |
+
... {'role': 'user', 'content': 'Hello, how are you?'},
|
| 154 |
+
... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'},
|
| 155 |
+
... {'role': 'user', 'content': 'Can you help me with my homework?'}
|
| 156 |
+
... ]
|
| 157 |
+
>>> print(messages_to_query(messages))
|
| 158 |
+
User: Hello, how are you?
|
| 159 |
+
Assistant: I am fine, thank you! How can I assist you today?
|
| 160 |
+
User: Can you help me with my homework?
|
| 161 |
+
"""
|
| 162 |
+
# Initialize an empty list to hold formatted messages
|
| 163 |
+
formatted_messages = []
|
| 164 |
+
|
| 165 |
+
# Define a mapping for role capitalization if needed
|
| 166 |
+
role_mapping = {
|
| 167 |
+
'user': 'User',
|
| 168 |
+
'assistant': 'Assistant',
|
| 169 |
+
'system': 'System'
|
| 170 |
+
# Add more roles here as needed
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
for idx, message in enumerate(messages):
|
| 174 |
+
if not isinstance(message, dict):
|
| 175 |
+
raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.')
|
| 176 |
+
|
| 177 |
+
# Extract 'role' and 'content' from each message
|
| 178 |
+
role = message.get('role')
|
| 179 |
+
content = message.get('content')
|
| 180 |
+
if not content:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
# Capitalize the role using the mapping, default to capitalized original role
|
| 184 |
+
role_formatted = role_mapping.get(role.lower(), role.capitalize())
|
| 185 |
+
|
| 186 |
+
# Append the formatted message to the list
|
| 187 |
+
formatted_messages.append(f'{role_formatted}: {content}')
|
| 188 |
+
|
| 189 |
+
# Join all formatted messages with newline characters
|
| 190 |
+
query = '\n'.join(formatted_messages)
|
| 191 |
+
|
| 192 |
+
return query
|
| 193 |
+
|
| 194 |
+
def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
|
| 195 |
+
"""
|
| 196 |
+
Compute average reward scores from the reward model's outputs.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
results (List[ChatCompletionResponse]): A list of results from the reward model.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
List[float]: A list of average reward scores.
|
| 203 |
+
"""
|
| 204 |
+
rewards = []
|
| 205 |
+
for idx, output in enumerate(results):
|
| 206 |
+
try:
|
| 207 |
+
cur_rewards = []
|
| 208 |
+
for choice in output.choices:
|
| 209 |
+
response = choice.message.content
|
| 210 |
+
reward = self.extract_reward(response)
|
| 211 |
+
cur_rewards.append(reward)
|
| 212 |
+
cur_rewards = [r for r in cur_rewards if r is not None]
|
| 213 |
+
if cur_rewards:
|
| 214 |
+
average_reward = sum(cur_rewards) / len(cur_rewards)
|
| 215 |
+
else:
|
| 216 |
+
average_reward = 0.0
|
| 217 |
+
logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
|
| 218 |
+
|
| 219 |
+
rewards.append(average_reward)
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.error(f'Error computing reward: {e}')
|
| 222 |
+
rewards.append(0.0) # Assign default reward score on failure
|
| 223 |
+
return rewards
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
rm_plugins = {
|
| 227 |
+
'default': DefaultRMPlugin,
|
| 228 |
+
'genrm': GenRMPlugin,
|
| 229 |
+
}
|
ms-swift/swift/trainers/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
from transformers.trainer_callback import TrainerCallback
|
| 5 |
+
from transformers.trainer_utils import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy,
|
| 6 |
+
SchedulerType)
|
| 7 |
+
|
| 8 |
+
from swift.utils.import_utils import _LazyModule
|
| 9 |
+
from . import callback
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
# https://github.com/huggingface/transformers/pull/25702
|
| 13 |
+
from transformers.trainer_utils import ShardedDDPOption
|
| 14 |
+
except ImportError:
|
| 15 |
+
ShardedDDPOption = None
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from .arguments import Seq2SeqTrainingArguments, TrainingArguments
|
| 19 |
+
from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer,
|
| 20 |
+
RewardTrainer, GRPOTrainer)
|
| 21 |
+
from .rlhf_arguments import DPOConfig, CPOConfig, KTOConfig, ORPOConfig, PPOConfig, RewardConfig
|
| 22 |
+
from .trainer_factory import TrainerFactory
|
| 23 |
+
from .trainers import Seq2SeqTrainer, Trainer, EmbeddingTrainer
|
| 24 |
+
from .mixin import SwiftMixin
|
| 25 |
+
|
| 26 |
+
else:
|
| 27 |
+
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
|
| 28 |
+
_import_structure = {
|
| 29 |
+
'arguments': ['Seq2SeqTrainingArguments', 'TrainingArguments'],
|
| 30 |
+
'rlhf_arguments':
|
| 31 |
+
['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig'],
|
| 32 |
+
'rlhf_trainer': [
|
| 33 |
+
'CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer',
|
| 34 |
+
'GRPOTrainer'
|
| 35 |
+
],
|
| 36 |
+
'trainer_factory': ['TrainerFactory'],
|
| 37 |
+
'trainers': ['Seq2SeqTrainer', 'Trainer', 'EmbeddingTrainer'],
|
| 38 |
+
'mixin': ['SwiftMixin'],
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
import sys
|
| 42 |
+
|
| 43 |
+
sys.modules[__name__] = _LazyModule(
|
| 44 |
+
__name__,
|
| 45 |
+
globals()['__file__'],
|
| 46 |
+
_import_structure,
|
| 47 |
+
module_spec=__spec__,
|
| 48 |
+
extra_objects=_extra_objects,
|
| 49 |
+
)
|
ms-swift/swift/trainers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
ms-swift/swift/trainers/__pycache__/callback.cpython-310.pyc
ADDED
|
Binary file (4.96 kB). View file
|
|
|
ms-swift/swift/trainers/__pycache__/trainer_factory.cpython-310.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
ms-swift/swift/trainers/__pycache__/trainers.cpython-310.pyc
ADDED
|
Binary file (7.93 kB). View file
|
|
|
ms-swift/swift/trainers/callback.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import trainer
|
| 8 |
+
from transformers.trainer_callback import (DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerControl,
|
| 9 |
+
TrainerState)
|
| 10 |
+
from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics
|
| 11 |
+
|
| 12 |
+
from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc
|
| 13 |
+
from ..utils.utils import format_time
|
| 14 |
+
from .arguments import TrainingArguments
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def add_train_message(logs, state, start_time) -> None:
|
| 18 |
+
logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}'
|
| 19 |
+
train_percentage = state.global_step / state.max_steps if state.max_steps else 0.
|
| 20 |
+
logs['percentage'] = f'{train_percentage * 100:.2f}%'
|
| 21 |
+
elapsed = time.time() - start_time
|
| 22 |
+
logs['elapsed_time'] = format_time(elapsed)
|
| 23 |
+
if train_percentage != 0:
|
| 24 |
+
logs['remaining_time'] = format_time(elapsed / train_percentage - elapsed)
|
| 25 |
+
for k, v in logs.items():
|
| 26 |
+
if isinstance(v, float):
|
| 27 |
+
logs[k] = round(logs[k], 8)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ProgressCallbackNew(ProgressCallback):
|
| 31 |
+
|
| 32 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 33 |
+
if state.is_world_process_zero:
|
| 34 |
+
self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True)
|
| 35 |
+
self.current_step = 0
|
| 36 |
+
self.start_time = time.time()
|
| 37 |
+
if use_torchacc():
|
| 38 |
+
self.warmup_start_time = 0
|
| 39 |
+
self.warmup_metric = None
|
| 40 |
+
self.metric_warmup_step = int(args.metric_warmup_step
|
| 41 |
+
* state.max_steps) if args.metric_warmup_step < 1 else args.metric_warmup_step
|
| 42 |
+
|
| 43 |
+
def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs):
|
| 44 |
+
if state.is_world_process_zero and has_length(eval_dataloader):
|
| 45 |
+
if self.prediction_bar is None:
|
| 46 |
+
if self.training_bar is not None:
|
| 47 |
+
self.training_bar.fp.write('\n')
|
| 48 |
+
self.prediction_bar = tqdm(
|
| 49 |
+
desc='Val', total=len(eval_dataloader), leave=True, dynamic_ncols=True, position=0)
|
| 50 |
+
self.prediction_bar.update()
|
| 51 |
+
|
| 52 |
+
def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
|
| 53 |
+
|
| 54 |
+
if use_torchacc():
|
| 55 |
+
if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0:
|
| 56 |
+
self.warmup_start_time = time.time()
|
| 57 |
+
self.metric_warmup_step = state.global_step
|
| 58 |
+
if state.max_steps == state.global_step and self.warmup_metric is None:
|
| 59 |
+
num_steps = state.max_steps - self.metric_warmup_step
|
| 60 |
+
num_total_samples = args.train_dataset_sample
|
| 61 |
+
num_after_warmup_samples = int(num_total_samples / state.max_steps * num_steps)
|
| 62 |
+
self.warmup_metric = speed_metrics('warmup_train', self.warmup_start_time, num_after_warmup_samples,
|
| 63 |
+
num_steps)
|
| 64 |
+
self.warmup_metric['num_total_samples'] = num_total_samples
|
| 65 |
+
self.warmup_metric['num_after_warmup_samples'] = num_after_warmup_samples
|
| 66 |
+
if 'train_samples_per_second' in logs:
|
| 67 |
+
logs.update(self.warmup_metric)
|
| 68 |
+
state.log_history[-1] = logs
|
| 69 |
+
|
| 70 |
+
add_train_message(logs, state, self.start_time)
|
| 71 |
+
if not is_pai_training_job() and state.is_world_process_zero:
|
| 72 |
+
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
|
| 73 |
+
append_to_jsonl(jsonl_path, logs)
|
| 74 |
+
super().on_log(args, state, control, logs, **kwargs)
|
| 75 |
+
if state.is_world_process_zero and self.training_bar is not None:
|
| 76 |
+
self.training_bar.refresh()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DefaultFlowCallbackNew(DefaultFlowCallback):
|
| 80 |
+
|
| 81 |
+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 82 |
+
control = super().on_step_end(args, state, control, **kwargs)
|
| 83 |
+
# save the last ckpt
|
| 84 |
+
evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
|
| 85 |
+
if state.global_step == state.max_steps:
|
| 86 |
+
if evaluation_strategy != IntervalStrategy.NO:
|
| 87 |
+
control.should_evaluate = True
|
| 88 |
+
if args.save_strategy != IntervalStrategy.NO:
|
| 89 |
+
control.should_save = True
|
| 90 |
+
return control
|
| 91 |
+
|
| 92 |
+
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 93 |
+
control = super().on_epoch_end(args, state, control, **kwargs)
|
| 94 |
+
evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
|
| 95 |
+
if args.max_epochs is not None and args.max_epochs <= math.ceil(state.epoch):
|
| 96 |
+
if evaluation_strategy != IntervalStrategy.NO:
|
| 97 |
+
control.should_evaluate = True
|
| 98 |
+
if args.save_strategy != IntervalStrategy.NO:
|
| 99 |
+
control.should_save = True
|
| 100 |
+
control.should_training_stop = True
|
| 101 |
+
return control
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class PrinterCallbackNew(PrinterCallback):
|
| 105 |
+
|
| 106 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 107 |
+
self.start_time = time.time()
|
| 108 |
+
return super().on_train_begin(args, state, control, **kwargs)
|
| 109 |
+
|
| 110 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 111 |
+
add_train_message(logs, state, self.start_time)
|
| 112 |
+
if not is_pai_training_job() and state.is_world_process_zero:
|
| 113 |
+
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
|
| 114 |
+
append_to_jsonl(jsonl_path, logs)
|
| 115 |
+
|
| 116 |
+
_ = logs.pop('total_flos', None)
|
| 117 |
+
if state.is_world_process_zero:
|
| 118 |
+
print(logs, flush=True)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# monkey patching
|
| 122 |
+
trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
|
| 123 |
+
trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew]
|
| 124 |
+
trainer.PrinterCallback = PrinterCallbackNew
|
ms-swift/swift/trainers/mixin.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
# Part of the implementation is borrowed from huggingface/transformers.
|
| 3 |
+
import inspect
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import time
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from copy import copy
|
| 9 |
+
from functools import partial
|
| 10 |
+
from types import MethodType
|
| 11 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import safetensors
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import transformers
|
| 18 |
+
from datasets import Dataset as HfDataset
|
| 19 |
+
from modelscope import check_local_model_is_latest
|
| 20 |
+
from packaging import version
|
| 21 |
+
from peft import PeftModel
|
| 22 |
+
from torch.nn import Module
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
from transformers import PreTrainedModel
|
| 25 |
+
from transformers.data.data_collator import DataCollator
|
| 26 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 27 |
+
from transformers.modeling_utils import unwrap_model
|
| 28 |
+
from transformers.trainer import TrainerCallback
|
| 29 |
+
from transformers.trainer_utils import EvalPrediction, IntervalStrategy
|
| 30 |
+
from transformers.utils import is_torch_npu_available
|
| 31 |
+
|
| 32 |
+
from swift.hub import get_hub
|
| 33 |
+
from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template
|
| 34 |
+
from swift.plugin import MeanMetric, compute_acc, extra_tuners
|
| 35 |
+
from swift.tuners import SwiftModel
|
| 36 |
+
from swift.utils import get_logger, is_mp_ddp, use_torchacc
|
| 37 |
+
from swift.utils.torchacc_utils import ta_trim_graph
|
| 38 |
+
from ..utils.torch_utils import get_device_count
|
| 39 |
+
from .arguments import TrainingArguments
|
| 40 |
+
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from trl import AutoModelForCausalLMWithValueHead
|
| 44 |
+
except (ImportError, RuntimeError):
|
| 45 |
+
AutoModelForCausalLMWithValueHead = None
|
| 46 |
+
|
| 47 |
+
logger = get_logger()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SwiftMixin:
|
| 51 |
+
|
| 52 |
+
def __init__(self,
|
| 53 |
+
model: Union[PreTrainedModel, Module] = None,
|
| 54 |
+
args: TrainingArguments = None,
|
| 55 |
+
data_collator: Optional[DataCollator] = None,
|
| 56 |
+
train_dataset: Optional[HfDataset] = None,
|
| 57 |
+
eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None,
|
| 58 |
+
template: Optional[Template] = None,
|
| 59 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 60 |
+
compute_loss_func: Optional[Callable] = None,
|
| 61 |
+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
| 62 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
| 63 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 64 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 65 |
+
**kwargs) -> None:
|
| 66 |
+
if not hasattr(train_dataset, '__len__') and args.dataloader_num_workers > 1:
|
| 67 |
+
args.dataloader_num_workers = 1
|
| 68 |
+
logger.warning('Using IterableDataset, setting args.dataloader_num_workers to 1.')
|
| 69 |
+
|
| 70 |
+
if args.check_model and hasattr(model, 'model_dir'):
|
| 71 |
+
from swift.utils.logger import ms_logger_ignore_error
|
| 72 |
+
with ms_logger_ignore_error():
|
| 73 |
+
check_local_model_is_latest(
|
| 74 |
+
model.model_dir, user_agent={
|
| 75 |
+
'invoked_by': 'local_trainer',
|
| 76 |
+
'third_party': 'swift',
|
| 77 |
+
})
|
| 78 |
+
if eval_dataset is None and args:
|
| 79 |
+
args.evaluation_strategy = IntervalStrategy.NO
|
| 80 |
+
args.eval_strategy = IntervalStrategy.NO
|
| 81 |
+
|
| 82 |
+
self._custom_metrics = {}
|
| 83 |
+
self.template = template
|
| 84 |
+
self.max_memory = 0
|
| 85 |
+
self.hub = get_hub()
|
| 86 |
+
|
| 87 |
+
self.model_meta = model.model_meta
|
| 88 |
+
with self.hub.patch_hub():
|
| 89 |
+
super().__init__(
|
| 90 |
+
model=model,
|
| 91 |
+
args=args,
|
| 92 |
+
data_collator=data_collator,
|
| 93 |
+
train_dataset=train_dataset,
|
| 94 |
+
eval_dataset=eval_dataset,
|
| 95 |
+
tokenizer=template.tokenizer,
|
| 96 |
+
model_init=model_init,
|
| 97 |
+
compute_metrics=compute_metrics,
|
| 98 |
+
callbacks=callbacks,
|
| 99 |
+
optimizers=optimizers,
|
| 100 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 101 |
+
**kwargs)
|
| 102 |
+
|
| 103 |
+
self.compute_loss_func = compute_loss_func
|
| 104 |
+
if get_function(model.__class__.forward) is not get_function(model.forward):
|
| 105 |
+
self.label_names = find_labels(model)
|
| 106 |
+
self.can_return_loss = can_return_loss(model)
|
| 107 |
+
self.label_names = self.label_names or ['labels']
|
| 108 |
+
self.start_time = time.time()
|
| 109 |
+
if self.template.sequence_parallel_size > 1:
|
| 110 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 111 |
+
sequence_parallel.prepare_trainer(self)
|
| 112 |
+
|
| 113 |
+
def _save_initial_model(self, output_dir):
|
| 114 |
+
# pissa/olora/lora-ga
|
| 115 |
+
model = unwrap_model(self.model)
|
| 116 |
+
if isinstance(model, PeftModel):
|
| 117 |
+
config = model.peft_config.get('default')
|
| 118 |
+
init_lora_weights = getattr(config, 'init_lora_weights', None)
|
| 119 |
+
if (isinstance(init_lora_weights, str)
|
| 120 |
+
and any(s in init_lora_weights for s in ('pissa', 'olora', 'lora-ga'))):
|
| 121 |
+
config.init_lora_weights = True
|
| 122 |
+
model.save_pretrained(os.path.join(output_dir, 'initial_model'))
|
| 123 |
+
config.init_lora_weights = init_lora_weights
|
| 124 |
+
|
| 125 |
+
def _save_converted_model(self, output_dir):
|
| 126 |
+
# pissa/olora/lora-ga
|
| 127 |
+
model = unwrap_model(self.model)
|
| 128 |
+
if isinstance(model, PeftModel):
|
| 129 |
+
config = model.peft_config.get('default')
|
| 130 |
+
init_lora_weights = getattr(config, 'init_lora_weights', None)
|
| 131 |
+
if isinstance(init_lora_weights, str):
|
| 132 |
+
config = copy(config)
|
| 133 |
+
os.makedirs(os.path.join(output_dir, 'converted'), exist_ok=True)
|
| 134 |
+
if 'lora-ga' in init_lora_weights:
|
| 135 |
+
try:
|
| 136 |
+
from lora_ga.entrypoint import LoraGAContext
|
| 137 |
+
with LoraGAContext(model):
|
| 138 |
+
model.save_pretrained(
|
| 139 |
+
os.path.join(output_dir, 'converted', 'default'),
|
| 140 |
+
path_initial_model_for_weight_conversion=os.path.join(
|
| 141 |
+
os.path.dirname(output_dir), 'initial_model'),
|
| 142 |
+
)
|
| 143 |
+
model.peft_config['default'] = config
|
| 144 |
+
except ImportError as e:
|
| 145 |
+
error_message = """
|
| 146 |
+
Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub.
|
| 147 |
+
Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'.
|
| 148 |
+
"""
|
| 149 |
+
logger.info(error_message)
|
| 150 |
+
raise RuntimeError(error_message) from e
|
| 151 |
+
elif 'pissa' in init_lora_weights or 'olora' in init_lora_weights:
|
| 152 |
+
model.save_pretrained(
|
| 153 |
+
os.path.join(output_dir, 'converted', 'default'),
|
| 154 |
+
path_initial_model_for_weight_conversion=os.path.join(
|
| 155 |
+
os.path.dirname(output_dir), 'initial_model'),
|
| 156 |
+
)
|
| 157 |
+
model.peft_config['default'] = config
|
| 158 |
+
|
| 159 |
+
def _load_optimizer_and_scheduler(self, *args, **kwargs):
|
| 160 |
+
super()._load_optimizer_and_scheduler(*args, **kwargs)
|
| 161 |
+
if is_mp_ddp():
|
| 162 |
+
# fix mp+ddp adamw
|
| 163 |
+
for v in self.optimizer.state.values():
|
| 164 |
+
if 'step' in v:
|
| 165 |
+
# not on the same device
|
| 166 |
+
device_set = set([t.device for t in v.values()]) - {v['step'].device, torch.device('cpu')}
|
| 167 |
+
if len(device_set) >= 1:
|
| 168 |
+
v['step'] = v['step'].to('cpu')
|
| 169 |
+
|
| 170 |
+
def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
|
| 171 |
+
# model
|
| 172 |
+
supported_classes = (SwiftModel, PreTrainedModel, PeftModel)
|
| 173 |
+
supported_names = ('SentenceTransformer')
|
| 174 |
+
if AutoModelForCausalLMWithValueHead is not None:
|
| 175 |
+
supported_classes = supported_classes + (AutoModelForCausalLMWithValueHead, )
|
| 176 |
+
save_safetensors = self.args.save_safetensors
|
| 177 |
+
if not isinstance(self.model, supported_classes) and self.model.__class__.__name__ not in supported_names:
|
| 178 |
+
if state_dict is None:
|
| 179 |
+
state_dict = self.model.state_dict()
|
| 180 |
+
|
| 181 |
+
_unwrap_model = unwrap_model(self.model)
|
| 182 |
+
if isinstance(_unwrap_model, supported_classes):
|
| 183 |
+
_unwrap_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
|
| 184 |
+
else:
|
| 185 |
+
logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
|
| 186 |
+
if save_safetensors:
|
| 187 |
+
safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
|
| 188 |
+
else:
|
| 189 |
+
torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
|
| 190 |
+
elif AutoModelForCausalLMWithValueHead and isinstance(self.model, AutoModelForCausalLMWithValueHead):
|
| 191 |
+
# save reward model
|
| 192 |
+
state_dict = self.model.state_dict()
|
| 193 |
+
decoder_state_dict, v_head_state_dict = {}, {}
|
| 194 |
+
for name, param in state_dict.items():
|
| 195 |
+
if name.startswith('v_head.'):
|
| 196 |
+
v_head_state_dict[name] = param
|
| 197 |
+
else:
|
| 198 |
+
decoder_state_dict[name.replace('pretrained_model.', '', 1)] = param
|
| 199 |
+
self.model.pretrained_model.save_pretrained(
|
| 200 |
+
output_dir, state_dict=decoder_state_dict or None, safe_serialization=save_safetensors)
|
| 201 |
+
if save_safetensors:
|
| 202 |
+
from safetensors.torch import save_file
|
| 203 |
+
save_file(
|
| 204 |
+
v_head_state_dict, os.path.join(output_dir, 'value_head.safetensors'), metadata={'format': 'pt'})
|
| 205 |
+
else:
|
| 206 |
+
torch.save(v_head_state_dict, os.path.join(output_dir, 'value_head.bin'))
|
| 207 |
+
elif is_instance_of_ms_model(self.model):
|
| 208 |
+
PreTrainedModel.save_pretrained(
|
| 209 |
+
self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
|
| 210 |
+
elif self.args.train_type in extra_tuners:
|
| 211 |
+
extra_tuners[self.args.train_type].save_pretrained(
|
| 212 |
+
self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
|
| 213 |
+
else:
|
| 214 |
+
if self.model.__class__.__name__ != 'SentenceTransformer':
|
| 215 |
+
self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
|
| 216 |
+
else:
|
| 217 |
+
|
| 218 |
+
@contextmanager
|
| 219 |
+
def save_context():
|
| 220 |
+
save_pretrained = self.model[0].auto_model.save_pretrained
|
| 221 |
+
_state_dict = {
|
| 222 |
+
key[len('0.auto_model.'):] if 'auto_model' in key else key: value
|
| 223 |
+
for key, value in state_dict.items()
|
| 224 |
+
}
|
| 225 |
+
self.model[0].auto_model.save_pretrained = partial(
|
| 226 |
+
self.model[0].auto_model.save_pretrained, state_dict=_state_dict)
|
| 227 |
+
yield
|
| 228 |
+
self.model[0].auto_model.save_pretrained = save_pretrained
|
| 229 |
+
|
| 230 |
+
with save_context():
|
| 231 |
+
self.model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
| 232 |
+
# copy sentencetransformers files
|
| 233 |
+
from swift.utils import copy_files_by_pattern
|
| 234 |
+
copy_files_by_pattern(self.model.model_dir, output_dir, '*.py')
|
| 235 |
+
copy_files_by_pattern(self.model.model_dir, output_dir, '*.json')
|
| 236 |
+
|
| 237 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 238 |
+
"""Compatible with swift and peft"""
|
| 239 |
+
# If we are executing this function, we are the process zero, so we don't check for that.
|
| 240 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 241 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 242 |
+
self._save_model(output_dir, state_dict)
|
| 243 |
+
# training_args.bin
|
| 244 |
+
torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))
|
| 245 |
+
self._save_converted_model(output_dir)
|
| 246 |
+
# args.json
|
| 247 |
+
args_path = os.path.join(os.path.dirname(output_dir), 'args.json')
|
| 248 |
+
if os.path.exists(args_path):
|
| 249 |
+
shutil.copy(args_path, os.path.join(output_dir, 'args.json'))
|
| 250 |
+
# predict.jsonl
|
| 251 |
+
predict_jsonl = os.path.join(os.path.dirname(output_dir), 'predict.jsonl')
|
| 252 |
+
if os.path.exists(predict_jsonl):
|
| 253 |
+
shutil.move(predict_jsonl, os.path.join(output_dir, 'predict.jsonl'))
|
| 254 |
+
|
| 255 |
+
is_adapter = isinstance(self.model, (SwiftModel, PeftModel))
|
| 256 |
+
# tokenizer
|
| 257 |
+
if not is_adapter:
|
| 258 |
+
from swift.llm import save_checkpoint
|
| 259 |
+
additional_saved_files = self.model_meta.additional_saved_files
|
| 260 |
+
save_checkpoint(
|
| 261 |
+
None,
|
| 262 |
+
self.template.processor,
|
| 263 |
+
output_dir,
|
| 264 |
+
model_dirs=[self.model.model_dir],
|
| 265 |
+
additional_saved_files=additional_saved_files)
|
| 266 |
+
if getattr(self.model, 'origin_generation_config', None):
|
| 267 |
+
self.model.origin_generation_config.save_pretrained(output_dir)
|
| 268 |
+
|
| 269 |
+
def _fix_zero3_gather_all_parameters(self) -> None:
|
| 270 |
+
if is_deepspeed_zero3_enabled() and not hasattr(self.deepspeed, '_zero3_consolidated_16bit_state_dict_origin'):
|
| 271 |
+
parameters = inspect.signature(self.deepspeed._zero3_consolidated_16bit_state_dict).parameters
|
| 272 |
+
if 'exclude_frozen_parameters' in parameters:
|
| 273 |
+
|
| 274 |
+
def _zero3_consolidated_16bit_state_dict(model, exclude_frozen_parameters=False):
|
| 275 |
+
unwrapped = unwrap_model(model)
|
| 276 |
+
exclude_frozen_parameters = False
|
| 277 |
+
if isinstance(unwrapped, SwiftModel) and unwrapped.has_additional_modules:
|
| 278 |
+
exclude_frozen_parameters = True
|
| 279 |
+
if isinstance(unwrapped, PeftModel):
|
| 280 |
+
exclude_frozen_parameters = True
|
| 281 |
+
return model._zero3_consolidated_16bit_state_dict_origin(exclude_frozen_parameters)
|
| 282 |
+
|
| 283 |
+
self.deepspeed._zero3_consolidated_16bit_state_dict_origin = (
|
| 284 |
+
self.deepspeed._zero3_consolidated_16bit_state_dict)
|
| 285 |
+
self.deepspeed._zero3_consolidated_16bit_state_dict = MethodType(_zero3_consolidated_16bit_state_dict,
|
| 286 |
+
self.deepspeed)
|
| 287 |
+
|
| 288 |
+
def _save_checkpoint(self, *args, **kwargs):
|
| 289 |
+
self.state.last_model_checkpoint = os.path.join(self.args.output_dir, f'checkpoint-{self.state.global_step}')
|
| 290 |
+
self._fix_zero3_gather_all_parameters()
|
| 291 |
+
result = super()._save_checkpoint(*args, **kwargs)
|
| 292 |
+
logger.info(f'Saving model checkpoint to {self.state.last_model_checkpoint}')
|
| 293 |
+
return result
|
| 294 |
+
|
| 295 |
+
@staticmethod
|
| 296 |
+
@contextmanager
|
| 297 |
+
def _fix_grad_norm_nan():
|
| 298 |
+
from accelerate import Accelerator
|
| 299 |
+
origin_clip_grad_norm_ = Accelerator.clip_grad_norm_
|
| 300 |
+
|
| 301 |
+
def clip_grad_norm_(self, parameters, *args, **kwargs):
|
| 302 |
+
# If NaN occurs, ignore weight updates.
|
| 303 |
+
parameters = list(parameters)
|
| 304 |
+
grad_norm = origin_clip_grad_norm_(self, parameters, *args, **kwargs)
|
| 305 |
+
if isinstance(grad_norm, torch.Tensor) and grad_norm.isnan().item():
|
| 306 |
+
for p in parameters:
|
| 307 |
+
p.grad = None
|
| 308 |
+
return grad_norm
|
| 309 |
+
|
| 310 |
+
Accelerator.clip_grad_norm_ = clip_grad_norm_
|
| 311 |
+
try:
|
| 312 |
+
yield
|
| 313 |
+
finally:
|
| 314 |
+
Accelerator.clip_grad_norm_ = origin_clip_grad_norm_
|
| 315 |
+
|
| 316 |
+
def train(self, *args, **kwargs):
|
| 317 |
+
if self.model_meta.is_multimodal:
|
| 318 |
+
models = []
|
| 319 |
+
for model_name in ['model', 'ref_model', 'value_model']:
|
| 320 |
+
model = getattr(self, model_name, None)
|
| 321 |
+
if isinstance(model, nn.Module):
|
| 322 |
+
models.append(model)
|
| 323 |
+
|
| 324 |
+
reward_model = getattr(self, 'reward_model', None)
|
| 325 |
+
if reward_model is not None:
|
| 326 |
+
if isinstance(reward_model, list):
|
| 327 |
+
models.extend([m for m in reward_model if isinstance(m, nn.Module)])
|
| 328 |
+
elif isinstance(reward_model, nn.Module):
|
| 329 |
+
models.append(reward_model)
|
| 330 |
+
|
| 331 |
+
models = list(set(models)) # Deduplicate
|
| 332 |
+
self.template.register_post_encode_hook(models)
|
| 333 |
+
logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}.')
|
| 334 |
+
self._save_initial_model(self.args.output_dir)
|
| 335 |
+
with self.hub.patch_hub(), self._fix_grad_norm_nan():
|
| 336 |
+
res = super().train(*args, **kwargs)
|
| 337 |
+
self.template.remove_post_encode_hook()
|
| 338 |
+
return res
|
| 339 |
+
|
| 340 |
+
def push_to_hub(self, *args, **kwargs):
|
| 341 |
+
with self.hub.patch_hub():
|
| 342 |
+
return super().push_to_hub(*args, **kwargs)
|
| 343 |
+
|
| 344 |
+
def get_max_cuda_memory(self, device: Optional[Union[torch.device, int]] = None) -> float:
|
| 345 |
+
if device is None:
|
| 346 |
+
mems = [torch.cuda.max_memory_reserved(device=device) for device in range(get_device_count())]
|
| 347 |
+
else:
|
| 348 |
+
mems = [torch.cuda.max_memory_reserved(device=device)]
|
| 349 |
+
mem = sum(mems) / 1024**3
|
| 350 |
+
self.max_memory = max(self.max_memory, mem)
|
| 351 |
+
return mem
|
| 352 |
+
|
| 353 |
+
def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
|
| 354 |
+
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
| 355 |
+
self.control.should_log = False
|
| 356 |
+
|
| 357 |
+
# all_gather + mean() to get average loss over all processes
|
| 358 |
+
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
|
| 359 |
+
loss = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
|
| 360 |
+
logs: Dict[str, float] = {'loss': loss} # loss first
|
| 361 |
+
|
| 362 |
+
for k, metric in self._custom_metrics.items():
|
| 363 |
+
value = metric.compute()
|
| 364 |
+
if len(value) == 1:
|
| 365 |
+
val = list(value.values())[0]
|
| 366 |
+
logs[k] = val
|
| 367 |
+
else:
|
| 368 |
+
for k_suffix, val in value.items():
|
| 369 |
+
new_k = f'{k}_{k_suffix}'
|
| 370 |
+
logs[new_k] = val
|
| 371 |
+
metric.reset()
|
| 372 |
+
|
| 373 |
+
if version.parse(transformers.__version__) >= version.parse('4.38'):
|
| 374 |
+
grad_norm = args[0]
|
| 375 |
+
if grad_norm is not None:
|
| 376 |
+
logs['grad_norm'] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
| 377 |
+
logs['learning_rate'] = self._get_learning_rate()
|
| 378 |
+
if not is_torch_npu_available():
|
| 379 |
+
logs['memory(GiB)'] = round(self.get_max_cuda_memory(), 2)
|
| 380 |
+
|
| 381 |
+
elapse_time = time.time() - self.start_time
|
| 382 |
+
logs['train_speed(iter/s)'] = round(self.state.global_step / elapse_time, 6)
|
| 383 |
+
for k in list(logs.keys()):
|
| 384 |
+
if logs[k] is None:
|
| 385 |
+
logs.pop(k)
|
| 386 |
+
tr_loss -= tr_loss
|
| 387 |
+
self._total_loss_scalar += tr_loss_scalar
|
| 388 |
+
self._globalstep_last_logged = self.state.global_step
|
| 389 |
+
self.store_flos()
|
| 390 |
+
self.log(logs)
|
| 391 |
+
|
| 392 |
+
if self.args.eval_use_evalscope and self.control.should_evaluate:
|
| 393 |
+
self._evalscope_eval()
|
| 394 |
+
super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
|
| 395 |
+
|
| 396 |
+
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
| 397 |
+
if self.args.optimizer is not None:
|
| 398 |
+
from swift.plugin import optimizers_map
|
| 399 |
+
optimizer_callback = optimizers_map[self.args.optimizer]
|
| 400 |
+
self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset)
|
| 401 |
+
if self.optimizer is None:
|
| 402 |
+
self.create_optimizer()
|
| 403 |
+
if self.lr_scheduler is None:
|
| 404 |
+
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
| 405 |
+
else:
|
| 406 |
+
super().create_optimizer_and_scheduler(num_training_steps=num_training_steps)
|
| 407 |
+
|
| 408 |
+
def _compute_acc(self, outputs, labels) -> None:
|
| 409 |
+
args = self.args
|
| 410 |
+
acc_steps = args.acc_steps
|
| 411 |
+
preds = outputs.logits.argmax(dim=-1)
|
| 412 |
+
if self.state.global_step % acc_steps == 0:
|
| 413 |
+
if use_torchacc():
|
| 414 |
+
ta_trim_graph()
|
| 415 |
+
preds = preds.to('cpu')
|
| 416 |
+
labels = labels.to('cpu')
|
| 417 |
+
metrics = compute_acc(
|
| 418 |
+
preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder)
|
| 419 |
+
for k, v in metrics.items():
|
| 420 |
+
if k not in self._custom_metrics:
|
| 421 |
+
self._custom_metrics[k] = MeanMetric(nan_value=None)
|
| 422 |
+
self._custom_metrics[k].update(v)
|
| 423 |
+
|
| 424 |
+
@torch.no_grad()
|
| 425 |
+
def _evalscope_eval(self):
|
| 426 |
+
from ..llm.eval.utils import EvalModel
|
| 427 |
+
from evalscope import TaskConfig, run_task
|
| 428 |
+
from evalscope.constants import EvalType
|
| 429 |
+
|
| 430 |
+
self.model.eval()
|
| 431 |
+
max_batch_size = self.args.per_device_eval_batch_size
|
| 432 |
+
custom_model = EvalModel(
|
| 433 |
+
self.model, self.template, max_batch_size=max_batch_size, model_name=f'model-step{self.state.global_step}')
|
| 434 |
+
task_config = TaskConfig(
|
| 435 |
+
model=custom_model,
|
| 436 |
+
eval_type=EvalType.CUSTOM,
|
| 437 |
+
datasets=self.args.eval_datasets,
|
| 438 |
+
dataset_args=self.args.eval_datasets_args,
|
| 439 |
+
limit=self.args.eval_limit,
|
| 440 |
+
work_dir=os.path.join(self.args.output_dir, 'eval'),
|
| 441 |
+
eval_batch_size=max_batch_size,
|
| 442 |
+
generation_config=self.args.eval_generation_config or {'max_tokens': 512},
|
| 443 |
+
)
|
| 444 |
+
# start evaluation
|
| 445 |
+
eval_report = run_task(task_config)
|
| 446 |
+
# convert to dict
|
| 447 |
+
eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()}
|
| 448 |
+
self.log(eval_dict)
|
| 449 |
+
|
| 450 |
+
self.model.train()
|
| 451 |
+
return eval_dict
|
| 452 |
+
|
| 453 |
+
def get_batch_samples(self, *args, **kwargs):
|
| 454 |
+
res = super().get_batch_samples(*args, **kwargs)
|
| 455 |
+
if self.template.sequence_parallel_size == 1:
|
| 456 |
+
return res
|
| 457 |
+
|
| 458 |
+
batch_samples, num_items_in_batch = res
|
| 459 |
+
if num_items_in_batch is None:
|
| 460 |
+
num_items_in_batch = torch.tensor(0).to(args[2])
|
| 461 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 462 |
+
dist.all_reduce(num_items_in_batch, dist.ReduceOp.SUM, sequence_parallel.sp_group)
|
| 463 |
+
return batch_samples, num_items_in_batch
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class DataLoaderMixin:
|
| 467 |
+
|
| 468 |
+
def get_train_dataloader(self):
|
| 469 |
+
dataloader = None
|
| 470 |
+
if self.template.sequence_parallel_size > 1:
|
| 471 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 472 |
+
dataloader = sequence_parallel.get_dataloader(self, self.train_dataset, self._train_batch_size)
|
| 473 |
+
if dataloader is None:
|
| 474 |
+
# Higher efficiency
|
| 475 |
+
if self.train_dataset is None:
|
| 476 |
+
raise ValueError('Trainer: training requires a train_dataset.')
|
| 477 |
+
args = self.args
|
| 478 |
+
train_dataset = self.train_dataset
|
| 479 |
+
|
| 480 |
+
dataloader_params = {
|
| 481 |
+
'collate_fn': self.data_collator,
|
| 482 |
+
'num_workers': args.dataloader_num_workers,
|
| 483 |
+
'pin_memory': args.dataloader_pin_memory,
|
| 484 |
+
'persistent_workers': args.dataloader_persistent_workers,
|
| 485 |
+
'prefetch_factor': args.dataloader_prefetch_factor
|
| 486 |
+
}
|
| 487 |
+
batch_sampler_params = {
|
| 488 |
+
'drop_last': args.dataloader_drop_last,
|
| 489 |
+
'shuffle': args.train_dataloader_shuffle,
|
| 490 |
+
'data_seed': args.data_seed,
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
if hasattr(train_dataset, '__len__'):
|
| 494 |
+
batch_sampler = BatchSamplerShard(
|
| 495 |
+
len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params)
|
| 496 |
+
dataloader = DataLoaderShard(train_dataset, batch_sampler, **dataloader_params)
|
| 497 |
+
else:
|
| 498 |
+
# IterableDataset
|
| 499 |
+
if dist.is_initialized() and dataloader_params['prefetch_factor']:
|
| 500 |
+
dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size()
|
| 501 |
+
dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size, **dataloader_params)
|
| 502 |
+
dataloader = DataLoaderDispatcher(dataloader)
|
| 503 |
+
|
| 504 |
+
return dataloader
|
| 505 |
+
|
| 506 |
+
def get_eval_dataloader(self, eval_dataset=None):
|
| 507 |
+
dataloader = None
|
| 508 |
+
if self.template.sequence_parallel_size > 1:
|
| 509 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 510 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 511 |
+
raise ValueError('Trainer: evaluation requires an eval_dataset.')
|
| 512 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 513 |
+
dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size)
|
| 514 |
+
if dataloader is None:
|
| 515 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
| 516 |
+
return dataloader
|
ms-swift/swift/trainers/optimizers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
ms-swift/swift/trainers/optimizers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (180 Bytes). View file
|
|
|
ms-swift/swift/trainers/optimizers/galore/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
from swift.utils.import_utils import _LazyModule
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from .utils import create_optimizer_and_scheduler, GaLoreConfig
|
| 9 |
+
from .adafactor import GaLoreAdafactor
|
| 10 |
+
from .adamw8bit import GaLoreAdamW8bit
|
| 11 |
+
from .adamw import GaLoreAdamW
|
| 12 |
+
else:
|
| 13 |
+
_import_structure = {
|
| 14 |
+
'utils': ['GaLoreConfig', 'create_optimizer_and_scheduler'],
|
| 15 |
+
'adafactor': ['GaLoreAdafactor'],
|
| 16 |
+
'adamw8bit': ['GaLoreAdamW8bit'],
|
| 17 |
+
'adamw': ['GaLoreAdamW'],
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
sys.modules[__name__] = _LazyModule(
|
| 23 |
+
__name__,
|
| 24 |
+
globals()['__file__'],
|
| 25 |
+
_import_structure,
|
| 26 |
+
module_spec=__spec__,
|
| 27 |
+
extra_objects={},
|
| 28 |
+
)
|
ms-swift/swift/trainers/optimizers/galore/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (765 Bytes). View file
|
|
|
ms-swift/swift/trainers/optimizers/galore/adafactor.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copy dependencies from transformers/optimization.py
|
| 2 |
+
# code borrowed from https://github.com/jiaweizzhao/GaLore
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.optim import Optimizer
|
| 7 |
+
from transformers.utils.versions import require_version
|
| 8 |
+
|
| 9 |
+
from .galore_projector import GaLoreProjector
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Adafactor(Optimizer):
|
| 13 |
+
"""
|
| 14 |
+
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
| 15 |
+
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
| 16 |
+
|
| 17 |
+
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
|
| 18 |
+
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
| 19 |
+
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
| 20 |
+
`relative_step=False`.
|
| 21 |
+
|
| 22 |
+
Arguments:
|
| 23 |
+
params (`Iterable[nn.parameter.Parameter]`):
|
| 24 |
+
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
| 25 |
+
lr (`float`, *optional*):
|
| 26 |
+
The external learning rate.
|
| 27 |
+
eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
|
| 28 |
+
Regularization constants for square gradient and parameter scale respectively
|
| 29 |
+
clip_threshold (`float`, *optional*, defaults to 1.0):
|
| 30 |
+
Threshold of root mean square of final gradient update
|
| 31 |
+
decay_rate (`float`, *optional*, defaults to -0.8):
|
| 32 |
+
Coefficient used to compute running averages of square
|
| 33 |
+
beta1 (`float`, *optional*):
|
| 34 |
+
Coefficient used for computing running averages of gradient
|
| 35 |
+
weight_decay (`float`, *optional*, defaults to 0.0):
|
| 36 |
+
Weight decay (L2 penalty)
|
| 37 |
+
scale_parameter (`bool`, *optional*, defaults to `True`):
|
| 38 |
+
If True, learning rate is scaled by root mean square
|
| 39 |
+
relative_step (`bool`, *optional*, defaults to `True`):
|
| 40 |
+
If True, time-dependent learning rate is computed instead of external learning rate
|
| 41 |
+
warmup_init (`bool`, *optional*, defaults to `False`):
|
| 42 |
+
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
| 43 |
+
|
| 44 |
+
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
| 45 |
+
|
| 46 |
+
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
| 47 |
+
|
| 48 |
+
- Training without LR warmup or clip_threshold is not recommended.
|
| 49 |
+
|
| 50 |
+
- use scheduled LR warm-up to fixed LR
|
| 51 |
+
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
|
| 52 |
+
- Disable relative updates
|
| 53 |
+
- Use scale_parameter=False
|
| 54 |
+
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
| 55 |
+
|
| 56 |
+
Example:
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Others reported the following combination to work well:
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
| 69 |
+
scheduler as following:
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
| 73 |
+
|
| 74 |
+
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
| 75 |
+
lr_scheduler = AdafactorSchedule(optimizer)
|
| 76 |
+
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
Usage:
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
# replace AdamW with Adafactor
|
| 83 |
+
optimizer = Adafactor(
|
| 84 |
+
model.parameters(),
|
| 85 |
+
lr=1e-3,
|
| 86 |
+
eps=(1e-30, 1e-3),
|
| 87 |
+
clip_threshold=1.0,
|
| 88 |
+
decay_rate=-0.8,
|
| 89 |
+
beta1=None,
|
| 90 |
+
weight_decay=0.0,
|
| 91 |
+
relative_step=False,
|
| 92 |
+
scale_parameter=False,
|
| 93 |
+
warmup_init=False,
|
| 94 |
+
)
|
| 95 |
+
```"""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
params,
|
| 100 |
+
lr=None,
|
| 101 |
+
eps=(1e-30, 1e-3),
|
| 102 |
+
clip_threshold=1.0,
|
| 103 |
+
decay_rate=-0.8,
|
| 104 |
+
beta1=None,
|
| 105 |
+
weight_decay=0.0,
|
| 106 |
+
scale_parameter=True,
|
| 107 |
+
relative_step=True,
|
| 108 |
+
warmup_init=False,
|
| 109 |
+
):
|
| 110 |
+
require_version('torch>=1.5.0') # add_ with alpha
|
| 111 |
+
if lr is not None and relative_step:
|
| 112 |
+
raise ValueError('Cannot combine manual `lr` and `relative_step=True` options')
|
| 113 |
+
if warmup_init and not relative_step:
|
| 114 |
+
raise ValueError('`warmup_init=True` requires `relative_step=True`')
|
| 115 |
+
|
| 116 |
+
defaults = {
|
| 117 |
+
'lr': lr,
|
| 118 |
+
'eps': eps,
|
| 119 |
+
'clip_threshold': clip_threshold,
|
| 120 |
+
'decay_rate': decay_rate,
|
| 121 |
+
'beta1': beta1,
|
| 122 |
+
'weight_decay': weight_decay,
|
| 123 |
+
'scale_parameter': scale_parameter,
|
| 124 |
+
'relative_step': relative_step,
|
| 125 |
+
'warmup_init': warmup_init,
|
| 126 |
+
}
|
| 127 |
+
super().__init__(params, defaults)
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _get_lr(param_group, param_state):
|
| 131 |
+
rel_step_sz = param_group['lr']
|
| 132 |
+
if param_group['relative_step']:
|
| 133 |
+
min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
|
| 134 |
+
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state['step']))
|
| 135 |
+
param_scale = 1.0
|
| 136 |
+
if param_group['scale_parameter']:
|
| 137 |
+
param_scale = max(param_group['eps'][1], param_state['RMS'])
|
| 138 |
+
return param_scale * rel_step_sz
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def _get_options(param_group, param_shape):
|
| 142 |
+
factored = len(param_shape) >= 2
|
| 143 |
+
use_first_moment = param_group['beta1'] is not None
|
| 144 |
+
return factored, use_first_moment
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def _rms(tensor):
|
| 148 |
+
return tensor.norm(2) / (tensor.numel()**0.5)
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
| 152 |
+
# copy from fairseq's adafactor implementation:
|
| 153 |
+
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
| 154 |
+
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
| 155 |
+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
| 156 |
+
return torch.mul(r_factor, c_factor)
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def step(self, closure=None):
|
| 160 |
+
"""
|
| 161 |
+
Performs a single optimization step
|
| 162 |
+
|
| 163 |
+
Arguments:
|
| 164 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 165 |
+
and returns the loss.
|
| 166 |
+
"""
|
| 167 |
+
loss = None
|
| 168 |
+
if closure is not None:
|
| 169 |
+
loss = closure()
|
| 170 |
+
|
| 171 |
+
for group in self.param_groups:
|
| 172 |
+
for p in group['params']:
|
| 173 |
+
if p.grad is None:
|
| 174 |
+
continue
|
| 175 |
+
grad = p.grad
|
| 176 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
| 177 |
+
grad = grad.float()
|
| 178 |
+
if grad.is_sparse:
|
| 179 |
+
raise RuntimeError('Adafactor does not support sparse gradients.')
|
| 180 |
+
|
| 181 |
+
state = self.state[p]
|
| 182 |
+
|
| 183 |
+
if 'step' not in state:
|
| 184 |
+
state['step'] = 0
|
| 185 |
+
|
| 186 |
+
# GaLore Projection
|
| 187 |
+
if 'rank' in group:
|
| 188 |
+
if 'projector' not in state:
|
| 189 |
+
state['projector'] = GaLoreProjector(
|
| 190 |
+
group['rank'],
|
| 191 |
+
update_proj_gap=group['update_proj_gap'],
|
| 192 |
+
scale=group['scale'],
|
| 193 |
+
proj_type=group['proj_type'])
|
| 194 |
+
|
| 195 |
+
grad = state['projector'].project(grad, state['step'])
|
| 196 |
+
|
| 197 |
+
grad_shape = grad.shape
|
| 198 |
+
|
| 199 |
+
factored, use_first_moment = self._get_options(group, grad_shape)
|
| 200 |
+
# State Initialization
|
| 201 |
+
if 'RMS' not in state:
|
| 202 |
+
state['step'] = 0
|
| 203 |
+
|
| 204 |
+
if use_first_moment:
|
| 205 |
+
# Exponential moving average of gradient values
|
| 206 |
+
state['exp_avg'] = torch.zeros_like(grad)
|
| 207 |
+
if factored:
|
| 208 |
+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
|
| 209 |
+
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
| 210 |
+
else:
|
| 211 |
+
state['exp_avg_sq'] = torch.zeros_like(grad)
|
| 212 |
+
|
| 213 |
+
state['RMS'] = 0
|
| 214 |
+
else:
|
| 215 |
+
if use_first_moment:
|
| 216 |
+
state['exp_avg'] = state['exp_avg'].to(grad)
|
| 217 |
+
if factored:
|
| 218 |
+
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
|
| 219 |
+
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
|
| 220 |
+
else:
|
| 221 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
|
| 222 |
+
|
| 223 |
+
p_data_fp32 = p
|
| 224 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 225 |
+
p_data_fp32 = p_data_fp32.float()
|
| 226 |
+
|
| 227 |
+
state['step'] += 1
|
| 228 |
+
state['RMS'] = self._rms(p_data_fp32)
|
| 229 |
+
lr = self._get_lr(group, state)
|
| 230 |
+
|
| 231 |
+
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
| 232 |
+
update = (grad**2) + group['eps'][0]
|
| 233 |
+
if factored:
|
| 234 |
+
exp_avg_sq_row = state['exp_avg_sq_row']
|
| 235 |
+
exp_avg_sq_col = state['exp_avg_sq_col']
|
| 236 |
+
|
| 237 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
| 238 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
| 239 |
+
|
| 240 |
+
# Approximation of exponential moving average of square of gradient
|
| 241 |
+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
| 242 |
+
update.mul_(grad)
|
| 243 |
+
else:
|
| 244 |
+
exp_avg_sq = state['exp_avg_sq']
|
| 245 |
+
|
| 246 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
| 247 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
| 248 |
+
|
| 249 |
+
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
|
| 250 |
+
update.mul_(lr)
|
| 251 |
+
|
| 252 |
+
if use_first_moment:
|
| 253 |
+
exp_avg = state['exp_avg']
|
| 254 |
+
exp_avg.mul_(group['beta1']).add_(update, alpha=(1 - group['beta1']))
|
| 255 |
+
update = exp_avg
|
| 256 |
+
|
| 257 |
+
# GaLore Projection Back
|
| 258 |
+
if 'rank' in group:
|
| 259 |
+
update = state['projector'].project_back(update)
|
| 260 |
+
|
| 261 |
+
if group['weight_decay'] != 0:
|
| 262 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group['weight_decay'] * lr))
|
| 263 |
+
|
| 264 |
+
p_data_fp32.add_(-update)
|
| 265 |
+
|
| 266 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 267 |
+
p.copy_(p_data_fp32)
|
| 268 |
+
|
| 269 |
+
return loss
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
GaLoreAdafactor = Adafactor
|
ms-swift/swift/trainers/optimizers/galore/galore_projector.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# code borrowed from https://github.com/jiaweizzhao/GaLore
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GaLoreProjector:
|
| 7 |
+
|
| 8 |
+
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
|
| 9 |
+
self.rank = rank
|
| 10 |
+
self.verbose = verbose
|
| 11 |
+
self.update_proj_gap = update_proj_gap
|
| 12 |
+
self.scale = scale
|
| 13 |
+
self.ortho_matrix = None
|
| 14 |
+
self.proj_type = proj_type
|
| 15 |
+
|
| 16 |
+
def project(self, full_rank_grad, iter):
|
| 17 |
+
|
| 18 |
+
if self.proj_type == 'std':
|
| 19 |
+
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
| 20 |
+
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
| 21 |
+
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
| 22 |
+
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
| 23 |
+
else:
|
| 24 |
+
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
| 25 |
+
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
| 26 |
+
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
| 27 |
+
elif self.proj_type == 'reverse_std':
|
| 28 |
+
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
| 29 |
+
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
| 30 |
+
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
| 31 |
+
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
| 32 |
+
else:
|
| 33 |
+
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
| 34 |
+
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
| 35 |
+
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
| 36 |
+
elif self.proj_type == 'right':
|
| 37 |
+
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
| 38 |
+
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
| 39 |
+
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
| 40 |
+
elif self.proj_type == 'left':
|
| 41 |
+
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
| 42 |
+
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
| 43 |
+
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
| 44 |
+
elif self.proj_type == 'full':
|
| 45 |
+
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
| 46 |
+
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
|
| 47 |
+
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
|
| 48 |
+
|
| 49 |
+
return low_rank_grad
|
| 50 |
+
|
| 51 |
+
def project_back(self, low_rank_grad):
|
| 52 |
+
|
| 53 |
+
if self.proj_type == 'std':
|
| 54 |
+
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
| 55 |
+
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
| 56 |
+
else:
|
| 57 |
+
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
| 58 |
+
elif self.proj_type == 'reverse_std':
|
| 59 |
+
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
| 60 |
+
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
| 61 |
+
else:
|
| 62 |
+
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
| 63 |
+
elif self.proj_type == 'right':
|
| 64 |
+
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
| 65 |
+
elif self.proj_type == 'left':
|
| 66 |
+
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
| 67 |
+
elif self.proj_type == 'full':
|
| 68 |
+
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
|
| 69 |
+
|
| 70 |
+
return full_rank_grad * self.scale
|
| 71 |
+
|
| 72 |
+
# svd decomposition
|
| 73 |
+
def get_orthogonal_matrix(self, weights, rank, type):
|
| 74 |
+
module_params = weights
|
| 75 |
+
|
| 76 |
+
if module_params.data.dtype != torch.float:
|
| 77 |
+
float_data = False
|
| 78 |
+
original_type = module_params.data.dtype
|
| 79 |
+
original_device = module_params.data.device
|
| 80 |
+
matrix = module_params.data.float()
|
| 81 |
+
else:
|
| 82 |
+
float_data = True
|
| 83 |
+
matrix = module_params.data
|
| 84 |
+
|
| 85 |
+
U, s, Vh = torch.linalg.svd(matrix, full_matrices=False)
|
| 86 |
+
|
| 87 |
+
# make the smaller matrix always to be orthogonal matrix
|
| 88 |
+
if type == 'right':
|
| 89 |
+
A = U[:, :rank] @ torch.diag(s[:rank])
|
| 90 |
+
B = Vh[:rank, :]
|
| 91 |
+
|
| 92 |
+
if not float_data:
|
| 93 |
+
B = B.to(original_device).type(original_type)
|
| 94 |
+
return B
|
| 95 |
+
elif type == 'left':
|
| 96 |
+
A = U[:, :rank]
|
| 97 |
+
B = torch.diag(s[:rank]) @ Vh[:rank, :]
|
| 98 |
+
if not float_data:
|
| 99 |
+
A = A.to(original_device).type(original_type)
|
| 100 |
+
return A
|
| 101 |
+
elif type == 'full':
|
| 102 |
+
A = U[:, :rank]
|
| 103 |
+
B = Vh[:rank, :]
|
| 104 |
+
if not float_data:
|
| 105 |
+
A = A.to(original_device).type(original_type)
|
| 106 |
+
B = B.to(original_device).type(original_type)
|
| 107 |
+
return [A, B]
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError('type should be left, right or full')
|
ms-swift/swift/trainers/optimizers/galore/utils.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import importlib
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.optim import Optimizer
|
| 9 |
+
from transformers import Trainer, TrainingArguments, get_scheduler
|
| 10 |
+
|
| 11 |
+
from swift.utils import get_logger
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
| 15 |
+
except ImportError:
|
| 16 |
+
from torch.optim.lr_scheduler import LRScheduler
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class GaLoreConfig:
|
| 23 |
+
"""
|
| 24 |
+
The configuration class for the Galore module.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
See https://arxiv.org/abs/2403.03507
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
rank (`int`): The galore rank
|
| 31 |
+
target_modules (`Union[str, List[str]]`): The target modules to use, if `None`,
|
| 32 |
+
will use all attn and mlp linears
|
| 33 |
+
update_proj_gap(`int`): The projection update interval for galore
|
| 34 |
+
proj_type(`str`) The project type of Galore, valid values are `std`,
|
| 35 |
+
`reverse_std`, `right`, `left`, `full`
|
| 36 |
+
galore_scale(float): the scale of gradient
|
| 37 |
+
optim_per_parameter(bool): Gives one optimizer per parameter
|
| 38 |
+
"""
|
| 39 |
+
rank: int = 128
|
| 40 |
+
target_modules: Union[str, List[str]] = None
|
| 41 |
+
update_proj_gap: int = 50
|
| 42 |
+
galore_scale: float = 1.0
|
| 43 |
+
proj_type: str = 'std'
|
| 44 |
+
optim_per_parameter: bool = False
|
| 45 |
+
quantize: bool = False
|
| 46 |
+
proj_quant: bool = False
|
| 47 |
+
proj_bits: int = 4
|
| 48 |
+
proj_group_size: int = 256
|
| 49 |
+
cos_threshold: float = 0.4
|
| 50 |
+
gamma_proj: int = 2
|
| 51 |
+
queue_size: int = 5
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class GaloreOptimizerWrapper(Optimizer):
|
| 55 |
+
|
| 56 |
+
def __init__(self, optimizers: Dict[Any, Optimizer]):
|
| 57 |
+
self.optimizers = optimizers
|
| 58 |
+
super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.})
|
| 59 |
+
|
| 60 |
+
def zero_grad(self, *args, **kwargs) -> None:
|
| 61 |
+
for optim in self.optimizers.values():
|
| 62 |
+
optim.zero_grad(*args, **kwargs)
|
| 63 |
+
|
| 64 |
+
def step(self, *args, **kwargs) -> None:
|
| 65 |
+
for optim in self.optimizers.values():
|
| 66 |
+
optim.step(*args, **kwargs)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class GaloreSchedulerWrapper(LRScheduler):
|
| 70 |
+
|
| 71 |
+
def __init__(self, lr_schedulers: Dict[Any, LRScheduler]):
|
| 72 |
+
self.lr_schedulers = lr_schedulers
|
| 73 |
+
|
| 74 |
+
def step(self, *args, **kwargs) -> None:
|
| 75 |
+
for lr_scheduler in self.lr_schedulers.values():
|
| 76 |
+
lr_scheduler.step(*args, **kwargs)
|
| 77 |
+
self._last_lr = lr_scheduler.get_last_lr()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps,
|
| 81 |
+
**defaults):
|
| 82 |
+
galore_params = []
|
| 83 |
+
for module_name, module in model.named_modules():
|
| 84 |
+
if not isinstance(module, (nn.Linear, nn.Embedding)) or \
|
| 85 |
+
not any(target_key in module_name for target_key in config.target_modules):
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
if not module.weight.requires_grad:
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
logger.info(f'Enable GaLore for weights in module: {module_name}')
|
| 92 |
+
galore_params.append(module.weight)
|
| 93 |
+
|
| 94 |
+
id_galore_params = [id(p) for p in galore_params]
|
| 95 |
+
galore_defaults = {
|
| 96 |
+
'rank': config.rank,
|
| 97 |
+
'update_proj_gap': config.update_proj_gap,
|
| 98 |
+
'scale': config.galore_scale,
|
| 99 |
+
'proj_type': config.proj_type,
|
| 100 |
+
**defaults
|
| 101 |
+
}
|
| 102 |
+
if config.quantize:
|
| 103 |
+
galore_defaults['quant'] = config.proj_quant
|
| 104 |
+
galore_defaults['quant_n_bit'] = config.proj_bits
|
| 105 |
+
galore_defaults['quant_group_size'] = config.proj_group_size
|
| 106 |
+
galore_defaults['cos_threshold'] = config.cos_threshold
|
| 107 |
+
galore_defaults['gamma_proj'] = config.gamma_proj
|
| 108 |
+
galore_defaults['queue_size'] = config.queue_size
|
| 109 |
+
optim_cls, optim_kwargs = get_optimizer(args, config)
|
| 110 |
+
|
| 111 |
+
if config.optim_per_parameter and not config.quantize:
|
| 112 |
+
# q-galore does not support optim_per_parameter
|
| 113 |
+
optimizer_dict = {}
|
| 114 |
+
galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2
|
| 115 |
+
for p in model.parameters():
|
| 116 |
+
if p.requires_grad:
|
| 117 |
+
if id(p) in id_galore_params:
|
| 118 |
+
optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs)
|
| 119 |
+
else:
|
| 120 |
+
optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs)
|
| 121 |
+
|
| 122 |
+
# get scheduler dict
|
| 123 |
+
scheduler_dict = {}
|
| 124 |
+
for p in model.parameters():
|
| 125 |
+
if p.requires_grad:
|
| 126 |
+
scheduler_dict[p] = get_scheduler(
|
| 127 |
+
optimizer=optimizer_dict[p],
|
| 128 |
+
name=args.lr_scheduler_type,
|
| 129 |
+
num_training_steps=max_steps * 2,
|
| 130 |
+
num_warmup_steps=args.warmup_steps * 2,
|
| 131 |
+
scheduler_specific_kwargs=args.lr_scheduler_kwargs,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict)
|
| 135 |
+
else:
|
| 136 |
+
decay_parameters = Trainer.get_decay_parameter_names(Trainer, model)
|
| 137 |
+
param_groups = [{
|
| 138 |
+
'params': galore_params,
|
| 139 |
+
**galore_defaults,
|
| 140 |
+
}]
|
| 141 |
+
param_groups.extend([
|
| 142 |
+
{
|
| 143 |
+
'params': [
|
| 144 |
+
p for n, p in model.named_parameters()
|
| 145 |
+
if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
|
| 146 |
+
],
|
| 147 |
+
'weight_decay':
|
| 148 |
+
defaults['weight_decay'],
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
'params': [
|
| 152 |
+
p for n, p in model.named_parameters()
|
| 153 |
+
if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
|
| 154 |
+
],
|
| 155 |
+
'weight_decay':
|
| 156 |
+
0.0,
|
| 157 |
+
},
|
| 158 |
+
])
|
| 159 |
+
optim = optim_cls(param_groups, **optim_kwargs)
|
| 160 |
+
scheduler = get_scheduler(
|
| 161 |
+
optimizer=optim,
|
| 162 |
+
name=args.lr_scheduler_type,
|
| 163 |
+
num_training_steps=max_steps,
|
| 164 |
+
num_warmup_steps=args.warmup_steps,
|
| 165 |
+
scheduler_specific_kwargs=args.lr_scheduler_kwargs,
|
| 166 |
+
)
|
| 167 |
+
return optim, scheduler
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]:
|
| 171 |
+
# parse args.optim_args
|
| 172 |
+
optim_args = {}
|
| 173 |
+
if args.optim_args:
|
| 174 |
+
for mapping in args.optim_args.replace(' ', '').split(','):
|
| 175 |
+
key, value = mapping.split('=')
|
| 176 |
+
optim_args[key] = value
|
| 177 |
+
|
| 178 |
+
optimizer_kwargs = {'lr': args.learning_rate}
|
| 179 |
+
|
| 180 |
+
adam_kwargs = {
|
| 181 |
+
'betas': (args.adam_beta1, args.adam_beta2),
|
| 182 |
+
'eps': args.adam_epsilon,
|
| 183 |
+
}
|
| 184 |
+
if args.optim == 'adafactor':
|
| 185 |
+
from .adafactor import GaLoreAdafactor
|
| 186 |
+
optimizer_cls = GaLoreAdafactor
|
| 187 |
+
optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
|
| 188 |
+
elif args.optim in ('adamw_hf', 'adamw_torch'):
|
| 189 |
+
if config.quantize:
|
| 190 |
+
assert importlib.util.find_spec('q_galore_torch') is not None, \
|
| 191 |
+
'Please install q-galore by `pip install q_galore_torch`'
|
| 192 |
+
logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0')
|
| 193 |
+
from swift.utils import get_dist_setting
|
| 194 |
+
_, _, world_size, _ = get_dist_setting()
|
| 195 |
+
if world_size > 1:
|
| 196 |
+
# from q_galore_torch import QGaLoreAdamW8bit_simulate as GaLoreAdamW
|
| 197 |
+
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
|
| 198 |
+
else:
|
| 199 |
+
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
|
| 200 |
+
else:
|
| 201 |
+
from .adamw import GaLoreAdamW
|
| 202 |
+
optimizer_cls = GaLoreAdamW
|
| 203 |
+
optimizer_kwargs.update(adam_kwargs)
|
| 204 |
+
elif 'adamw' in args.optim and '8bit' in args.optim:
|
| 205 |
+
try:
|
| 206 |
+
from .adamw8bit import GaLoreAdamW8bit
|
| 207 |
+
optimizer_cls = GaLoreAdamW8bit
|
| 208 |
+
optimizer_kwargs.update(adam_kwargs)
|
| 209 |
+
optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim})
|
| 210 |
+
except ImportError:
|
| 211 |
+
raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!')
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError(f'Galore not supported for optimizer type: {args.optim}')
|
| 214 |
+
return optimizer_cls, optimizer_kwargs
|
ms-swift/swift/trainers/rlhf_arguments.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from trl import CPOConfig as HfCPOConfig
|
| 5 |
+
from trl import DPOConfig as HfDPOConfig
|
| 6 |
+
from trl import GRPOConfig as HfGRPOConfig
|
| 7 |
+
from trl import KTOConfig as HfKTOConfig
|
| 8 |
+
from trl import ORPOConfig as HfORPOConfig
|
| 9 |
+
from trl import PPOConfig as HfPPOConfig
|
| 10 |
+
from trl import RewardConfig as HfRewardConfig
|
| 11 |
+
|
| 12 |
+
from .arguments import GRPOArgumentsMixin, SwiftArgumentsMixin
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DPOConfig(SwiftArgumentsMixin, HfDPOConfig):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class CPOConfig(SwiftArgumentsMixin, HfCPOConfig):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ORPOConfig(SwiftArgumentsMixin, HfORPOConfig):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class KTOConfig(SwiftArgumentsMixin, HfKTOConfig):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class RewardConfig(SwiftArgumentsMixin, HfRewardConfig):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class PPOConfig(SwiftArgumentsMixin, HfPPOConfig):
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
|
| 47 |
+
stop_words: List[str] = field(default_factory=list)
|
| 48 |
+
|
| 49 |
+
def __post_init__(self):
|
| 50 |
+
from swift.llm.argument.base_args.model_args import ModelArguments
|
| 51 |
+
super().__post_init__()
|
| 52 |
+
if self.cosine_max_len is None:
|
| 53 |
+
self.cosine_max_len = self.max_completion_length
|
| 54 |
+
self.vllm_limit_mm_per_prompt = ModelArguments.parse_to_dict(self.vllm_limit_mm_per_prompt)
|
| 55 |
+
|
| 56 |
+
if self.deepspeed and 'zero_optimization' in self.deepspeed and self.deepspeed['zero_optimization'][
|
| 57 |
+
'stage'] == 3:
|
| 58 |
+
# https://github.com/modelscope/ms-swift/issues/3237
|
| 59 |
+
self.deepspeed['zero_optimization']['stage3_prefetch_bucket_size'] = 0
|
| 60 |
+
self.deepspeed_plugin.hf_ds_config.config['zero_optimization']['stage3_prefetch_bucket_size'] = 0
|
| 61 |
+
|
| 62 |
+
# https://github.com/modelscope/ms-swift/issues/3863
|
| 63 |
+
self.dataloader_drop_last = True
|
ms-swift/swift/trainers/rlhf_trainer/kto_trainer.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from peft import PeftModel
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
from trl import KTOTrainer as HFKTOTrainer
|
| 10 |
+
|
| 11 |
+
from swift.utils import get_logger
|
| 12 |
+
from ..mixin import SwiftMixin
|
| 13 |
+
from .rlhf_mixin import RLHFTrainerMixin
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
del HFKTOTrainer.__init__
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class KTOTrainer(RLHFTrainerMixin, SwiftMixin, HFKTOTrainer):
|
| 21 |
+
|
| 22 |
+
def __init__(self,
|
| 23 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 24 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 25 |
+
*_args,
|
| 26 |
+
**kwargs):
|
| 27 |
+
args = kwargs['args']
|
| 28 |
+
args.disable_dropout = True
|
| 29 |
+
self.desirable_weight = args.desirable_weight
|
| 30 |
+
self.undesirable_weight = args.undesirable_weight
|
| 31 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
| 32 |
+
self.is_peft_model = isinstance(model, PeftModel)
|
| 33 |
+
if hasattr(args, 'loss_type'):
|
| 34 |
+
self.loss_type = args.loss_type
|
| 35 |
+
else:
|
| 36 |
+
self.loss_type = 'kto'
|
| 37 |
+
|
| 38 |
+
self.ref_adapter_name = None
|
| 39 |
+
# Not all losses require a KL calculation
|
| 40 |
+
self.calculate_KL = True
|
| 41 |
+
if self.loss_type in ['apo_zero_unpaired']:
|
| 42 |
+
self.calculate_KL = False
|
| 43 |
+
super().__init__(model, ref_model, *_args, **kwargs)
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
| 47 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 48 |
+
is_kl = True
|
| 49 |
+
|
| 50 |
+
def _add_data_hook(model, args, kwargs):
|
| 51 |
+
nonlocal is_kl
|
| 52 |
+
if is_kl:
|
| 53 |
+
kwargs = {k[len('KL_completion_'):]: v for k, v in batch.items() if k.startswith('KL_completion_')}
|
| 54 |
+
else:
|
| 55 |
+
kwargs = {k[len('completion_'):]: v for k, v in batch.items() if k.startswith('completion_')}
|
| 56 |
+
is_kl = not is_kl
|
| 57 |
+
return (), kwargs
|
| 58 |
+
|
| 59 |
+
@contextmanager
|
| 60 |
+
def _patch_model_call():
|
| 61 |
+
handle = model.register_forward_pre_hook(_add_data_hook, with_kwargs=True, prepend=True)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
yield
|
| 65 |
+
finally:
|
| 66 |
+
handle.remove()
|
| 67 |
+
|
| 68 |
+
with _patch_model_call():
|
| 69 |
+
return super().forward(model, batch)
|
ms-swift/swift/trainers/rlhf_trainer/orpo_trainer.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import PreTrainedModel
|
| 6 |
+
from trl import ORPOTrainer as HFORPOTrainer
|
| 7 |
+
|
| 8 |
+
from ..mixin import SwiftMixin
|
| 9 |
+
from .rlhf_mixin import RLHFTrainerMixin
|
| 10 |
+
|
| 11 |
+
del HFORPOTrainer.__init__
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ORPOTrainer(RLHFTrainerMixin, SwiftMixin, HFORPOTrainer):
|
| 15 |
+
|
| 16 |
+
def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs):
|
| 17 |
+
ref_model = kwargs.get('ref_model')
|
| 18 |
+
assert ref_model is None, 'ORPO does not require a ref_model.'
|
| 19 |
+
super().__init__(model, *_args, **kwargs)
|
ms-swift/swift/trainers/rlhf_trainer/ppo_trainer.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import inspect
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
|
| 5 |
+
import transformers
|
| 6 |
+
from packaging import version
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
from trl import PPOTrainer as HFPPOTrainer
|
| 10 |
+
|
| 11 |
+
from swift.utils import patch_getattr
|
| 12 |
+
from ..mixin import SwiftMixin
|
| 13 |
+
|
| 14 |
+
ppo_trainer_init = HFPPOTrainer.__init__
|
| 15 |
+
del HFPPOTrainer.__init__
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PPOTrainer(SwiftMixin, HFPPOTrainer):
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
@contextmanager
|
| 22 |
+
def _patch_dataloader(collate_fn):
|
| 23 |
+
__init__ = DataLoader.__init__
|
| 24 |
+
|
| 25 |
+
def __new_init__(self, *args, **kwargs):
|
| 26 |
+
kwargs['collate_fn'] = collate_fn
|
| 27 |
+
__init__(self, *args, **kwargs)
|
| 28 |
+
|
| 29 |
+
DataLoader.__init__ = __new_init__
|
| 30 |
+
try:
|
| 31 |
+
yield
|
| 32 |
+
finally:
|
| 33 |
+
DataLoader.__init__ = __init__
|
| 34 |
+
|
| 35 |
+
def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, **kwargs):
|
| 36 |
+
super().__init__(model, *_args, **{k: v for k, v in kwargs.items() if k not in {'reward_model', 'value_model'}})
|
| 37 |
+
with self._patch_dataloader(kwargs['data_collator']):
|
| 38 |
+
new_kwargs = {
|
| 39 |
+
k: v
|
| 40 |
+
for k, v in kwargs.items()
|
| 41 |
+
if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset']
|
| 42 |
+
}
|
| 43 |
+
parameters = inspect.signature(ppo_trainer_init).parameters
|
| 44 |
+
if 'config' in parameters:
|
| 45 |
+
new_kwargs['config'] = kwargs['args']
|
| 46 |
+
else:
|
| 47 |
+
new_kwargs['args'] = kwargs['args']
|
| 48 |
+
if 'processing_class' in parameters:
|
| 49 |
+
new_kwargs['processing_class'] = self.tokenizer
|
| 50 |
+
else:
|
| 51 |
+
new_kwargs['tokenizer'] = self.tokenizer
|
| 52 |
+
ppo_trainer_init(self, model=model, ref_model=ref_model, **new_kwargs)
|
| 53 |
+
unwrap_model = self.accelerator.unwrap_model(self.model)
|
| 54 |
+
patch_getattr(unwrap_model.__class__, 'policy')
|
| 55 |
+
|
| 56 |
+
def train(self, *args, **kwargs):
|
| 57 |
+
# remove args that are not needed for the HFPPOTrainer
|
| 58 |
+
super().train()
|
| 59 |
+
|
| 60 |
+
def _save_checkpoint(self, *args, **kwargs):
|
| 61 |
+
if version.parse(transformers.__version__) >= version.parse('4.47'):
|
| 62 |
+
metrics = kwargs.pop('metrics', None)
|
| 63 |
+
trial = kwargs.get('trial')
|
| 64 |
+
self._determine_best_metric(metrics=metrics, trial=trial)
|
| 65 |
+
return super()._save_checkpoint(*args, **kwargs)
|
ms-swift/swift/trainers/rlhf_trainer/reward_trainer.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from typing import Any, Dict, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from accelerate.utils import gather_object
|
| 9 |
+
from transformers import PreTrainedModel
|
| 10 |
+
from trl import RewardTrainer as HFRewardTrainer
|
| 11 |
+
from trl.trainer.utils import print_rich_table
|
| 12 |
+
|
| 13 |
+
from ..mixin import SwiftMixin
|
| 14 |
+
from .rlhf_mixin import RLHFTrainerMixin
|
| 15 |
+
|
| 16 |
+
del HFRewardTrainer.__init__
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RewardTrainer(RLHFTrainerMixin, SwiftMixin, HFRewardTrainer):
|
| 20 |
+
|
| 21 |
+
def compute_loss(self,
|
| 22 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 23 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
| 24 |
+
return_outputs=False,
|
| 25 |
+
num_items_in_batch=None) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
| 26 |
+
inputs.pop('labels', None) # not use
|
| 27 |
+
attention_mask = inputs['attention_mask']
|
| 28 |
+
batch_size = attention_mask.shape[0] // 2
|
| 29 |
+
rewards = model(**inputs).logits
|
| 30 |
+
rewards_chosen, rewards_rejected = torch.split(rewards, batch_size, dim=0)
|
| 31 |
+
if 'margin' in inputs:
|
| 32 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs['margin']).mean()
|
| 33 |
+
else:
|
| 34 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
| 35 |
+
if self.args.center_rewards_coefficient is not None:
|
| 36 |
+
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected)**2)
|
| 37 |
+
# compat transformers>=4.46.*
|
| 38 |
+
if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
|
| 39 |
+
loss /= self.args.gradient_accumulation_steps
|
| 40 |
+
if return_outputs:
|
| 41 |
+
return loss, {
|
| 42 |
+
'rewards_chosen': rewards_chosen,
|
| 43 |
+
'rewards_rejected': rewards_rejected,
|
| 44 |
+
}
|
| 45 |
+
return loss
|
| 46 |
+
|
| 47 |
+
def visualize_samples(self, num_print_samples: int):
|
| 48 |
+
"""
|
| 49 |
+
Visualize the reward model logits prediction
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
num_print_samples (`int`, defaults to `4`):
|
| 53 |
+
The number of samples to print. Set to `-1` to print all samples.
|
| 54 |
+
"""
|
| 55 |
+
eval_dataloader = self.get_eval_dataloader()
|
| 56 |
+
table = defaultdict(list)
|
| 57 |
+
for _, inputs in enumerate(eval_dataloader):
|
| 58 |
+
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
| 59 |
+
input_ids = inputs['input_ids']
|
| 60 |
+
attention_mask = inputs['attention_mask']
|
| 61 |
+
sequence_lengths = ((torch.eq(attention_mask, 0).int().argmax(-1) - 1) % attention_mask.shape[1]).tolist()
|
| 62 |
+
text = [self.template.safe_decode(tokens[:sequence_lengths[i]]) for i, tokens in enumerate(input_ids)]
|
| 63 |
+
batch_size = input_ids.shape[0] // 2
|
| 64 |
+
chosen_text, rejected_text = text[:batch_size], text[batch_size:]
|
| 65 |
+
table['chosen_text'].extend(gather_object(chosen_text))
|
| 66 |
+
table['rejected_text'].extend(gather_object(rejected_text))
|
| 67 |
+
table['logits'].extend(
|
| 68 |
+
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]))
|
| 69 |
+
if 0 <= num_print_samples <= len(table['chosen_text']):
|
| 70 |
+
break
|
| 71 |
+
df = pd.DataFrame(table)
|
| 72 |
+
if self.accelerator.process_index == 0:
|
| 73 |
+
print_rich_table(df[:num_print_samples])
|
| 74 |
+
if 'wandb' in self.args.report_to:
|
| 75 |
+
import wandb
|
| 76 |
+
|
| 77 |
+
if wandb.run is not None:
|
| 78 |
+
wandb.log({'completions': wandb.Table(dataframe=df)})
|
ms-swift/swift/trainers/rlhf_trainer/rlhf_mixin.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from contextlib import contextmanager, nullcontext
|
| 4 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from trl import AutoModelForCausalLMWithValueHead
|
| 13 |
+
except (ImportError, RuntimeError):
|
| 14 |
+
AutoModelForCausalLMWithValueHead = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RLHFTrainerMixin:
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 21 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 22 |
+
*_args,
|
| 23 |
+
**kwargs):
|
| 24 |
+
from trl.trainer import disable_dropout_in_model
|
| 25 |
+
from swift.llm import HfConfigFactory
|
| 26 |
+
self.ref_model = ref_model
|
| 27 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 28 |
+
args = kwargs['args']
|
| 29 |
+
self.beta = getattr(args, 'beta', 0.0)
|
| 30 |
+
if getattr(args, 'disable_dropout', False):
|
| 31 |
+
disable_dropout_in_model(model)
|
| 32 |
+
if self.ref_model is not None:
|
| 33 |
+
disable_dropout_in_model(self.ref_model)
|
| 34 |
+
|
| 35 |
+
self.is_encoder_decoder = kwargs['template'].is_encoder_decoder
|
| 36 |
+
self.aux_loss_enabled = getattr(model.config, 'output_router_logits', False)
|
| 37 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 38 |
+
self.generate_during_eval = getattr(args, 'generate_during_eval', False)
|
| 39 |
+
if self.is_encoder_decoder:
|
| 40 |
+
self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id')
|
| 41 |
+
self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id')
|
| 42 |
+
# not use
|
| 43 |
+
self.is_vision_model = False
|
| 44 |
+
self.label_pad_token_id = -100
|
| 45 |
+
self.use_dpo_data_collator = True
|
| 46 |
+
super().__init__(model, *_args, **kwargs)
|
| 47 |
+
if is_deepspeed_zero3_enabled() and ref_model is not None:
|
| 48 |
+
try:
|
| 49 |
+
from trl.models.utils import prepare_deepspeed
|
| 50 |
+
except ImportError as e:
|
| 51 |
+
raise ImportError('Please install trl>=0.14 via `pip install "trl>=0.14"`') from e
|
| 52 |
+
prepare_deepspeed(self.ref_model, self.accelerator) # Does not wrap DeepSpeedEngine
|
| 53 |
+
self.padding_value = self.tokenizer.pad_token_id
|
| 54 |
+
|
| 55 |
+
def concatenated_forward(
|
| 56 |
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
| 57 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 58 |
+
model_kwargs = batch.copy()
|
| 59 |
+
labels = model_kwargs.pop('labels', None)
|
| 60 |
+
if self.is_encoder_decoder:
|
| 61 |
+
model_kwargs['labels'] = labels
|
| 62 |
+
|
| 63 |
+
if self.aux_loss_enabled:
|
| 64 |
+
model_kwargs['output_router_logits'] = True
|
| 65 |
+
outputs = model(**model_kwargs, use_cache=False)
|
| 66 |
+
model_kwargs['labels'] = labels
|
| 67 |
+
model_kwargs['chosen_labels'] = torch.zeros(model_kwargs['labels'].shape[0] // 2) # just get shape
|
| 68 |
+
if outputs.logits.shape[1] != labels.shape[1]:
|
| 69 |
+
# for llava, the model returns logits for the entire sequence, including the image tokens
|
| 70 |
+
# (placed before the text tokens)
|
| 71 |
+
outputs.logits = outputs.logits[:, -labels.shape[1]:]
|
| 72 |
+
for key in ['input_ids', 'attention_mask', 'labels']:
|
| 73 |
+
model_kwargs[f'concatenated_{key}'] = model_kwargs.pop(key, None)
|
| 74 |
+
if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
|
| 75 |
+
model_kwargs['concatenated_input_ids'] = model_kwargs['concatenated_labels']
|
| 76 |
+
|
| 77 |
+
@contextmanager
|
| 78 |
+
def _patch_concatenated_forward():
|
| 79 |
+
_old_concatenated_inputs = self.concatenated_inputs
|
| 80 |
+
_old_model_call = model.__class__.__call__
|
| 81 |
+
self.concatenated_inputs = lambda *args, **kwargs: model_kwargs
|
| 82 |
+
model.__class__.__call__ = lambda *args, **kwargs: outputs
|
| 83 |
+
try:
|
| 84 |
+
yield
|
| 85 |
+
finally:
|
| 86 |
+
self.concatenated_inputs = _old_concatenated_inputs
|
| 87 |
+
model.__class__.__call__ = _old_model_call
|
| 88 |
+
|
| 89 |
+
with _patch_concatenated_forward():
|
| 90 |
+
return super().concatenated_forward(model, model_kwargs)
|
| 91 |
+
|
| 92 |
+
def get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, *args, **kwargs):
|
| 93 |
+
if self.is_encoder_decoder:
|
| 94 |
+
labels = labels.clone() # fix trl bug
|
| 95 |
+
return super().get_batch_logps(logits, labels, *args, **kwargs)
|
| 96 |
+
|
| 97 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 98 |
+
res = super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 99 |
+
# compat transformers>=4.46.*
|
| 100 |
+
if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
|
| 101 |
+
loss = res[0] if return_outputs else res
|
| 102 |
+
loss /= self.args.gradient_accumulation_steps
|
| 103 |
+
return (loss, res[1:]) if return_outputs else loss
|
| 104 |
+
return res
|
ms-swift/swift/trainers/rlhf_trainer/utils.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
from types import MethodType
|
| 3 |
+
from typing import Any, List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from peft.tuners import lora
|
| 7 |
+
from peft.tuners.lora import LoraLayer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def round_robin(num_reqs, num_workers):
|
| 11 |
+
"""Distribute requests evenly across workers using round-robin algorithm.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
num_reqs (int): Total number of requests to distribute
|
| 15 |
+
num_workers (int): Number of available workers
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
list: A list of lists where each sublist contains the request indices
|
| 19 |
+
assigned to that particular node
|
| 20 |
+
"""
|
| 21 |
+
distribution = [[] for _ in range(num_workers)]
|
| 22 |
+
for idx in range(num_reqs):
|
| 23 |
+
worker_id = idx % num_workers
|
| 24 |
+
distribution[worker_id].append(idx)
|
| 25 |
+
return distribution
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@contextmanager
|
| 29 |
+
def patch_lora_merge(model, parameter_group=None):
|
| 30 |
+
"""Patch LoraLayer's merge and get_delta_weight methods for controlled merging.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: The PEFT model to patch
|
| 34 |
+
parameter_group: Optional list of parameter names to restrict merging
|
| 35 |
+
|
| 36 |
+
Yields:
|
| 37 |
+
The patched model (context manager ensures cleanup)
|
| 38 |
+
"""
|
| 39 |
+
from peft.tuners.tuners_utils import check_adapters_to_merge
|
| 40 |
+
|
| 41 |
+
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
| 42 |
+
if parameter_group and all(self.name not in pg for pg in parameter_group):
|
| 43 |
+
return # Skip if not in target parameter group
|
| 44 |
+
adapter_names = check_adapters_to_merge(self, adapter_names)
|
| 45 |
+
if not adapter_names:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
for active_adapter in adapter_names:
|
| 49 |
+
if active_adapter in self.lora_A.keys():
|
| 50 |
+
base_layer = self.get_base_layer()
|
| 51 |
+
if self.use_dora.get(active_adapter, False):
|
| 52 |
+
self.lora_magnitude_vector[active_adapter].weight.data = \
|
| 53 |
+
self.lora_magnitude_vector[active_adapter].weight.data.to(base_layer.weight.device)
|
| 54 |
+
|
| 55 |
+
return self.merge_origin(safe_merge, adapter_names)
|
| 56 |
+
|
| 57 |
+
def get_delta_weight(self, adapter) -> torch.Tensor:
|
| 58 |
+
# Ensure tensors are on correct device
|
| 59 |
+
if isinstance(self, lora.Embedding):
|
| 60 |
+
self.lora_embedding_A[adapter].data = self.lora_embedding_A[adapter].data.to(self.base_layer.weight.device)
|
| 61 |
+
self.lora_embedding_B[adapter].data = self.lora_embedding_B[adapter].data.to(self.base_layer.weight.device)
|
| 62 |
+
else:
|
| 63 |
+
self.lora_A[adapter].weight.data = self.lora_A[adapter].weight.data.to(self.base_layer.weight.device)
|
| 64 |
+
self.lora_B[adapter].weight.data = self.lora_B[adapter].weight.data.to(self.base_layer.weight.device)
|
| 65 |
+
return self.get_delta_weight_origin(adapter).to(self.base_layer.weight.device)
|
| 66 |
+
|
| 67 |
+
def _cache_pop(self, key: str) -> Any:
|
| 68 |
+
value = self._caches.pop(key).to(self.base_layer.weight.device)
|
| 69 |
+
return value
|
| 70 |
+
|
| 71 |
+
# Patch all LoraLayer instances
|
| 72 |
+
for name, module in model.named_modules():
|
| 73 |
+
if isinstance(module, LoraLayer):
|
| 74 |
+
module.name = name
|
| 75 |
+
if not hasattr(module, 'merge_origin') and hasattr(module, 'base_layer'):
|
| 76 |
+
module.merge_origin = module.merge
|
| 77 |
+
module.merge = MethodType(merge, module)
|
| 78 |
+
module.get_delta_weight_origin = module.get_delta_weight
|
| 79 |
+
module.get_delta_weight = MethodType(get_delta_weight, module)
|
| 80 |
+
module._cache_pop_origin = module._cache_pop
|
| 81 |
+
module._cache_pop = MethodType(_cache_pop, module)
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
yield model
|
| 85 |
+
finally:
|
| 86 |
+
# Cleanup: restore original methods
|
| 87 |
+
for module in model.modules():
|
| 88 |
+
if isinstance(module, LoraLayer):
|
| 89 |
+
if hasattr(module, 'merge_origin'):
|
| 90 |
+
module.merge = module.merge_origin
|
| 91 |
+
del module.merge_origin
|
| 92 |
+
module.get_delta_weight = module.get_delta_weight_origin
|
| 93 |
+
del module.get_delta_weight_origin
|
| 94 |
+
module._cache_pop = module._cache_pop_origin
|
| 95 |
+
del module._cache_pop_origin
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@contextmanager
|
| 99 |
+
def patch_lora_unmerge(model):
|
| 100 |
+
"""Patch the unmerge method to ensure proper device handling."""
|
| 101 |
+
|
| 102 |
+
def _cache_pop_patched(self, key: str) -> Any:
|
| 103 |
+
value = self._caches.pop(key).to(self.base_layer.weight.device)
|
| 104 |
+
return value
|
| 105 |
+
|
| 106 |
+
def unmerge_patched(self):
|
| 107 |
+
if not self.merged:
|
| 108 |
+
return
|
| 109 |
+
# Move magnitude vectors to correct device first
|
| 110 |
+
for adapter in list(self.merged_adapters):
|
| 111 |
+
if self.use_dora.get(adapter, False):
|
| 112 |
+
self.lora_magnitude_vector[adapter].weight.data = \
|
| 113 |
+
self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device)
|
| 114 |
+
|
| 115 |
+
return self.unmerge_origin()
|
| 116 |
+
|
| 117 |
+
for module in model.modules():
|
| 118 |
+
if isinstance(module, LoraLayer) and not hasattr(module, 'unmerge_origin'):
|
| 119 |
+
module.unmerge_origin = module.unmerge
|
| 120 |
+
module.unmerge = MethodType(unmerge_patched, module)
|
| 121 |
+
module._cache_pop_origin = module._cache_pop
|
| 122 |
+
module._cache_pop = MethodType(_cache_pop_patched, module)
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
yield model
|
| 126 |
+
finally:
|
| 127 |
+
for module in model.modules():
|
| 128 |
+
if isinstance(module, LoraLayer) and hasattr(module, 'unmerge_origin'):
|
| 129 |
+
module.unmerge = module.unmerge_origin
|
| 130 |
+
del module.unmerge_origin
|
| 131 |
+
module._cache_pop = module._cache_pop_origin
|
| 132 |
+
del module._cache_pop_origin
|
ms-swift/swift/trainers/rlhf_trainer/vllm_client.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# Code partially sourced from Hugging Face TRL
|
| 4 |
+
|
| 5 |
+
import atexit
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
import torch
|
| 12 |
+
from dacite import from_dict
|
| 13 |
+
from requests import ConnectionError
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from swift.llm import AdapterRequest, InferRequest, Template
|
| 17 |
+
from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig
|
| 18 |
+
from swift.plugin import Metric
|
| 19 |
+
from swift.utils import is_vllm_ascend_available, is_vllm_available
|
| 20 |
+
|
| 21 |
+
if is_vllm_available():
|
| 22 |
+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
| 23 |
+
from vllm.distributed.utils import StatelessProcessGroup
|
| 24 |
+
|
| 25 |
+
if is_vllm_ascend_available():
|
| 26 |
+
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator # noqa
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class VLLMClient:
|
| 32 |
+
"""
|
| 33 |
+
A client class to interact with a vLLM server.
|
| 34 |
+
|
| 35 |
+
This class provides methods to infer completions, initialize and manage weight update groups, and update model
|
| 36 |
+
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
| 40 |
+
IP address of the vLLM server.
|
| 41 |
+
server_port (`int`, *optional*, defaults to `8000`):
|
| 42 |
+
Port number of the vLLM server.
|
| 43 |
+
group_port (`int`, *optional*, defaults to `51216`):
|
| 44 |
+
Port number for the weight update group.
|
| 45 |
+
connection_timeout (`float`, *optional*, defaults to `0.0`):
|
| 46 |
+
Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
|
| 47 |
+
timeout, a `ConnectionError` is raised.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self,
|
| 51 |
+
host: str = '0.0.0.0',
|
| 52 |
+
server_port: int = 8000,
|
| 53 |
+
group_port: int = 51216,
|
| 54 |
+
connection_timeout: float = 0.0):
|
| 55 |
+
if not is_vllm_available():
|
| 56 |
+
raise ImportError('vLLM is not installed. Please install it with `pip install vllm`.')
|
| 57 |
+
|
| 58 |
+
self.session = requests.Session()
|
| 59 |
+
self.host = host
|
| 60 |
+
self.server_port = server_port
|
| 61 |
+
self.group_port = group_port
|
| 62 |
+
self.check_server(connection_timeout) # check server and fail after timeout
|
| 63 |
+
|
| 64 |
+
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
|
| 65 |
+
"""
|
| 66 |
+
Check server availability with retries on failure, within a total timeout duration. If the server is not up
|
| 67 |
+
after the total timeout duration, raise a `ConnectionError`.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
retry_interval (`float`, *optional*, defaults to `2.0`):
|
| 71 |
+
Interval in seconds between retries.
|
| 72 |
+
total_timeout (`float`, *optional*, defaults to `0.0`):
|
| 73 |
+
Total timeout duration in seconds.
|
| 74 |
+
"""
|
| 75 |
+
url = f'http://{self.host}:{self.server_port}/health/'
|
| 76 |
+
start_time = time.time() # Record the start time
|
| 77 |
+
|
| 78 |
+
while True:
|
| 79 |
+
try:
|
| 80 |
+
response = requests.get(url)
|
| 81 |
+
except requests.exceptions.RequestException as exc:
|
| 82 |
+
# Check if the total timeout duration has passed
|
| 83 |
+
elapsed_time = time.time() - start_time
|
| 84 |
+
if elapsed_time >= total_timeout:
|
| 85 |
+
raise ConnectionError(
|
| 86 |
+
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
|
| 87 |
+
'seconds. Make sure the server is running by running `swift deploy`.') from exc
|
| 88 |
+
else:
|
| 89 |
+
if response.status_code == 200:
|
| 90 |
+
logger.info('Server is up!')
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
# Retry logic: wait before trying again
|
| 94 |
+
logger.info(f'Server is not up yet. Retrying in {retry_interval} seconds...')
|
| 95 |
+
time.sleep(retry_interval)
|
| 96 |
+
|
| 97 |
+
def infer(
|
| 98 |
+
self,
|
| 99 |
+
infer_requests: List[InferRequest],
|
| 100 |
+
request_config: Optional[RequestConfig] = None,
|
| 101 |
+
metrics: Optional[List[Metric]] = None,
|
| 102 |
+
*,
|
| 103 |
+
template: Optional[Template] = None,
|
| 104 |
+
use_tqdm: Optional[bool] = None,
|
| 105 |
+
adapter_request: Optional[AdapterRequest] = None,
|
| 106 |
+
):
|
| 107 |
+
url = f'http://{self.host}:{self.server_port}/infer/'
|
| 108 |
+
response = self.session.post(
|
| 109 |
+
url,
|
| 110 |
+
json={
|
| 111 |
+
'infer_requests': infer_requests,
|
| 112 |
+
'request_config': request_config,
|
| 113 |
+
'metrics': metrics,
|
| 114 |
+
'template': template,
|
| 115 |
+
'use_tqdm': use_tqdm,
|
| 116 |
+
'adapter_request': adapter_request,
|
| 117 |
+
},
|
| 118 |
+
)
|
| 119 |
+
if response.status_code == 200:
|
| 120 |
+
return [from_dict(data_class=ChatCompletionResponse, data=resp) for resp in response.json()]
|
| 121 |
+
else:
|
| 122 |
+
raise Exception(f'Request failed: {response.status_code}, {response.text}')
|
| 123 |
+
|
| 124 |
+
def init_communicator(self):
|
| 125 |
+
"""
|
| 126 |
+
Initializes the weight update group in a distributed setup for model synchronization.
|
| 127 |
+
"""
|
| 128 |
+
# Get the tensor parallel size from the server
|
| 129 |
+
url = f'http://{self.host}:{self.server_port}/get_world_size/'
|
| 130 |
+
response = requests.get(url)
|
| 131 |
+
if response.status_code == 200:
|
| 132 |
+
vllm_world_size = response.json()['world_size']
|
| 133 |
+
else:
|
| 134 |
+
raise Exception(f'Request failed: {response.status_code}, {response.text}')
|
| 135 |
+
|
| 136 |
+
world_size = vllm_world_size + 1 # add the client to the world
|
| 137 |
+
self.rank = vllm_world_size # the client's rank is the last process
|
| 138 |
+
|
| 139 |
+
# Initialize weight update group
|
| 140 |
+
url = f'http://{self.host}:{self.server_port}/init_communicator/'
|
| 141 |
+
# In the server side, the host is set to 0.0.0.0
|
| 142 |
+
response = self.session.post(url, json={'host': '0.0.0.0', 'port': self.group_port, 'world_size': world_size})
|
| 143 |
+
if response.status_code != 200:
|
| 144 |
+
raise Exception(f'Request failed: {response.status_code}, {response.text}')
|
| 145 |
+
|
| 146 |
+
# Brief delay to allow server initialization. While not strictly required (client socket will retry on
|
| 147 |
+
# connection failure), this prevents log warnings like:
|
| 148 |
+
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
|
| 149 |
+
time.sleep(0.1)
|
| 150 |
+
|
| 151 |
+
# Set up the communication group for weight broadcasting
|
| 152 |
+
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
|
| 153 |
+
self.pynccl_comm = PyNcclCommunicator(pg, device=0)
|
| 154 |
+
|
| 155 |
+
# When the client object is deleted, close the weight update group
|
| 156 |
+
atexit.register(self.close_communicator)
|
| 157 |
+
|
| 158 |
+
def update_named_param(self, name: str, weights: torch.Tensor):
|
| 159 |
+
"""
|
| 160 |
+
Updates a specific named parameter in the model and broadcasts it to other processes.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
name (`str`):
|
| 164 |
+
Name of the layer whose weights are being updated.
|
| 165 |
+
weights (`torch.Tensor`):
|
| 166 |
+
Tensor containing the updated weights.
|
| 167 |
+
"""
|
| 168 |
+
dtype, shape = str(weights.dtype), tuple(weights.shape)
|
| 169 |
+
url = f'http://{self.host}:{self.server_port}/update_named_param/'
|
| 170 |
+
response = self.session.post(url, json={'name': name, 'dtype': dtype, 'shape': shape})
|
| 171 |
+
if response.status_code != 200:
|
| 172 |
+
raise Exception(f'Request failed: {response.status_code}, {response.text}')
|
| 173 |
+
|
| 174 |
+
# Broadcast the weights to the other processes
|
| 175 |
+
self.pynccl_comm.broadcast(weights, src=self.rank)
|
| 176 |
+
self.pynccl_comm.group.barrier()
|
| 177 |
+
|
| 178 |
+
def update_model_params(self, model: nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
model (`nn.Module`):
|
| 184 |
+
Model whose parameters (weights/biases) are to be updated.
|
| 185 |
+
"""
|
| 186 |
+
for name, param in model.named_parameters():
|
| 187 |
+
# Update each parameter individually
|
| 188 |
+
self.update_named_param(name, param.data)
|
| 189 |
+
|
| 190 |
+
def reset_prefix_cache(self):
|
| 191 |
+
"""
|
| 192 |
+
Resets the prefix cache for the model.
|
| 193 |
+
"""
|
| 194 |
+
url = f'http://{self.host}:{self.server_port}/reset_prefix_cache/'
|
| 195 |
+
response = self.session.post(url)
|
| 196 |
+
if response.status_code != 200:
|
| 197 |
+
raise Exception(f'Request failed: {response.status_code}, {response.text}')
|
| 198 |
+
|
| 199 |
+
def close_communicator(self):
|
| 200 |
+
"""
|
| 201 |
+
Closes the weight update group and cleans up the communication group.
|
| 202 |
+
"""
|
| 203 |
+
url = f'http://{self.host}:{self.server_port}/close_communicator/'
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
response = self.session.post(url)
|
| 207 |
+
except ConnectionError:
|
| 208 |
+
# The server might be already down, so we don't need to close the communicator
|
| 209 |
+
pass
|
| 210 |
+
else:
|
| 211 |
+
if response.status_code != 200:
|
| 212 |
+
raise Exception(f'Request failed: {response.status_code}, {response.text}')
|
ms-swift/swift/trainers/sequence_parallel/base.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SequenceParallel(abc.ABC):
|
| 6 |
+
|
| 7 |
+
@abstractmethod
|
| 8 |
+
def init_sequence_parallel(self, size):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def prepare_model(self, model, tokenizer, split_in_forward):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def pad_and_split_inputs(self,
|
| 17 |
+
tokenizer,
|
| 18 |
+
input_ids,
|
| 19 |
+
input_embeds,
|
| 20 |
+
labels,
|
| 21 |
+
position_ids,
|
| 22 |
+
attention_mask,
|
| 23 |
+
loss_scale,
|
| 24 |
+
embed_tokens=None):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def reduce_outputs(self, loss, labels):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def sp_group(self):
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def world_size(self):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def prepare_trainer(self, trainer):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def get_dataloader(self, trainer, dataset, batch_size):
|
| 45 |
+
pass
|
ms-swift/swift/trainers/sequence_parallel/ulysses.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import partial
|
| 3 |
+
from types import MethodType
|
| 4 |
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import datasets
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from peft import PeftModel
|
| 11 |
+
from torch.distributed.device_mesh import init_device_mesh
|
| 12 |
+
from torch.nn import CrossEntropyLoss
|
| 13 |
+
from torch.utils.data import DataLoader, Sampler
|
| 14 |
+
from transformers.trainer_utils import seed_worker
|
| 15 |
+
|
| 16 |
+
from swift.llm import DataLoaderDispatcher, get_model_arch
|
| 17 |
+
from swift.tuners import SwiftModel
|
| 18 |
+
from swift.utils import get_current_device, get_device, get_dist_setting
|
| 19 |
+
from .base import SequenceParallel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class GatherLoss(torch.autograd.Function):
|
| 23 |
+
"""Gather loss from sequence group"""
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def forward(ctx, loss, labels, process_group, gather_idx=None):
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
loss: loss tensor after splitting
|
| 30 |
+
labels: labels tensor after splitting
|
| 31 |
+
process_group: the sequence parallel group
|
| 32 |
+
gather_idx: gather the tensors on this dim
|
| 33 |
+
"""
|
| 34 |
+
ctx.process_group = process_group
|
| 35 |
+
shape0 = labels.shape[0]
|
| 36 |
+
ctx.scatter_shape = labels.shape[gather_idx or 0]
|
| 37 |
+
ctx.gather_idx = gather_idx or 0
|
| 38 |
+
world_size = dist.get_world_size(group=process_group) # the sp world size
|
| 39 |
+
output = torch.empty((shape0 * world_size, *loss.shape[1:]), dtype=loss.dtype, device=loss.device)
|
| 40 |
+
# gather all from sp group
|
| 41 |
+
dist.all_gather_into_tensor(output, loss, group=process_group)
|
| 42 |
+
if gather_idx is not None:
|
| 43 |
+
output = torch.cat(output.split(shape0, dim=0), dim=gather_idx)
|
| 44 |
+
labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]), dtype=labels.dtype, device=labels.device)
|
| 45 |
+
dist.all_gather_into_tensor(labels_output, labels, group=process_group)
|
| 46 |
+
if gather_idx is not None:
|
| 47 |
+
labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx)
|
| 48 |
+
return output, labels_output
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def backward(ctx, *grad_output):
|
| 52 |
+
_grad = grad_output[0] * dist.get_world_size(group=ctx.process_group)
|
| 53 |
+
return _grad.split(
|
| 54 |
+
ctx.scatter_shape, dim=ctx.gather_idx)[dist.get_rank(ctx.process_group)].contiguous(), None, None, None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# For nll loss
|
| 58 |
+
def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None, process_group=None) -> torch.Tensor:
|
| 59 |
+
if hasattr(outputs, 'logits'):
|
| 60 |
+
logits = outputs.logits
|
| 61 |
+
else:
|
| 62 |
+
logits = outputs
|
| 63 |
+
device = logits.device
|
| 64 |
+
logits = logits.view(-1, logits.shape[-1])
|
| 65 |
+
labels = labels.flatten().to(device)
|
| 66 |
+
# Flatten the tokens
|
| 67 |
+
loss_fct = CrossEntropyLoss(reduction='none')
|
| 68 |
+
# flatten loss
|
| 69 |
+
loss = loss_fct(logits, labels)
|
| 70 |
+
|
| 71 |
+
if loss_scale is not None:
|
| 72 |
+
loss_scale = loss_scale.flatten().to(loss.device)
|
| 73 |
+
loss = (loss_scale * loss)
|
| 74 |
+
loss, labels = GatherLoss.apply(loss, labels, process_group)
|
| 75 |
+
loss = loss[labels != -100].sum()
|
| 76 |
+
if num_items_in_batch is None:
|
| 77 |
+
loss = loss / (labels != -100).sum()
|
| 78 |
+
else:
|
| 79 |
+
loss = loss / num_items_in_batch
|
| 80 |
+
return loss
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# For DPO
|
| 84 |
+
def get_batch_logps(logits: torch.FloatTensor,
|
| 85 |
+
labels: torch.LongTensor,
|
| 86 |
+
label_pad_token_id: int = -100,
|
| 87 |
+
is_encoder_decoder: bool = False,
|
| 88 |
+
process_group=None) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
| 89 |
+
labels = labels.clone() # No need to shift, pad and split has shifted the inputs.
|
| 90 |
+
loss_mask = labels != label_pad_token_id
|
| 91 |
+
labels[labels == label_pad_token_id] = 0
|
| 92 |
+
labels = labels.to(logits.device)
|
| 93 |
+
loss_mask = loss_mask.to(logits.device)
|
| 94 |
+
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
| 95 |
+
total_per_token_logps, total_loss_mask = GatherLoss.apply(per_token_logps, loss_mask, process_group, 1)
|
| 96 |
+
return (total_per_token_logps * total_loss_mask).sum(-1), total_loss_mask.sum(-1)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class UlyssesSampler(Sampler):
|
| 100 |
+
|
| 101 |
+
# Code borrowed from mmengine
|
| 102 |
+
def __init__(self, ulysses, dataset, shuffle: bool = True, seed=None, round_up: bool = True) -> None:
|
| 103 |
+
self.ulysses = ulysses
|
| 104 |
+
rank = dist.get_rank(ulysses.device_mesh['data'].get_group())
|
| 105 |
+
world_size = ulysses.device_mesh['data'].size()
|
| 106 |
+
self.rank = rank
|
| 107 |
+
self.world_size = world_size
|
| 108 |
+
|
| 109 |
+
self.dataset = dataset
|
| 110 |
+
self.shuffle = shuffle
|
| 111 |
+
assert seed is not None
|
| 112 |
+
self.seed = seed
|
| 113 |
+
self.epoch = 0
|
| 114 |
+
self.round_up = round_up
|
| 115 |
+
|
| 116 |
+
if self.round_up:
|
| 117 |
+
self.num_samples = math.ceil(len(self.dataset) / world_size)
|
| 118 |
+
self.total_size = self.num_samples * self.world_size
|
| 119 |
+
else:
|
| 120 |
+
self.num_samples = math.ceil((len(self.dataset) - rank) / world_size)
|
| 121 |
+
self.total_size = len(self.dataset)
|
| 122 |
+
|
| 123 |
+
def __iter__(self) -> Iterator[int]:
|
| 124 |
+
if self.shuffle:
|
| 125 |
+
g = torch.Generator()
|
| 126 |
+
g.manual_seed(self.seed + self.epoch)
|
| 127 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
| 128 |
+
else:
|
| 129 |
+
indices = torch.arange(len(self.dataset)).tolist()
|
| 130 |
+
|
| 131 |
+
if self.round_up:
|
| 132 |
+
indices = (indices * int(self.total_size / len(indices) + 1))[:self.total_size]
|
| 133 |
+
|
| 134 |
+
indices = indices[self.rank:self.total_size:self.world_size]
|
| 135 |
+
|
| 136 |
+
return iter(indices)
|
| 137 |
+
|
| 138 |
+
def __len__(self) -> int:
|
| 139 |
+
return self.num_samples
|
| 140 |
+
|
| 141 |
+
def set_epoch(self, epoch: int) -> None:
|
| 142 |
+
self.epoch = epoch
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class UlyssesDispatcher(DataLoaderDispatcher):
|
| 146 |
+
|
| 147 |
+
def __init__(self, base_dataloader, ulysses):
|
| 148 |
+
super().__init__(base_dataloader)
|
| 149 |
+
self.ulysses = ulysses
|
| 150 |
+
|
| 151 |
+
def __iter__(self):
|
| 152 |
+
base_iter = iter(self.base_dataloader)
|
| 153 |
+
while True:
|
| 154 |
+
data = None
|
| 155 |
+
try:
|
| 156 |
+
for i in range(self.ulysses.dp_world_size):
|
| 157 |
+
data = next(base_iter)
|
| 158 |
+
if i == self.ulysses.dp_rank:
|
| 159 |
+
break
|
| 160 |
+
except StopIteration:
|
| 161 |
+
pass
|
| 162 |
+
if data is None:
|
| 163 |
+
break
|
| 164 |
+
yield data
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Code borrowed from deepspeed, here is why:
|
| 168 |
+
# 1. Reduce the dependency
|
| 169 |
+
# 2. The original code is complex
|
| 170 |
+
def _generate_layout_params(scatter_idx, seq_world_size, input):
|
| 171 |
+
if scatter_idx < 2:
|
| 172 |
+
bs, global_seq_len, num_local_head, head_dim = input.shape
|
| 173 |
+
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
|
| 174 |
+
pre_all2all_permute_idx = (1, 0, 2, 3, 4)
|
| 175 |
+
|
| 176 |
+
post_all2all_permute_idx = (1, 2, 0, 3, 4)
|
| 177 |
+
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
|
| 178 |
+
else:
|
| 179 |
+
bs, local_seq_len, num_total_head, head_dim = input.shape
|
| 180 |
+
assert num_total_head % seq_world_size == 0, (f'Number of heads ({num_total_head}) must be divisible '
|
| 181 |
+
f'by the sequence parallel size ({seq_world_size})!')
|
| 182 |
+
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
|
| 183 |
+
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
|
| 184 |
+
|
| 185 |
+
post_all2all_permute_idx = (1, 0, 2, 3, 4)
|
| 186 |
+
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
|
| 187 |
+
|
| 188 |
+
return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def post_all2all(permute_idx, res_shape):
|
| 192 |
+
"""
|
| 193 |
+
Post-processing function for `all2all` communication.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def post_func(input):
|
| 197 |
+
if permute_idx is not None:
|
| 198 |
+
input = input.permute(permute_idx).contiguous()
|
| 199 |
+
output = input.reshape(res_shape).contiguous()
|
| 200 |
+
|
| 201 |
+
return output
|
| 202 |
+
|
| 203 |
+
return post_func
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def pre_all2all_fun(permute_idx, inp_shape, input):
|
| 207 |
+
"""
|
| 208 |
+
Pre-processing function for `all2all` communication.
|
| 209 |
+
"""
|
| 210 |
+
input_t = input.reshape(inp_shape).contiguous()
|
| 211 |
+
if permute_idx is not None:
|
| 212 |
+
input_t = input_t.permute(permute_idx).contiguous()
|
| 213 |
+
return input_t
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def single_all_to_all(input, scatter_idx, gather_idx, group, **kwargs):
|
| 217 |
+
seq_world_size = dist.get_world_size(group)
|
| 218 |
+
num_heads = input.shape[2]
|
| 219 |
+
if num_heads % seq_world_size != 0 and not scatter_idx < 2:
|
| 220 |
+
raise NotImplementedError
|
| 221 |
+
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = (
|
| 222 |
+
_generate_layout_params(scatter_idx, seq_world_size, input))
|
| 223 |
+
|
| 224 |
+
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)
|
| 225 |
+
|
| 226 |
+
post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
|
| 227 |
+
output = torch.empty_like(input_t)
|
| 228 |
+
dist.all_to_all_single(output, input_t, group=group)
|
| 229 |
+
|
| 230 |
+
res = post_all2all_fun(output)
|
| 231 |
+
return res
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class _SeqAllToAll(torch.autograd.Function):
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def forward(
|
| 238 |
+
ctx: Any,
|
| 239 |
+
group: dist.ProcessGroup,
|
| 240 |
+
input: torch.Tensor,
|
| 241 |
+
scatter_idx: int,
|
| 242 |
+
gather_idx: int,
|
| 243 |
+
) -> torch.Tensor:
|
| 244 |
+
ctx.group = group
|
| 245 |
+
ctx.scatter_idx = scatter_idx
|
| 246 |
+
ctx.gather_idx = gather_idx
|
| 247 |
+
res = single_all_to_all(input, scatter_idx, gather_idx, group)
|
| 248 |
+
return res
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]:
|
| 252 |
+
return None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class DistributedAttention(torch.nn.Module):
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
local_attention,
|
| 260 |
+
sequence_process_group: dist.ProcessGroup,
|
| 261 |
+
scatter_idx: int = 2,
|
| 262 |
+
gather_idx: int = 1,
|
| 263 |
+
) -> None:
|
| 264 |
+
super(DistributedAttention, self).__init__()
|
| 265 |
+
self.local_attn = local_attention
|
| 266 |
+
self.spg = sequence_process_group
|
| 267 |
+
self.scatter_idx = scatter_idx
|
| 268 |
+
self.gather_idx = gather_idx
|
| 269 |
+
|
| 270 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor,
|
| 271 |
+
*args: Any, **kwargs) -> torch.Tensor:
|
| 272 |
+
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
| 273 |
+
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
| 274 |
+
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
| 275 |
+
position_ids = kwargs.pop('position_ids', None)
|
| 276 |
+
if position_ids is not None:
|
| 277 |
+
shape0 = position_ids.shape[0]
|
| 278 |
+
position_ids_output = torch.empty((shape0 * dist.get_world_size(self.spg), position_ids.shape[1]),
|
| 279 |
+
dtype=position_ids.dtype,
|
| 280 |
+
device=position_ids.device)
|
| 281 |
+
dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.spg)
|
| 282 |
+
position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1)
|
| 283 |
+
context_layer = self.local_attn(
|
| 284 |
+
query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs)
|
| 285 |
+
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
| 286 |
+
return output
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class Ulysses(SequenceParallel):
|
| 290 |
+
|
| 291 |
+
def __init__(self):
|
| 292 |
+
self.split_in_forward = None
|
| 293 |
+
self.dp_world_size = None
|
| 294 |
+
self.sp_world_size = None
|
| 295 |
+
self.model_dtype = None
|
| 296 |
+
self.causal_mask_func = None
|
| 297 |
+
self.device_mesh = None
|
| 298 |
+
self._inited = False
|
| 299 |
+
|
| 300 |
+
def init_sequence_parallel(self, size):
|
| 301 |
+
if self._inited:
|
| 302 |
+
return
|
| 303 |
+
self._inited = True
|
| 304 |
+
self.sp_world_size = size
|
| 305 |
+
rank, local_rank, world_size, local_world_size = get_dist_setting()
|
| 306 |
+
self.dp_world_size = world_size // size
|
| 307 |
+
self.device_mesh = init_device_mesh(
|
| 308 |
+
get_device().split(':')[0], mesh_shape=(world_size // size, size), mesh_dim_names=['data', 'sequence'])
|
| 309 |
+
|
| 310 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 311 |
+
ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'] = ALL_ATTENTION_FUNCTIONS['flash_attention_2']
|
| 312 |
+
ALL_ATTENTION_FUNCTIONS['sdpa_origin'] = ALL_ATTENTION_FUNCTIONS['sdpa']
|
| 313 |
+
|
| 314 |
+
def local_flash_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
|
| 315 |
+
dist_attn, **kwargs):
|
| 316 |
+
if dist_attn.local_attn is None:
|
| 317 |
+
|
| 318 |
+
def _attention(query, key, value, *args, **kwargs):
|
| 319 |
+
query = query.transpose(1, 2)
|
| 320 |
+
key = key.transpose(1, 2)
|
| 321 |
+
value = value.transpose(1, 2)
|
| 322 |
+
return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args,
|
| 323 |
+
**kwargs)[0]
|
| 324 |
+
|
| 325 |
+
dist_attn.local_attn = _attention
|
| 326 |
+
|
| 327 |
+
return dist_attn(
|
| 328 |
+
query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask,
|
| 329 |
+
*args, **kwargs), None
|
| 330 |
+
|
| 331 |
+
def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
|
| 332 |
+
dist_attn, **kwargs):
|
| 333 |
+
if dist_attn.local_attn is None:
|
| 334 |
+
|
| 335 |
+
def _attention(query, key, value, *args, **kwargs):
|
| 336 |
+
query = query.transpose(1, 2)
|
| 337 |
+
key = key.transpose(1, 2)
|
| 338 |
+
value = value.transpose(1, 2)
|
| 339 |
+
return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query, key, value, *args, **kwargs)[0]
|
| 340 |
+
|
| 341 |
+
dist_attn.local_attn = _attention
|
| 342 |
+
return dist_attn(
|
| 343 |
+
query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask,
|
| 344 |
+
*args, **kwargs), None
|
| 345 |
+
|
| 346 |
+
ALL_ATTENTION_FUNCTIONS['flash_attention_2'] = partial(
|
| 347 |
+
local_flash_attn, dist_attn=DistributedAttention(None, self.sp_group))
|
| 348 |
+
ALL_ATTENTION_FUNCTIONS['sdpa'] = partial(local_sdpa_attn, dist_attn=DistributedAttention(None, self.sp_group))
|
| 349 |
+
|
| 350 |
+
from transformers.modeling_flash_attention_utils import is_flash_attn_available
|
| 351 |
+
if is_flash_attn_available():
|
| 352 |
+
# TODO this works for multi-modal models like qwen2.5-vl
|
| 353 |
+
# SDPA is not supported, because we need to copy the code to our project, which will bring
|
| 354 |
+
# more works for maintaining.
|
| 355 |
+
from transformers import modeling_flash_attention_utils
|
| 356 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 357 |
+
_distributed_flash_attention = DistributedAttention(_flash_attention_forward, self.sp_group)
|
| 358 |
+
|
| 359 |
+
def flash_attention_forward(query_states: torch.Tensor, key_states: torch.Tensor,
|
| 360 |
+
value_states: torch.Tensor, attention_mask: Optional[torch.Tensor], q_len,
|
| 361 |
+
*args, **kwargs):
|
| 362 |
+
return _distributed_flash_attention(query_states, key_states, value_states, attention_mask,
|
| 363 |
+
q_len * self.sp_world_size, *args, **kwargs)
|
| 364 |
+
|
| 365 |
+
modeling_flash_attention_utils._flash_attention_forward = flash_attention_forward
|
| 366 |
+
|
| 367 |
+
def prepare_model(self, model, tokenizer, split_in_forward):
|
| 368 |
+
self.split_in_forward = split_in_forward
|
| 369 |
+
|
| 370 |
+
def forward(_self, **kwargs):
|
| 371 |
+
# Split embedding here for multi-modal
|
| 372 |
+
inputs_embeds = kwargs['inputs_embeds']
|
| 373 |
+
position_ids = kwargs['position_ids']
|
| 374 |
+
attention_mask = kwargs['attention_mask']
|
| 375 |
+
_, inputs_embeds, _, position_ids, attention_mask, _ = self.pad_and_split_inputs(
|
| 376 |
+
tokenizer,
|
| 377 |
+
None,
|
| 378 |
+
inputs_embeds,
|
| 379 |
+
None,
|
| 380 |
+
position_ids,
|
| 381 |
+
attention_mask,
|
| 382 |
+
None,
|
| 383 |
+
embed_tokens=_self.embed_tokens)
|
| 384 |
+
kwargs['inputs_embeds'] = inputs_embeds
|
| 385 |
+
kwargs['position_ids'] = position_ids
|
| 386 |
+
kwargs['attention_mask'] = attention_mask
|
| 387 |
+
return _self.forward_origin(**kwargs)
|
| 388 |
+
|
| 389 |
+
if isinstance(model, (SwiftModel, PeftModel)):
|
| 390 |
+
model = model.model
|
| 391 |
+
model_meta = model.model_meta
|
| 392 |
+
llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None)
|
| 393 |
+
if llm_prefix:
|
| 394 |
+
llm_model = getattr(model, llm_prefix[0])
|
| 395 |
+
else:
|
| 396 |
+
llm_model = model
|
| 397 |
+
|
| 398 |
+
if 'CausalLM' not in llm_model.__class__.__name__:
|
| 399 |
+
llm_model = model
|
| 400 |
+
|
| 401 |
+
base_model = llm_model.model
|
| 402 |
+
self.causal_mask_func = base_model._update_causal_mask
|
| 403 |
+
if self.split_in_forward:
|
| 404 |
+
# for multi modal models
|
| 405 |
+
base_model.forward_origin = base_model.forward
|
| 406 |
+
base_model.forward = MethodType(forward, base_model)
|
| 407 |
+
|
| 408 |
+
self.model_dtype = next(model.parameters()).dtype
|
| 409 |
+
|
| 410 |
+
def _pad_sp(self, tensor, padding_value, dim=-1):
|
| 411 |
+
# code borrowed from xtuner
|
| 412 |
+
length = tensor.shape[dim]
|
| 413 |
+
if length % self.sp_world_size == 0:
|
| 414 |
+
return tensor
|
| 415 |
+
|
| 416 |
+
pad_num = self.sp_world_size - (length % self.sp_world_size)
|
| 417 |
+
if not isinstance(padding_value, torch.Tensor):
|
| 418 |
+
# ids
|
| 419 |
+
pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else
|
| 420 |
+
(*tensor.shape[:dim], pad_num))
|
| 421 |
+
pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device)
|
| 422 |
+
tensor = torch.cat([tensor, pad], dim=dim)
|
| 423 |
+
else:
|
| 424 |
+
# For embeddings
|
| 425 |
+
tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim)
|
| 426 |
+
return tensor
|
| 427 |
+
|
| 428 |
+
def world_size(self):
|
| 429 |
+
return self.sp_world_size
|
| 430 |
+
|
| 431 |
+
def _split_sp(self, input, dim: int, sp_group: dist.ProcessGroup):
|
| 432 |
+
# code borrowed from xtuner
|
| 433 |
+
if self.sp_world_size == 1:
|
| 434 |
+
return input
|
| 435 |
+
|
| 436 |
+
rank = dist.get_rank(sp_group)
|
| 437 |
+
dim_size = input.size(dim)
|
| 438 |
+
assert dim_size % self.sp_world_size == 0, (f'The dimension to split ({dim_size}) is not a multiple of '
|
| 439 |
+
f'world size ({self.sp_world_size}), cannot split tensor evenly')
|
| 440 |
+
|
| 441 |
+
tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim)
|
| 442 |
+
output = tensor_list[rank].contiguous()
|
| 443 |
+
|
| 444 |
+
return output
|
| 445 |
+
|
| 446 |
+
def pad_and_split_inputs(self,
|
| 447 |
+
tokenizer,
|
| 448 |
+
input_ids,
|
| 449 |
+
input_embeds,
|
| 450 |
+
labels,
|
| 451 |
+
position_ids,
|
| 452 |
+
attention_mask,
|
| 453 |
+
loss_scale,
|
| 454 |
+
embed_tokens=None):
|
| 455 |
+
sp_group = self.sp_group
|
| 456 |
+
split_inputs = False
|
| 457 |
+
if (input_ids is not None and not self.split_in_forward) or input_embeds is not None:
|
| 458 |
+
# Whether split the model inputs
|
| 459 |
+
# cannot split input_ids for multi-modal models
|
| 460 |
+
split_inputs = True
|
| 461 |
+
if input_ids is not None and split_inputs:
|
| 462 |
+
input_ids = self._pad_sp(input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
|
| 463 |
+
if input_embeds is not None:
|
| 464 |
+
pad_emb = embed_tokens(torch.tensor(tokenizer.pad_token_id).to(embed_tokens.weight.device)).unsqueeze(0)
|
| 465 |
+
input_embeds = self._pad_sp(input_embeds, padding_value=pad_emb, dim=1)
|
| 466 |
+
if position_ids is not None and split_inputs:
|
| 467 |
+
position_ids = self._pad_sp(position_ids, padding_value=0, dim=-1)
|
| 468 |
+
if split_inputs:
|
| 469 |
+
inputs = input_ids if input_ids is not None else input_embeds
|
| 470 |
+
attn_shape = inputs.shape[1] # The sequence length
|
| 471 |
+
if attention_mask is None:
|
| 472 |
+
attention_mask = torch.ones_like(position_ids)
|
| 473 |
+
attention_mask = self._pad_sp(attention_mask, padding_value=0, dim=-1)
|
| 474 |
+
cache_position = torch.arange(0, attn_shape, device=inputs.device)
|
| 475 |
+
# pad attention mask to 4d to avoid calculation errors
|
| 476 |
+
attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position, None,
|
| 477 |
+
None)
|
| 478 |
+
if input_ids is not None and split_inputs:
|
| 479 |
+
input_ids = self._split_sp(input_ids, dim=1, sp_group=sp_group)
|
| 480 |
+
if input_embeds is not None:
|
| 481 |
+
input_embeds = self._split_sp(input_embeds, dim=1, sp_group=sp_group)
|
| 482 |
+
if position_ids is not None and split_inputs:
|
| 483 |
+
position_ids = self._split_sp(position_ids, dim=-1, sp_group=sp_group)
|
| 484 |
+
if labels is not None:
|
| 485 |
+
labels = self._pad_sp(labels, padding_value=-100, dim=-1)
|
| 486 |
+
labels[:, 0] = -100 # make the last invalid, so we do not need to cut the loss of last token
|
| 487 |
+
labels = torch.roll(labels, shifts=-1, dims=1)
|
| 488 |
+
labels = self._split_sp(labels, dim=1, sp_group=sp_group)
|
| 489 |
+
|
| 490 |
+
if loss_scale is not None:
|
| 491 |
+
loss_scale = self._pad_sp(loss_scale, padding_value=0., dim=-1)
|
| 492 |
+
loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1)
|
| 493 |
+
loss_scale = self._split_sp(loss_scale, dim=-1, sp_group=sp_group)
|
| 494 |
+
|
| 495 |
+
return input_ids, input_embeds, labels, position_ids, attention_mask, loss_scale
|
| 496 |
+
|
| 497 |
+
def reduce_outputs(self, loss, labels):
|
| 498 |
+
return loss
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def sp_rank(self):
|
| 502 |
+
return dist.get_rank(self.device_mesh['sequence'].get_group())
|
| 503 |
+
|
| 504 |
+
@property
|
| 505 |
+
def dp_rank(self):
|
| 506 |
+
return dist.get_rank(self.device_mesh['data'].get_group())
|
| 507 |
+
|
| 508 |
+
@property
|
| 509 |
+
def sp_group(self):
|
| 510 |
+
return self.device_mesh['sequence'].get_group()
|
| 511 |
+
|
| 512 |
+
@property
|
| 513 |
+
def dp_group(self):
|
| 514 |
+
return self.device_mesh['data'].get_group()
|
| 515 |
+
|
| 516 |
+
def get_dataloader(self, trainer, dataset, batch_size):
|
| 517 |
+
data_collator = trainer.data_collator
|
| 518 |
+
if isinstance(dataset, datasets.Dataset):
|
| 519 |
+
dataset = trainer._remove_unused_columns(dataset, description='training')
|
| 520 |
+
else:
|
| 521 |
+
data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training')
|
| 522 |
+
if hasattr(dataset, '__len__'):
|
| 523 |
+
sampler = UlyssesSampler(self, dataset, seed=42)
|
| 524 |
+
dataloader_params = {
|
| 525 |
+
'batch_size': batch_size,
|
| 526 |
+
'collate_fn': data_collator,
|
| 527 |
+
'num_workers': trainer.args.dataloader_num_workers,
|
| 528 |
+
'pin_memory': trainer.args.dataloader_pin_memory,
|
| 529 |
+
'persistent_workers': trainer.args.dataloader_persistent_workers,
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
| 533 |
+
dataloader_params['sampler'] = sampler
|
| 534 |
+
dataloader_params['drop_last'] = trainer.args.dataloader_drop_last
|
| 535 |
+
dataloader_params['worker_init_fn'] = seed_worker
|
| 536 |
+
|
| 537 |
+
return DataLoader(dataset, **dataloader_params)
|
| 538 |
+
else:
|
| 539 |
+
dataloader_params = {
|
| 540 |
+
'collate_fn': data_collator,
|
| 541 |
+
'num_workers': trainer.args.dataloader_num_workers,
|
| 542 |
+
'pin_memory': trainer.args.dataloader_pin_memory,
|
| 543 |
+
'persistent_workers': trainer.args.dataloader_persistent_workers,
|
| 544 |
+
'prefetch_factor': trainer.args.dataloader_prefetch_factor
|
| 545 |
+
}
|
| 546 |
+
if dist.is_initialized() and dataloader_params['prefetch_factor']:
|
| 547 |
+
dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size()
|
| 548 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, **dataloader_params)
|
| 549 |
+
dataloader = UlyssesDispatcher(dataloader, self)
|
| 550 |
+
return dataloader
|
| 551 |
+
|
| 552 |
+
def prepare_trainer(self, trainer):
|
| 553 |
+
if trainer.train_dataset is None:
|
| 554 |
+
raise ValueError('Trainer: training requires a train_dataset.')
|
| 555 |
+
|
| 556 |
+
trainer.compute_loss_func = partial(loss_scale_sp_func, process_group=self.sp_group)
|
| 557 |
+
if hasattr(trainer, 'get_batch_logps'):
|
| 558 |
+
trainer.get_batch_logps = partial(get_batch_logps, process_group=self.sp_group)
|
| 559 |
+
if hasattr(trainer, 'get_nll_loss'):
|
| 560 |
+
|
| 561 |
+
def rlhf_loss_scale_sp_func(_, *args, **kwargs):
|
| 562 |
+
return loss_scale_sp_func(*args, process_group=self.sp_group, **kwargs)
|
| 563 |
+
|
| 564 |
+
trainer.get_nll_loss = MethodType(rlhf_loss_scale_sp_func, trainer)
|
| 565 |
+
|
| 566 |
+
from swift.plugin import metric
|
| 567 |
+
from swift.trainers import mixin
|
| 568 |
+
compute_acc_origin = metric.compute_acc
|
| 569 |
+
|
| 570 |
+
def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]:
|
| 571 |
+
|
| 572 |
+
# Gather preds and labels across the sp group
|
| 573 |
+
if isinstance(preds, np.ndarray):
|
| 574 |
+
preds = torch.from_numpy(preds).to(get_current_device())
|
| 575 |
+
if isinstance(labels, np.ndarray):
|
| 576 |
+
labels = torch.from_numpy(labels).to(get_current_device())
|
| 577 |
+
shape0 = preds.shape[0]
|
| 578 |
+
preds_output = torch.empty((shape0 * self.sp_world_size, preds.shape[1]),
|
| 579 |
+
dtype=preds.dtype,
|
| 580 |
+
device=preds.device)
|
| 581 |
+
dist.all_gather_into_tensor(preds_output, preds, group=self.sp_group)
|
| 582 |
+
preds_output = torch.cat(preds_output.split(shape0, dim=0), dim=1)
|
| 583 |
+
shape0 = labels.shape[0]
|
| 584 |
+
labels_output = torch.empty((shape0 * self.sp_world_size, labels.shape[1]),
|
| 585 |
+
dtype=labels.dtype,
|
| 586 |
+
device=labels.device)
|
| 587 |
+
dist.all_gather_into_tensor(labels_output, labels, group=self.sp_group)
|
| 588 |
+
labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=1)
|
| 589 |
+
# roll back to fit compute_acc
|
| 590 |
+
labels_output = torch.roll(labels_output, shifts=1, dims=1)
|
| 591 |
+
return compute_acc_origin(preds_output, labels_output, *args, **kwargs)
|
| 592 |
+
|
| 593 |
+
metric.compute_acc = compute_acc
|
| 594 |
+
mixin.compute_acc = compute_acc
|
ms-swift/swift/trainers/sequence_parallel/xtuner.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
import datasets
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
from datasets import Dataset
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from transformers.trainer_utils import seed_worker
|
| 10 |
+
|
| 11 |
+
from .base import SequenceParallel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class XTuner(SequenceParallel):
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def assert_xtuner_runtime_condition():
|
| 18 |
+
from swift.utils import is_xtuner_available
|
| 19 |
+
assert is_xtuner_available(), \
|
| 20 |
+
('Please install XTuner first to pack dataset to `max_length`.'
|
| 21 |
+
'`pip install -U \'xtuner[deepspeed]\'`')
|
| 22 |
+
assert dist.is_initialized(), 'pack_to_max_length is only available with distributed training.'
|
| 23 |
+
|
| 24 |
+
def pack_dataset_xtuner(self, dataset: Dataset, args: Any) -> Any:
|
| 25 |
+
self.assert_xtuner_runtime_condition()
|
| 26 |
+
if dist.get_rank() == 0:
|
| 27 |
+
ds = [i[0] for i in dataset.data]
|
| 28 |
+
train_dataset = Dataset.from_list(ds)
|
| 29 |
+
from xtuner.dataset.huggingface import pack_dataset
|
| 30 |
+
train_dataset = pack_dataset(
|
| 31 |
+
train_dataset,
|
| 32 |
+
max_length=args.max_length,
|
| 33 |
+
use_varlen_attn=False,
|
| 34 |
+
shuffle_before_pack=True,
|
| 35 |
+
map_num_proc=16)
|
| 36 |
+
objects = [train_dataset]
|
| 37 |
+
train_dataset.save_to_disk('alpaca_pack')
|
| 38 |
+
else:
|
| 39 |
+
objects = [None]
|
| 40 |
+
dist.broadcast_object_list(objects, src=0)
|
| 41 |
+
train_dataset = objects[0]
|
| 42 |
+
return train_dataset
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def sp_group(self):
|
| 46 |
+
from xtuner.parallel.sequence import get_sequence_parallel_group
|
| 47 |
+
return get_sequence_parallel_group()
|
| 48 |
+
|
| 49 |
+
def init_sequence_parallel(self, size):
|
| 50 |
+
self.assert_xtuner_runtime_condition()
|
| 51 |
+
from xtuner.parallel.sequence import init_sequence_parallel
|
| 52 |
+
init_sequence_parallel(size)
|
| 53 |
+
|
| 54 |
+
def prepare_model(self, model, tokenizer, split_in_forward):
|
| 55 |
+
self.assert_xtuner_runtime_condition()
|
| 56 |
+
from xtuner.model.modules.dispatch import dispatch_modules
|
| 57 |
+
dispatch_modules(model)
|
| 58 |
+
|
| 59 |
+
def pad_and_split_inputs(self,
|
| 60 |
+
tokenizer,
|
| 61 |
+
input_ids,
|
| 62 |
+
input_embeds,
|
| 63 |
+
labels,
|
| 64 |
+
position_ids,
|
| 65 |
+
attention_mask,
|
| 66 |
+
loss_scale,
|
| 67 |
+
embed_tokens=None):
|
| 68 |
+
self.assert_xtuner_runtime_condition()
|
| 69 |
+
from xtuner.parallel.sequence import (pad_for_sequence_parallel, split_for_sequence_parallel,
|
| 70 |
+
get_sequence_parallel_group)
|
| 71 |
+
input_ids = pad_for_sequence_parallel(input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
|
| 72 |
+
labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
|
| 73 |
+
position_ids = pad_for_sequence_parallel(position_ids, padding_value=0, dim=-1)
|
| 74 |
+
if attention_mask is not None:
|
| 75 |
+
attention_mask = pad_for_sequence_parallel(attention_mask, padding_value=0, dim=-1)
|
| 76 |
+
|
| 77 |
+
sp_group = get_sequence_parallel_group()
|
| 78 |
+
input_ids = split_for_sequence_parallel(input_ids, dim=1, sp_group=sp_group)
|
| 79 |
+
labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
|
| 80 |
+
position_ids = split_for_sequence_parallel(position_ids, dim=1, sp_group=sp_group)
|
| 81 |
+
if attention_mask is not None:
|
| 82 |
+
attention_mask = split_for_sequence_parallel(attention_mask, dim=-1, sp_group=sp_group)
|
| 83 |
+
if loss_scale is not None:
|
| 84 |
+
loss_scale = pad_for_sequence_parallel(loss_scale, padding_value=0., dim=-1)
|
| 85 |
+
loss_scale = split_for_sequence_parallel(loss_scale, dim=1, sp_group=sp_group)
|
| 86 |
+
|
| 87 |
+
return input_ids, None, labels, position_ids, attention_mask, loss_scale
|
| 88 |
+
|
| 89 |
+
def reduce_outputs(self, loss, labels):
|
| 90 |
+
from xtuner.parallel.sequence import (reduce_sequence_parallel_loss, get_sequence_parallel_group)
|
| 91 |
+
# reduce loss for logging correctly
|
| 92 |
+
num_tokens = (labels != -100).sum()
|
| 93 |
+
return reduce_sequence_parallel_loss(loss, num_tokens, get_sequence_parallel_group())
|
| 94 |
+
|
| 95 |
+
def world_size(self):
|
| 96 |
+
self.assert_xtuner_runtime_condition()
|
| 97 |
+
from xtuner.parallel.sequence import get_sequence_parallel_world_size
|
| 98 |
+
return get_sequence_parallel_world_size()
|
| 99 |
+
|
| 100 |
+
def prepare_trainer(self, trainer):
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
def get_dataloader(self, trainer, dataset, batch_size):
|
| 104 |
+
# modified from HFTrainer.get_train_dataloader
|
| 105 |
+
# RandomSampler -> SequenceParallelSampler
|
| 106 |
+
self.assert_xtuner_runtime_condition()
|
| 107 |
+
data_collator = trainer.data_collator
|
| 108 |
+
if isinstance(dataset, datasets.Dataset):
|
| 109 |
+
dataset = trainer._remove_unused_columns(dataset, description='training')
|
| 110 |
+
else:
|
| 111 |
+
data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training')
|
| 112 |
+
|
| 113 |
+
dataloader_params = {
|
| 114 |
+
'batch_size': batch_size,
|
| 115 |
+
'collate_fn': data_collator,
|
| 116 |
+
'num_workers': trainer.args.dataloader_num_workers,
|
| 117 |
+
'pin_memory': trainer.args.dataloader_pin_memory,
|
| 118 |
+
'persistent_workers': trainer.args.dataloader_persistent_workers,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
| 122 |
+
from xtuner.parallel import SequenceParallelSampler
|
| 123 |
+
dataloader_params['sampler'] = SequenceParallelSampler(dataset, seed=1024)
|
| 124 |
+
dataloader_params['drop_last'] = trainer.args.dataloader_drop_last
|
| 125 |
+
dataloader_params['worker_init_fn'] = seed_worker
|
| 126 |
+
|
| 127 |
+
return DataLoader(dataset, **dataloader_params)
|
ms-swift/swift/trainers/torchacc_mixin.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from transformers import PreTrainedModel, is_datasets_available
|
| 7 |
+
|
| 8 |
+
from swift.utils import use_torchacc
|
| 9 |
+
from swift.utils.torchacc_utils import (patch_clip_grad_norm, save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint,
|
| 10 |
+
ta_eval_dataloader, ta_load_optimizer_and_scheduler,
|
| 11 |
+
ta_save_optimizer_and_scheduler, ta_test_dataloader, ta_train_dataloader,
|
| 12 |
+
ta_trim_graph)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TorchAccMixin:
|
| 16 |
+
|
| 17 |
+
def __init__(self, *args, **kwargs):
|
| 18 |
+
if use_torchacc():
|
| 19 |
+
patch_clip_grad_norm(self.accelerator)
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
|
| 22 |
+
def get_train_dataloader(self):
|
| 23 |
+
if not use_torchacc():
|
| 24 |
+
return super().get_train_dataloader()
|
| 25 |
+
|
| 26 |
+
if is_datasets_available():
|
| 27 |
+
import datasets
|
| 28 |
+
|
| 29 |
+
if self.train_dataset is None:
|
| 30 |
+
raise ValueError('Trainer: training requires a train_dataset.')
|
| 31 |
+
|
| 32 |
+
train_dataset = self.train_dataset
|
| 33 |
+
data_collator = self.data_collator
|
| 34 |
+
|
| 35 |
+
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
| 36 |
+
train_dataset = self._remove_unused_columns(train_dataset, description='training')
|
| 37 |
+
else:
|
| 38 |
+
data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
|
| 39 |
+
|
| 40 |
+
return ta_train_dataloader(train_dataset, data_collator, self._get_train_sampler(), self.args,
|
| 41 |
+
self._train_batch_size)
|
| 42 |
+
|
| 43 |
+
def get_eval_dataloader(self, eval_dataset=None):
|
| 44 |
+
|
| 45 |
+
if not use_torchacc():
|
| 46 |
+
return super().get_eval_dataloader(eval_dataset)
|
| 47 |
+
|
| 48 |
+
if is_datasets_available():
|
| 49 |
+
import datasets
|
| 50 |
+
|
| 51 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 52 |
+
raise ValueError('Trainer: evaluation requires an eval_dataset.')
|
| 53 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 54 |
+
data_collator = self.data_collator
|
| 55 |
+
|
| 56 |
+
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
| 57 |
+
eval_dataset = self._remove_unused_columns(eval_dataset, description='evaluation')
|
| 58 |
+
else:
|
| 59 |
+
data_collator = self._get_collator_with_removed_columns(data_collator, description='evaluation')
|
| 60 |
+
|
| 61 |
+
return ta_eval_dataloader(eval_dataset, data_collator, self._get_eval_sampler(eval_dataset), self.args)
|
| 62 |
+
|
| 63 |
+
def get_test_dataloader(self, test_dataset):
|
| 64 |
+
|
| 65 |
+
if not use_torchacc():
|
| 66 |
+
return super().get_test_dataloader(test_dataset)
|
| 67 |
+
|
| 68 |
+
if is_datasets_available():
|
| 69 |
+
import datasets
|
| 70 |
+
|
| 71 |
+
data_collator = self.data_collator
|
| 72 |
+
|
| 73 |
+
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
| 74 |
+
test_dataset = self._remove_unused_columns(test_dataset, description='test')
|
| 75 |
+
else:
|
| 76 |
+
data_collator = self._get_collator_with_removed_columns(data_collator, description='test')
|
| 77 |
+
|
| 78 |
+
return ta_test_dataloader(test_dataset, data_collator, self._get_eval_sampler(test_dataset), self.args)
|
| 79 |
+
|
| 80 |
+
def _save_tpu(self, output_dir: Optional[str] = None):
|
| 81 |
+
|
| 82 |
+
if not use_torchacc():
|
| 83 |
+
return super()._save_tpu(output_dir)
|
| 84 |
+
|
| 85 |
+
import torch_xla.core.xla_model as xm
|
| 86 |
+
|
| 87 |
+
# Compatible with swift and peft
|
| 88 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 89 |
+
|
| 90 |
+
if xm.is_master_ordinal(local=False):
|
| 91 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 92 |
+
# configuration.json
|
| 93 |
+
model_dir = getattr(self.model, 'model_dir', None)
|
| 94 |
+
if model_dir is not None:
|
| 95 |
+
src_path = os.path.join(model_dir, 'configuration.json')
|
| 96 |
+
dst_path = os.path.join(output_dir, 'configuration.json')
|
| 97 |
+
if os.path.exists(src_path):
|
| 98 |
+
shutil.copy(src_path, dst_path)
|
| 99 |
+
else:
|
| 100 |
+
self._create_configuration_file(self.model, output_dir)
|
| 101 |
+
self._save_sft_args(output_dir)
|
| 102 |
+
# generation_config
|
| 103 |
+
generation_config = getattr(self.args, 'generation_config', None)
|
| 104 |
+
if generation_config is not None:
|
| 105 |
+
generation_config.save_pretrained(output_dir)
|
| 106 |
+
|
| 107 |
+
# model
|
| 108 |
+
if self.args.fsdp_num > 1:
|
| 109 |
+
save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
|
| 110 |
+
else:
|
| 111 |
+
save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
|
| 112 |
+
|
| 113 |
+
# additional files
|
| 114 |
+
if xm.is_master_ordinal(local=False):
|
| 115 |
+
if self.args is not None and self.args.sft_type == 'full':
|
| 116 |
+
additional_files = getattr(self.args, 'additional_saved_files',
|
| 117 |
+
None) or [] + ['preprocessor_config.json']
|
| 118 |
+
if model_dir is not None:
|
| 119 |
+
for file in additional_files:
|
| 120 |
+
src_path = os.path.join(model_dir, file)
|
| 121 |
+
dst_path = os.path.join(output_dir, file)
|
| 122 |
+
if os.path.isfile(src_path):
|
| 123 |
+
shutil.copy(src_path, dst_path)
|
| 124 |
+
elif os.path.isdir(src_path):
|
| 125 |
+
shutil.copytree(src_path, dst_path)
|
| 126 |
+
|
| 127 |
+
def _load_optimizer_and_scheduler(self, checkpoint):
|
| 128 |
+
|
| 129 |
+
if not use_torchacc() or self.args.fsdp_num == 1:
|
| 130 |
+
return super()._load_optimizer_and_scheduler(checkpoint)
|
| 131 |
+
|
| 132 |
+
self.optimizer, self.lr_scheduler = ta_load_optimizer_and_scheduler(self.optimizer, self.lr_scheduler,
|
| 133 |
+
checkpoint, self.args.device)
|
| 134 |
+
|
| 135 |
+
def _save_optimizer_and_scheduler(self, output_dir):
|
| 136 |
+
if not use_torchacc() or not self.args.fsdp_num == 1:
|
| 137 |
+
return super()._save_optimizer_and_scheduler(output_dir)
|
| 138 |
+
|
| 139 |
+
return ta_save_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, output_dir)
|
| 140 |
+
|
| 141 |
+
def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
|
| 142 |
+
if use_torchacc() and self.control.should_log:
|
| 143 |
+
ta_trim_graph()
|
| 144 |
+
super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
|
| 145 |
+
|
| 146 |
+
def _load_from_checkpoint(self, resume_from_checkpoint: str, model=None) -> None:
|
| 147 |
+
if use_torchacc():
|
| 148 |
+
if model is None:
|
| 149 |
+
model = self.model
|
| 150 |
+
# Loading checkpoint of TorchAcc has been done in tuner.py when
|
| 151 |
+
# sft_type is 'full'.
|
| 152 |
+
if self.args.fsdp_num > 1:
|
| 153 |
+
model = model._get_underlay_model().module.module
|
| 154 |
+
if isinstance(model, PreTrainedModel):
|
| 155 |
+
return
|
| 156 |
+
return super()._load_from_checkpoint(resume_from_checkpoint, model)
|