| import math |
| from functools import partial |
| from types import MethodType |
| from typing import Any, Dict, Iterator, List, Optional, Tuple |
|
|
| import datasets |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from peft import PeftModel |
| from torch.distributed.device_mesh import init_device_mesh |
| from torch.nn import CrossEntropyLoss |
| from torch.utils.data import DataLoader, Sampler |
| from transformers.trainer_utils import seed_worker |
|
|
| from swift.llm import DataLoaderDispatcher, get_model_arch |
| from swift.tuners import SwiftModel |
| from swift.utils import get_current_device, get_device, get_dist_setting |
| from .base import SequenceParallel |
|
|
|
|
| class GatherLoss(torch.autograd.Function): |
| """Gather loss from sequence group""" |
|
|
| @staticmethod |
| def forward(ctx, loss, labels, process_group, gather_idx=None): |
| """ |
| Args: |
| loss: loss tensor after splitting |
| labels: labels tensor after splitting |
| process_group: the sequence parallel group |
| gather_idx: gather the tensors on this dim |
| """ |
| ctx.process_group = process_group |
| shape0 = labels.shape[0] |
| ctx.scatter_shape = labels.shape[gather_idx or 0] |
| ctx.gather_idx = gather_idx or 0 |
| world_size = dist.get_world_size(group=process_group) |
| output = torch.empty((shape0 * world_size, *loss.shape[1:]), dtype=loss.dtype, device=loss.device) |
| |
| dist.all_gather_into_tensor(output, loss, group=process_group) |
| if gather_idx is not None: |
| output = torch.cat(output.split(shape0, dim=0), dim=gather_idx) |
| labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]), dtype=labels.dtype, device=labels.device) |
| dist.all_gather_into_tensor(labels_output, labels, group=process_group) |
| if gather_idx is not None: |
| labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx) |
| return output, labels_output |
|
|
| @staticmethod |
| def backward(ctx, *grad_output): |
| _grad = grad_output[0] * dist.get_world_size(group=ctx.process_group) |
| return _grad.split( |
| ctx.scatter_shape, dim=ctx.gather_idx)[dist.get_rank(ctx.process_group)].contiguous(), None, None, None |
|
|
|
|
| |
| def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None, process_group=None) -> torch.Tensor: |
| if hasattr(outputs, 'logits'): |
| logits = outputs.logits |
| else: |
| logits = outputs |
| device = logits.device |
| logits = logits.view(-1, logits.shape[-1]) |
| labels = labels.flatten().to(device) |
| |
| loss_fct = CrossEntropyLoss(reduction='none') |
| |
| loss = loss_fct(logits, labels) |
|
|
| if loss_scale is not None: |
| loss_scale = loss_scale.flatten().to(loss.device) |
| loss = (loss_scale * loss) |
| loss, labels = GatherLoss.apply(loss, labels, process_group) |
| loss = loss[labels != -100].sum() |
| if num_items_in_batch is None: |
| loss = loss / (labels != -100).sum() |
| else: |
| loss = loss / num_items_in_batch |
| return loss |
|
|
|
|
| |
| def get_batch_logps(logits: torch.FloatTensor, |
| labels: torch.LongTensor, |
| label_pad_token_id: int = -100, |
| is_encoder_decoder: bool = False, |
| process_group=None) -> Tuple[torch.FloatTensor, torch.LongTensor]: |
| labels = labels.clone() |
| loss_mask = labels != label_pad_token_id |
| labels[labels == label_pad_token_id] = 0 |
| labels = labels.to(logits.device) |
| loss_mask = loss_mask.to(logits.device) |
| per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) |
| total_per_token_logps, total_loss_mask = GatherLoss.apply(per_token_logps, loss_mask, process_group, 1) |
| return (total_per_token_logps * total_loss_mask).sum(-1), total_loss_mask.sum(-1) |
|
|
|
|
| class UlyssesSampler(Sampler): |
|
|
| |
| def __init__(self, ulysses, dataset, shuffle: bool = True, seed=None, round_up: bool = True) -> None: |
| self.ulysses = ulysses |
| rank = dist.get_rank(ulysses.device_mesh['data'].get_group()) |
| world_size = ulysses.device_mesh['data'].size() |
| self.rank = rank |
| self.world_size = world_size |
|
|
| self.dataset = dataset |
| self.shuffle = shuffle |
| assert seed is not None |
| self.seed = seed |
| self.epoch = 0 |
| self.round_up = round_up |
|
|
| if self.round_up: |
| self.num_samples = math.ceil(len(self.dataset) / world_size) |
| self.total_size = self.num_samples * self.world_size |
| else: |
| self.num_samples = math.ceil((len(self.dataset) - rank) / world_size) |
| self.total_size = len(self.dataset) |
|
|
| def __iter__(self) -> Iterator[int]: |
| if self.shuffle: |
| g = torch.Generator() |
| g.manual_seed(self.seed + self.epoch) |
| indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| else: |
| indices = torch.arange(len(self.dataset)).tolist() |
|
|
| if self.round_up: |
| indices = (indices * int(self.total_size / len(indices) + 1))[:self.total_size] |
|
|
| indices = indices[self.rank:self.total_size:self.world_size] |
|
|
| return iter(indices) |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
|
|
| def set_epoch(self, epoch: int) -> None: |
| self.epoch = epoch |
|
|
|
|
| class UlyssesDispatcher(DataLoaderDispatcher): |
|
|
| def __init__(self, base_dataloader, ulysses): |
| super().__init__(base_dataloader) |
| self.ulysses = ulysses |
|
|
| def __iter__(self): |
| base_iter = iter(self.base_dataloader) |
| while True: |
| data = None |
| try: |
| for i in range(self.ulysses.dp_world_size): |
| data = next(base_iter) |
| if i == self.ulysses.dp_rank: |
| break |
| except StopIteration: |
| pass |
| if data is None: |
| break |
| yield data |
|
|
|
|
| |
| |
| |
| def _generate_layout_params(scatter_idx, seq_world_size, input): |
| if scatter_idx < 2: |
| bs, global_seq_len, num_local_head, head_dim = input.shape |
| pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim] |
| pre_all2all_permute_idx = (1, 0, 2, 3, 4) |
|
|
| post_all2all_permute_idx = (1, 2, 0, 3, 4) |
| post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim] |
| else: |
| bs, local_seq_len, num_total_head, head_dim = input.shape |
| assert num_total_head % seq_world_size == 0, (f'Number of heads ({num_total_head}) must be divisible ' |
| f'by the sequence parallel size ({seq_world_size})!') |
| pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim] |
| pre_all2all_permute_idx = (2, 0, 1, 3, 4) |
|
|
| post_all2all_permute_idx = (1, 0, 2, 3, 4) |
| post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim] |
|
|
| return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape |
|
|
|
|
| def post_all2all(permute_idx, res_shape): |
| """ |
| Post-processing function for `all2all` communication. |
| """ |
|
|
| def post_func(input): |
| if permute_idx is not None: |
| input = input.permute(permute_idx).contiguous() |
| output = input.reshape(res_shape).contiguous() |
|
|
| return output |
|
|
| return post_func |
|
|
|
|
| def pre_all2all_fun(permute_idx, inp_shape, input): |
| """ |
| Pre-processing function for `all2all` communication. |
| """ |
| input_t = input.reshape(inp_shape).contiguous() |
| if permute_idx is not None: |
| input_t = input_t.permute(permute_idx).contiguous() |
| return input_t |
|
|
|
|
| def single_all_to_all(input, scatter_idx, gather_idx, group, **kwargs): |
| seq_world_size = dist.get_world_size(group) |
| num_heads = input.shape[2] |
| if num_heads % seq_world_size != 0 and not scatter_idx < 2: |
| raise NotImplementedError |
| pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = ( |
| _generate_layout_params(scatter_idx, seq_world_size, input)) |
|
|
| input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input) |
|
|
| post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape) |
| output = torch.empty_like(input_t) |
| dist.all_to_all_single(output, input_t, group=group) |
|
|
| res = post_all2all_fun(output) |
| return res |
|
|
|
|
| class _SeqAllToAll(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward( |
| ctx: Any, |
| group: dist.ProcessGroup, |
| input: torch.Tensor, |
| scatter_idx: int, |
| gather_idx: int, |
| ) -> torch.Tensor: |
| ctx.group = group |
| ctx.scatter_idx = scatter_idx |
| ctx.gather_idx = gather_idx |
| res = single_all_to_all(input, scatter_idx, gather_idx, group) |
| return res |
|
|
| @staticmethod |
| def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]: |
| return None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None |
|
|
|
|
| class DistributedAttention(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| local_attention, |
| sequence_process_group: dist.ProcessGroup, |
| scatter_idx: int = 2, |
| gather_idx: int = 1, |
| ) -> None: |
| super(DistributedAttention, self).__init__() |
| self.local_attn = local_attention |
| self.spg = sequence_process_group |
| self.scatter_idx = scatter_idx |
| self.gather_idx = gather_idx |
|
|
| def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, |
| *args: Any, **kwargs) -> torch.Tensor: |
| query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) |
| key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) |
| value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) |
| position_ids = kwargs.pop('position_ids', None) |
| if position_ids is not None: |
| shape0 = position_ids.shape[0] |
| position_ids_output = torch.empty((shape0 * dist.get_world_size(self.spg), position_ids.shape[1]), |
| dtype=position_ids.dtype, |
| device=position_ids.device) |
| dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.spg) |
| position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1) |
| context_layer = self.local_attn( |
| query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs) |
| output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) |
| return output |
|
|
|
|
| class Ulysses(SequenceParallel): |
|
|
| def __init__(self): |
| self.split_in_forward = None |
| self.dp_world_size = None |
| self.sp_world_size = None |
| self.model_dtype = None |
| self.causal_mask_func = None |
| self.device_mesh = None |
| self._inited = False |
|
|
| def init_sequence_parallel(self, size): |
| if self._inited: |
| return |
| self._inited = True |
| self.sp_world_size = size |
| rank, local_rank, world_size, local_world_size = get_dist_setting() |
| self.dp_world_size = world_size // size |
| self.device_mesh = init_device_mesh( |
| get_device().split(':')[0], mesh_shape=(world_size // size, size), mesh_dim_names=['data', 'sequence']) |
|
|
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'] = ALL_ATTENTION_FUNCTIONS['flash_attention_2'] |
| ALL_ATTENTION_FUNCTIONS['sdpa_origin'] = ALL_ATTENTION_FUNCTIONS['sdpa'] |
|
|
| def local_flash_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args, |
| dist_attn, **kwargs): |
| if dist_attn.local_attn is None: |
|
|
| def _attention(query, key, value, *args, **kwargs): |
| query = query.transpose(1, 2) |
| key = key.transpose(1, 2) |
| value = value.transpose(1, 2) |
| return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args, |
| **kwargs)[0] |
|
|
| dist_attn.local_attn = _attention |
|
|
| return dist_attn( |
| query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask, |
| *args, **kwargs), None |
|
|
| def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args, |
| dist_attn, **kwargs): |
| if dist_attn.local_attn is None: |
|
|
| def _attention(query, key, value, *args, **kwargs): |
| query = query.transpose(1, 2) |
| key = key.transpose(1, 2) |
| value = value.transpose(1, 2) |
| return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query, key, value, *args, **kwargs)[0] |
|
|
| dist_attn.local_attn = _attention |
| return dist_attn( |
| query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask, |
| *args, **kwargs), None |
|
|
| ALL_ATTENTION_FUNCTIONS['flash_attention_2'] = partial( |
| local_flash_attn, dist_attn=DistributedAttention(None, self.sp_group)) |
| ALL_ATTENTION_FUNCTIONS['sdpa'] = partial(local_sdpa_attn, dist_attn=DistributedAttention(None, self.sp_group)) |
|
|
| from transformers.modeling_flash_attention_utils import is_flash_attn_available |
| if is_flash_attn_available(): |
| |
| |
| |
| from transformers import modeling_flash_attention_utils |
| from transformers.modeling_flash_attention_utils import _flash_attention_forward |
| _distributed_flash_attention = DistributedAttention(_flash_attention_forward, self.sp_group) |
|
|
| def flash_attention_forward(query_states: torch.Tensor, key_states: torch.Tensor, |
| value_states: torch.Tensor, attention_mask: Optional[torch.Tensor], q_len, |
| *args, **kwargs): |
| return _distributed_flash_attention(query_states, key_states, value_states, attention_mask, |
| q_len * self.sp_world_size, *args, **kwargs) |
|
|
| modeling_flash_attention_utils._flash_attention_forward = flash_attention_forward |
|
|
| def prepare_model(self, model, tokenizer, split_in_forward): |
| self.split_in_forward = split_in_forward |
|
|
| def forward(_self, **kwargs): |
| |
| inputs_embeds = kwargs['inputs_embeds'] |
| position_ids = kwargs['position_ids'] |
| attention_mask = kwargs['attention_mask'] |
| _, inputs_embeds, _, position_ids, attention_mask, _ = self.pad_and_split_inputs( |
| tokenizer, |
| None, |
| inputs_embeds, |
| None, |
| position_ids, |
| attention_mask, |
| None, |
| embed_tokens=_self.embed_tokens) |
| kwargs['inputs_embeds'] = inputs_embeds |
| kwargs['position_ids'] = position_ids |
| kwargs['attention_mask'] = attention_mask |
| return _self.forward_origin(**kwargs) |
|
|
| if isinstance(model, (SwiftModel, PeftModel)): |
| model = model.model |
| model_meta = model.model_meta |
| llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None) |
| if llm_prefix: |
| llm_model = getattr(model, llm_prefix[0]) |
| else: |
| llm_model = model |
|
|
| if 'CausalLM' not in llm_model.__class__.__name__: |
| llm_model = model |
|
|
| base_model = llm_model.model |
| self.causal_mask_func = base_model._update_causal_mask |
| if self.split_in_forward: |
| |
| base_model.forward_origin = base_model.forward |
| base_model.forward = MethodType(forward, base_model) |
|
|
| self.model_dtype = next(model.parameters()).dtype |
|
|
| def _pad_sp(self, tensor, padding_value, dim=-1): |
| |
| length = tensor.shape[dim] |
| if length % self.sp_world_size == 0: |
| return tensor |
|
|
| pad_num = self.sp_world_size - (length % self.sp_world_size) |
| if not isinstance(padding_value, torch.Tensor): |
| |
| pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else |
| (*tensor.shape[:dim], pad_num)) |
| pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device) |
| tensor = torch.cat([tensor, pad], dim=dim) |
| else: |
| |
| tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim) |
| return tensor |
|
|
| def world_size(self): |
| return self.sp_world_size |
|
|
| def _split_sp(self, input, dim: int, sp_group: dist.ProcessGroup): |
| |
| if self.sp_world_size == 1: |
| return input |
|
|
| rank = dist.get_rank(sp_group) |
| dim_size = input.size(dim) |
| assert dim_size % self.sp_world_size == 0, (f'The dimension to split ({dim_size}) is not a multiple of ' |
| f'world size ({self.sp_world_size}), cannot split tensor evenly') |
|
|
| tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim) |
| output = tensor_list[rank].contiguous() |
|
|
| return output |
|
|
| def pad_and_split_inputs(self, |
| tokenizer, |
| input_ids, |
| input_embeds, |
| labels, |
| position_ids, |
| attention_mask, |
| loss_scale, |
| embed_tokens=None): |
| sp_group = self.sp_group |
| split_inputs = False |
| if (input_ids is not None and not self.split_in_forward) or input_embeds is not None: |
| |
| |
| split_inputs = True |
| if input_ids is not None and split_inputs: |
| input_ids = self._pad_sp(input_ids, padding_value=tokenizer.pad_token_id, dim=-1) |
| if input_embeds is not None: |
| pad_emb = embed_tokens(torch.tensor(tokenizer.pad_token_id).to(embed_tokens.weight.device)).unsqueeze(0) |
| input_embeds = self._pad_sp(input_embeds, padding_value=pad_emb, dim=1) |
| if position_ids is not None and split_inputs: |
| position_ids = self._pad_sp(position_ids, padding_value=0, dim=-1) |
| if split_inputs: |
| inputs = input_ids if input_ids is not None else input_embeds |
| attn_shape = inputs.shape[1] |
| if attention_mask is None: |
| attention_mask = torch.ones_like(position_ids) |
| attention_mask = self._pad_sp(attention_mask, padding_value=0, dim=-1) |
| cache_position = torch.arange(0, attn_shape, device=inputs.device) |
| |
| attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position, None, |
| None) |
| if input_ids is not None and split_inputs: |
| input_ids = self._split_sp(input_ids, dim=1, sp_group=sp_group) |
| if input_embeds is not None: |
| input_embeds = self._split_sp(input_embeds, dim=1, sp_group=sp_group) |
| if position_ids is not None and split_inputs: |
| position_ids = self._split_sp(position_ids, dim=-1, sp_group=sp_group) |
| if labels is not None: |
| labels = self._pad_sp(labels, padding_value=-100, dim=-1) |
| labels[:, 0] = -100 |
| labels = torch.roll(labels, shifts=-1, dims=1) |
| labels = self._split_sp(labels, dim=1, sp_group=sp_group) |
|
|
| if loss_scale is not None: |
| loss_scale = self._pad_sp(loss_scale, padding_value=0., dim=-1) |
| loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1) |
| loss_scale = self._split_sp(loss_scale, dim=-1, sp_group=sp_group) |
|
|
| return input_ids, input_embeds, labels, position_ids, attention_mask, loss_scale |
|
|
| def reduce_outputs(self, loss, labels): |
| return loss |
|
|
| @property |
| def sp_rank(self): |
| return dist.get_rank(self.device_mesh['sequence'].get_group()) |
|
|
| @property |
| def dp_rank(self): |
| return dist.get_rank(self.device_mesh['data'].get_group()) |
|
|
| @property |
| def sp_group(self): |
| return self.device_mesh['sequence'].get_group() |
|
|
| @property |
| def dp_group(self): |
| return self.device_mesh['data'].get_group() |
|
|
| def get_dataloader(self, trainer, dataset, batch_size): |
| data_collator = trainer.data_collator |
| if isinstance(dataset, datasets.Dataset): |
| dataset = trainer._remove_unused_columns(dataset, description='training') |
| else: |
| data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training') |
| if hasattr(dataset, '__len__'): |
| sampler = UlyssesSampler(self, dataset, seed=42) |
| dataloader_params = { |
| 'batch_size': batch_size, |
| 'collate_fn': data_collator, |
| 'num_workers': trainer.args.dataloader_num_workers, |
| 'pin_memory': trainer.args.dataloader_pin_memory, |
| 'persistent_workers': trainer.args.dataloader_persistent_workers, |
| } |
|
|
| if not isinstance(dataset, torch.utils.data.IterableDataset): |
| dataloader_params['sampler'] = sampler |
| dataloader_params['drop_last'] = trainer.args.dataloader_drop_last |
| dataloader_params['worker_init_fn'] = seed_worker |
|
|
| return DataLoader(dataset, **dataloader_params) |
| else: |
| dataloader_params = { |
| 'collate_fn': data_collator, |
| 'num_workers': trainer.args.dataloader_num_workers, |
| 'pin_memory': trainer.args.dataloader_pin_memory, |
| 'persistent_workers': trainer.args.dataloader_persistent_workers, |
| 'prefetch_factor': trainer.args.dataloader_prefetch_factor |
| } |
| if dist.is_initialized() and dataloader_params['prefetch_factor']: |
| dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size() |
| dataloader = DataLoader(dataset, batch_size=batch_size, **dataloader_params) |
| dataloader = UlyssesDispatcher(dataloader, self) |
| return dataloader |
|
|
| def prepare_trainer(self, trainer): |
| if trainer.train_dataset is None: |
| raise ValueError('Trainer: training requires a train_dataset.') |
|
|
| trainer.compute_loss_func = partial(loss_scale_sp_func, process_group=self.sp_group) |
| if hasattr(trainer, 'get_batch_logps'): |
| trainer.get_batch_logps = partial(get_batch_logps, process_group=self.sp_group) |
| if hasattr(trainer, 'get_nll_loss'): |
|
|
| def rlhf_loss_scale_sp_func(_, *args, **kwargs): |
| return loss_scale_sp_func(*args, process_group=self.sp_group, **kwargs) |
|
|
| trainer.get_nll_loss = MethodType(rlhf_loss_scale_sp_func, trainer) |
|
|
| from swift.plugin import metric |
| from swift.trainers import mixin |
| compute_acc_origin = metric.compute_acc |
|
|
| def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]: |
|
|
| |
| if isinstance(preds, np.ndarray): |
| preds = torch.from_numpy(preds).to(get_current_device()) |
| if isinstance(labels, np.ndarray): |
| labels = torch.from_numpy(labels).to(get_current_device()) |
| shape0 = preds.shape[0] |
| preds_output = torch.empty((shape0 * self.sp_world_size, preds.shape[1]), |
| dtype=preds.dtype, |
| device=preds.device) |
| dist.all_gather_into_tensor(preds_output, preds, group=self.sp_group) |
| preds_output = torch.cat(preds_output.split(shape0, dim=0), dim=1) |
| shape0 = labels.shape[0] |
| labels_output = torch.empty((shape0 * self.sp_world_size, labels.shape[1]), |
| dtype=labels.dtype, |
| device=labels.device) |
| dist.all_gather_into_tensor(labels_output, labels, group=self.sp_group) |
| labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=1) |
| |
| labels_output = torch.roll(labels_output, shifts=1, dims=1) |
| return compute_acc_origin(preds_output, labels_output, *args, **kwargs) |
|
|
| metric.compute_acc = compute_acc |
| mixin.compute_acc = compute_acc |
|
|