| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from torch.nested._internal.nested_tensor import NestedTensor |
|
|
| from verl.utils.megatron_utils import unwrap_model |
| from verl.workers.config import MtpConfig |
|
|
| from .util import ( |
| postprocess_bshd, |
| postprocess_bshd_no_padding, |
| postprocess_packed_seqs, |
| postprocess_thd_no_padding, |
| preprocess_bshd, |
| preprocess_bshd_no_padding, |
| preprocess_packed_seqs, |
| preprocess_thd_no_padding, |
| ) |
|
|
|
|
| def model_forward_gen(vision_model: bool = False): |
| def model_forward( |
| model, |
| input_ids, |
| attention_mask, |
| position_ids, |
| multi_modal_inputs: dict, |
| logits_processor=None, |
| logits_processor_args: dict = None, |
| value_model=False, |
| data_format: str = "thd", |
| mtp_config: MtpConfig = None, |
| ): |
| """Forward pass for models with sequence packing.""" |
| assert data_format in ["thd", "bshd"], "data_format must be 'thd' or 'bshd'" |
| pre_process = ( |
| unwrap_model(model).pre_process if not vision_model else False |
| ) |
| post_process = unwrap_model(model).post_process |
| sp = unwrap_model(model).config.sequence_parallel |
| fp8 = unwrap_model(model).config.fp8 |
| use_fp8_padding = fp8 in ["e4m3", "hybrid"] |
|
|
| model_kwargs = {} |
| if "pixel_values" in multi_modal_inputs: |
| model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device) |
| if "image_grid_thw" in multi_modal_inputs: |
| model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device) |
| if "pixel_values_videos" in multi_modal_inputs: |
| model_kwargs["pixel_values_videos"] = multi_modal_inputs["pixel_values_videos"].to(input_ids.device) |
| if "video_grid_thw" in multi_modal_inputs: |
| model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) |
|
|
| batch_size, seq_len = attention_mask.shape[:2] |
| mtp_enable_train = mtp_config and mtp_config.enable_train |
|
|
| if data_format == "thd": |
| input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( |
| input_ids, |
| attention_mask, |
| pre_process=pre_process or (post_process and mtp_enable_train), |
| use_fp8_padding=use_fp8_padding, |
| ) |
| input_ids_rmpad = input_ids_rmpad.contiguous() |
|
|
| |
| if mtp_enable_train and post_process: |
| args = { |
| k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0] |
| for k, v in logits_processor_args.items() |
| } |
| model_kwargs["labels"] = args["label"].contiguous() |
| model_kwargs["loss_mask"] = args["label_mask"].contiguous() |
|
|
| input_args = dict( |
| input_ids=input_ids_rmpad, |
| attention_mask=None, |
| position_ids=position_ids if not vision_model else None, |
| packed_seq_params=packed_seq_params, |
| **model_kwargs, |
| ) |
|
|
| if vision_model: |
| |
| |
| |
| |
| input_args["input_ids"] = input_ids |
| input_args["attention_mask"] = attention_mask |
|
|
| output_orig = model(**input_args) |
|
|
| if post_process and logits_processor is not None: |
| args = { |
| k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0] |
| for k, v in logits_processor_args.items() |
| } |
| output_dict = logits_processor(output_orig, **args) |
| output = { |
| k: postprocess_packed_seqs( |
| v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process |
| ) |
| for k, v in output_dict.items() |
| } |
| else: |
| output = postprocess_packed_seqs( |
| output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process |
| ) |
| elif data_format == "bshd": |
| """ |
| data_format: "thd" or "bshd", default is "thd", |
| why we need this? |
| for some new models, GPT-OSS, the thd format is not supported, so we need to use the bshd format. |
| When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length, |
| so it is recommended to disable dynamic batch size and set batch size to 1 |
| """ |
| assert fp8 is None, "fp8 is not supported for bshd format yet" |
|
|
| batch_size, sequence_length = attention_mask.shape[:2] |
| position_ids_for_preprocess = ( |
| torch.arange(sequence_length, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) |
| if vision_model |
| else position_ids |
| ) |
| pre_process_for_bshd = True if vision_model else pre_process |
| new_input_ids, new_attention_mask, new_position_ids = preprocess_bshd( |
| input_ids, |
| attention_mask, |
| position_ids_for_preprocess, |
| sequence_parallel=sp, |
| pre_process=pre_process_for_bshd, |
| ) |
| output_orig = model( |
| input_ids=new_input_ids, |
| position_ids=None if vision_model else new_position_ids, |
| attention_mask=new_attention_mask, |
| **model_kwargs, |
| ) |
| if post_process and logits_processor is not None: |
| args = { |
| k: preprocess_bshd( |
| v, attention_mask, position_ids_for_preprocess, sequence_parallel=sp, pre_process=True |
| )[0] |
| for k, v in logits_processor_args.items() |
| } |
| output_dict = logits_processor(output_orig, **args) |
| output = { |
| k: postprocess_bshd( |
| v, new_attention_mask, attention_mask, sequence_length, post_process=post_process |
| ) |
| for k, v in output_dict.items() |
| } |
| else: |
| output = postprocess_bshd( |
| output_orig, new_attention_mask, attention_mask, sequence_length, post_process=post_process |
| ) |
| if value_model and post_process: |
| output = output[..., 0] |
| return output |
|
|
| return model_forward |
|
|
|
|
| def _convert_to_nested_tensor(v, input_ids_lengths): |
| """Convert regular tensor to NestedTensor, slicing according to input_ids_lengths. |
| |
| Args: |
| v: Tensor to convert, shape [batch, seq_len] |
| input_ids_lengths: List of valid lengths for each sample |
| |
| Returns: |
| Converted NestedTensor |
| """ |
| if isinstance(v, NestedTensor): |
| return v |
|
|
| batch_size = v.shape[0] |
| assert len(input_ids_lengths) == batch_size, ( |
| f"len(input_ids_lengths)={len(input_ids_lengths)} != batch_size={batch_size}" |
| ) |
|
|
| v_split_list = [] |
| for i in range(batch_size): |
| vi = v[i] |
| target_len = input_ids_lengths[i] |
| if vi.shape[0] > target_len: |
| vi = vi[:target_len] |
| elif vi.shape[0] < target_len: |
| vi = torch.cat([vi, torch.ones(target_len - vi.shape[0], dtype=vi.dtype, device=vi.device)]) |
| v_split_list.append(vi) |
|
|
| v = torch.nested.nested_tensor(v_split_list, layout=torch.jagged) |
| return v |
|
|
|
|
| def gptmodel_forward_no_padding( |
| model, |
| input_ids, |
| multi_modal_inputs: dict, |
| logits_processor=None, |
| logits_processor_args: dict = None, |
| value_model=False, |
| vision_model=False, |
| pad_token_id=None, |
| data_format: str = "thd", |
| mtp_enable_train: bool = False, |
| ): |
| """Default forward pass for GPT models with optional sequence packing.""" |
|
|
| assert data_format in ["thd", "bshd"], "data_format must be 'thd' or 'bshd'" |
| pre_process = unwrap_model(model).pre_process |
| post_process = unwrap_model(model).post_process |
|
|
| fp8 = unwrap_model(model).config.fp8 |
| use_fp8_padding = fp8 in ["e4m3", "hybrid"] |
|
|
| model_kwargs = {} |
| if "pixel_values" in multi_modal_inputs: |
| model_kwargs["pixel_values"] = multi_modal_inputs["pixel_values"].to(input_ids.device) |
| if "image_grid_thw" in multi_modal_inputs: |
| model_kwargs["image_grid_thw"] = multi_modal_inputs["image_grid_thw"].to(input_ids.device) |
| if "pixel_values_videos" in multi_modal_inputs: |
| model_kwargs["pixel_values_videos"] = multi_modal_inputs["pixel_values_videos"].to(input_ids.device) |
| if "video_grid_thw" in multi_modal_inputs: |
| model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) |
|
|
| batch_size = input_ids.shape[0] |
| if data_format == "thd": |
| input_ids_rmpad, packed_seq_params, position_ids_rmpad = preprocess_thd_no_padding( |
| input_ids, pre_process=pre_process or (post_process and mtp_enable_train), use_fp8_padding=use_fp8_padding |
| ) |
| input_ids_rmpad = input_ids_rmpad.contiguous() |
|
|
| args = {} |
| if mtp_enable_train and post_process: |
| |
| input_ids_offsets = input_ids.offsets() |
| input_ids_lengths = input_ids_offsets.diff().tolist() |
|
|
| for k in ["label", "loss_mask"]: |
| v = logits_processor_args[k] |
| v = _convert_to_nested_tensor(v, input_ids_lengths) |
| logits_processor_args[k] = v |
| args[k] = preprocess_thd_no_padding( |
| v, pre_process=True, need_roll=True, use_fp8_padding=use_fp8_padding |
| )[0] |
|
|
| model_kwargs["labels"] = args["label"].contiguous() |
| model_kwargs["loss_mask"] = args["loss_mask"].contiguous() |
|
|
| if logits_processor_args and "loss_mask" in logits_processor_args: |
| logits_processor_args.pop("loss_mask") |
|
|
| |
| attention_mask = None |
| if vision_model: |
| input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id) |
| seqlens_in_batch = input_ids.offsets().diff() |
| attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool) |
| for i, seqlen in enumerate(seqlens_in_batch): |
| attention_mask[i, :seqlen] = True |
|
|
| output_orig = model( |
| input_ids=input_ids_rmpad, |
| attention_mask=attention_mask, |
| position_ids=position_ids_rmpad if not vision_model else None, |
| packed_seq_params=packed_seq_params, |
| **model_kwargs, |
| ) |
|
|
| if post_process and logits_processor is not None: |
| args = { |
| k: preprocess_thd_no_padding( |
| v, pre_process=True, need_roll=(k == "label"), use_fp8_padding=use_fp8_padding |
| )[0] |
| for k, v in logits_processor_args.items() |
| } |
| output_dict = logits_processor(output_orig, **args) |
| output = { |
| k: postprocess_thd_no_padding(v, packed_seq_params, input_ids, batch_size, post_process=post_process) |
| for k, v in output_dict.items() |
| } |
| else: |
| output = postprocess_thd_no_padding( |
| output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process |
| ) |
| else: |
| """ |
| data_format: "thd" or "bshd", default is "thd", |
| why we need this? |
| for some new models, GPT-OSS, the thd format is not supported, so we need to use the bshd format. |
| When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length, |
| so it is recommended to disable dynamic batch size and set batch size to 1 |
| """ |
|
|
| input_ids_bshd, attention_mask_bshd, position_ids_bshd = preprocess_bshd_no_padding( |
| input_ids, pre_process=pre_process or (post_process and mtp_enable_train), use_fp8_padding=use_fp8_padding |
| ) |
|
|
| if mtp_enable_train and post_process: |
| args = {} |
| |
| input_ids_offsets = input_ids.offsets() |
| input_ids_lengths = input_ids_offsets.diff().tolist() |
|
|
| for k in ["label", "loss_mask"]: |
| v = logits_processor_args[k] |
| v = _convert_to_nested_tensor(v, input_ids_lengths) |
| logits_processor_args[k] = v |
| args[k] = preprocess_bshd_no_padding( |
| v, pre_process=True, need_roll=True, use_fp8_padding=use_fp8_padding |
| )[0] |
| model_kwargs["labels"] = args["label"].contiguous() |
| model_kwargs["loss_mask"] = args["loss_mask"].contiguous() |
|
|
| if logits_processor_args and "loss_mask" in logits_processor_args: |
| logits_processor_args.pop("loss_mask") |
|
|
| output_orig = model( |
| input_ids=input_ids_bshd, |
| attention_mask=attention_mask_bshd, |
| position_ids=None if vision_model else position_ids_bshd, |
| **model_kwargs, |
| ) |
| if post_process and logits_processor is not None: |
| args = { |
| k: preprocess_bshd_no_padding( |
| v, pre_process=True, need_roll=(k == "label"), use_fp8_padding=use_fp8_padding |
| )[0] |
| for k, v in logits_processor_args.items() |
| } |
| output_dict = logits_processor(output_orig, **args) |
| output = { |
| k: postprocess_bshd_no_padding(v, attention_mask_bshd, post_process=post_process) |
| for k, v in output_dict.items() |
| } |
| else: |
| output = postprocess_bshd_no_padding(output_orig, attention_mask_bshd, post_process=post_process) |
|
|
| if value_model and post_process: |
| |
| |
| |
| |
| output = output.squeeze(-1) |
|
|
| return output |
|
|