| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint |
| from transformers.cache_utils import DynamicCache |
| from yunchang import EXTRACT_FUNC_DICT |
|
|
| from specforge.core.loss import LogSoftmaxLoss |
| from specforge.distributed import ( |
| gather_outputs_and_unpad, |
| get_sp_ring_group, |
| get_sp_ulysses_group, |
| ) |
| from specforge.modeling.draft import Eagle3DraftModel |
| from specforge.utils import padding |
|
|
|
|
| class Eagle3Model(nn.Module): |
| pass |
|
|
|
|
| class OnlineEagle3Model(Eagle3Model): |
| """ |
| In sgl-spec, we implement offline/online training. |
| Online training means we have the target hidden_states available during training. |
| Eagle3 using test time training technique (TTT) to train the draft model. |
| 1. We first extract the hidden states from the target model. |
| 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). |
| 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) |
| 4. We concat the projected hidden states and embedding output as the input for the draft model. |
| 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) |
| """ |
|
|
| def __init__( |
| self, |
| draft_model: Eagle3DraftModel, |
| length: int = 7, |
| attention_backend="sdpa", |
| target_model: Optional[Eagle3Model] = None, |
| ): |
| """ |
| Args: |
| target_model: the target model to extract hidden states. |
| draft_model: the draft model to be trained. |
| length: TTT length, it means how many turns to unroll during TTT. |
| """ |
| super().__init__() |
| self.draft_model = draft_model |
| self.length = length |
| self.attention_backend = attention_backend |
| self.target_model = target_model |
|
|
| if self.attention_backend == "usp": |
| self.extract_func = EXTRACT_FUNC_DICT["basic"] |
| self.sp_ring_degree = torch.distributed.get_world_size(get_sp_ring_group()) |
| self.sp_ulysses_degree = torch.distributed.get_world_size( |
| get_sp_ulysses_group() |
| ) |
| self.sp_world_size = self.sp_ring_degree * self.sp_ulysses_degree |
| self.sp_rank = torch.distributed.get_rank() % self.sp_world_size |
|
|
| @torch.compile() |
| def prepare_usp_input(self, full_input): |
| shared_input = self.extract_func( |
| full_input, |
| rank=self.sp_rank, |
| world_size=self.sp_world_size, |
| ).clone() |
| return shared_input |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| target: torch.Tensor, |
| loss_mask: torch.Tensor, |
| hidden_states: torch.Tensor, |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.Tensor] = None, |
| is_vlm: bool = False, |
| **kwargs, |
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: |
| """ |
| Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 |
| |
| Args: |
| input_ids: (batch, seq_len) |
| attention_mask: (batch, seq_len) |
| loss_mask: (batch, seq_len) |
| past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. |
| position_ids: (batch, seq_len) |
| """ |
| |
| target_p_padded, position_mask = _compute_target_p_padded( |
| target=target, |
| t2d=self.draft_model.t2d, |
| loss_mask=loss_mask, |
| length=self.length, |
| ) |
| del target |
| torch.cuda.empty_cache() |
|
|
| |
| batch_size, seq_length, _ = hidden_states.shape |
| seq_length_with_past = seq_length |
| past_key_values_length = 0 |
|
|
| |
| if self.attention_backend == "usp": |
| |
| |
| hidden_states = self.prepare_usp_input(hidden_states) |
| hidden_states = self.draft_model.project_hidden_states(hidden_states) |
|
|
| |
| if past_key_values is not None: |
| past_key_values_length = past_key_values[0][0].shape[2] |
| seq_length_with_past = seq_length_with_past + past_key_values_length |
| if position_ids is None: |
| if is_vlm: |
| mrope_positions_ids, mrope_position_delta = ( |
| self.target_model.get_rope_index( |
| input_ids=input_ids, image_grid_thw=image_grid_thw |
| ) |
| ) |
| position_ids = mrope_positions_ids |
| else: |
| device = hidden_states.device |
| position_ids = torch.arange( |
| past_key_values_length, |
| seq_length + past_key_values_length, |
| dtype=torch.long, |
| device=device, |
| ) |
| position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
| else: |
| position_ids = position_ids.view(-1, seq_length).long() |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| (batch_size, seq_length_with_past), |
| dtype=torch.bool, |
| device=hidden_states.device, |
| ) |
| if self.attention_backend == "sdpa": |
| attention_mask = self.draft_model.prepare_decoder_attention_mask( |
| attention_mask=attention_mask, |
| hidden_states=hidden_states, |
| batch_size=batch_size, |
| seq_length=seq_length, |
| past_key_values_length=past_key_values_length, |
| ) |
|
|
| def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): |
| |
| logits_ = self.draft_model.compute_logits(hs) |
| logits = gather_outputs_and_unpad(logits_, gather_dim=1) |
|
|
| |
| loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask) |
|
|
| |
| with torch.no_grad(): |
| acc_val = _compute_metric_acc( |
| logits=logits, |
| target_p=tgt_p, |
| position_mask=pos_mask, |
| loss_mask=l_mask, |
| ) |
| return loss_val, acc_val |
|
|
| |
| plosses = [] |
| vlosses = [] |
| acces = [] |
| |
| global_input_ids = input_ids |
| if self.attention_backend in ["sdpa", "fa", "usp"]: |
| cache_hidden = [[], []] |
| past_key_values = None |
| elif self.attention_backend == "flex_attention": |
| cache_hidden = None |
| past_key_values = DynamicCache() |
| else: |
| raise ValueError(f"Unknown attention backend: {self.attention_backend}") |
|
|
| for idx in range(self.length): |
| target_p = target_p_padded[:, idx : idx + seq_length, :] |
| if self.attention_backend == "usp": |
| input_ids = self.prepare_usp_input(global_input_ids) |
| else: |
| input_ids = global_input_ids |
|
|
| is_last = idx == self.length - 1 |
|
|
| |
| inputs_embeds = self.draft_model.embed_input_ids(input_ids) |
| inputs_embeds = inputs_embeds.to(hidden_states.dtype) |
|
|
| |
| hidden_states_out = self.draft_model.backbone( |
| input_embeds=inputs_embeds, |
| hidden_states=hidden_states, |
| cache_hidden=cache_hidden, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
|
|
| |
| hidden_states = hidden_states_out |
|
|
| if hidden_states.requires_grad: |
| loss, acc = checkpoint( |
| compute_loss_and_acc_checkpointed, |
| hidden_states, |
| target_p, |
| position_mask, |
| loss_mask, |
| use_reentrant=False, |
| ) |
| else: |
| loss, acc = compute_loss_and_acc_checkpointed( |
| hidden_states, target_p, position_mask, loss_mask |
| ) |
|
|
| plosses.append(loss) |
| acces.append(acc) |
| if not is_last: |
| |
| global_input_ids = padding(global_input_ids, left=False) |
| position_mask = padding(position_mask, left=False) |
| loss_mask = padding(loss_mask, left=False) |
| |
| return plosses, vlosses, acces |
|
|
|
|
| class QwenVLOnlineEagle3Model(Eagle3Model): |
| """ |
| In sgl-spec, we implement offline/online training. |
| Online training means we have the target hidden_states available during training. |
| Eagle3 using test time training technique (TTT) to train the draft model. |
| 1. We first extract the hidden states from the target model. |
| 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). |
| 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) |
| 4. We concat the projected hidden states and embedding output as the input for the draft model. |
| 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) |
| """ |
|
|
| def __init__( |
| self, |
| target_model, |
| draft_model: Eagle3DraftModel, |
| processor, |
| length: int = 7, |
| attention_backend: str = "sdpa", |
| ): |
| """ |
| Args: |
| target_model: the target model to extract hidden states. |
| draft_model: the draft model to be trained. |
| length: TTT length, it means how many turns to unroll during TTT. |
| """ |
| super().__init__() |
| self.target_model = target_model |
| self.draft_model = draft_model |
| self.processor = processor |
| self.length = length |
| self.attention_backend = attention_backend |
|
|
| @torch.no_grad() |
| def _prepare_data( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.Tensor] = None, |
| device: Optional[torch.device] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L692 |
| Extract the hidden states from the target model outputs. |
| |
| Args: |
| input_ids: (batch, seq_len) |
| attention_mask: (batch, seq_len) |
| loss_mask: (batch, seq_len) |
| device: the device to run the target model, if None, use the input_ids device |
| pixel_values: image pixel values, used for VLM models |
| image_grid_thw: image grid thw, used for VLM models |
| |
| Returns: |
| hidden_states: (batch, seq_len, 3*hidden_size) |
| target: (batch, seq_len, vocab_size) |
| loss_mask: (batch, seq_len) |
| input_ids: (batch, seq_len) |
| """ |
|
|
| if device is None: |
| device = input_ids.device |
|
|
| |
| outputs = self.target_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pixel_values=pixel_values, |
| image_grid_thw=image_grid_thw, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
|
|
| |
| |
| |
| num_hidden_states = len(outputs.hidden_states) |
| offset = 1 |
| num_layers = num_hidden_states - 1 |
|
|
| |
| low_aux_layer = 1 + offset |
| mid_aux_layer = num_layers // 2 - 1 + offset |
| last_aux_layer = num_layers - 4 + offset |
|
|
| hidden_states0 = outputs.hidden_states[low_aux_layer] |
| hidden_states1 = outputs.hidden_states[mid_aux_layer] |
| hidden_states2 = outputs.hidden_states[last_aux_layer] |
|
|
| hidden_states = torch.cat( |
| (hidden_states0, hidden_states1, hidden_states2), dim=-1 |
| ) |
|
|
| |
| target = outputs.logits |
| target = padding(target, left=False) |
| input_ids = padding(input_ids, left=False) |
|
|
| if target is not None: |
| target = target.to(device) |
| loss_mask = loss_mask[..., None] |
| loss_mask = loss_mask.to(device) |
|
|
| return hidden_states, target, loss_mask, input_ids |
|
|
| @torch.no_grad() |
| def _get_input_embeds( |
| self, |
| input_ids: torch.Tensor, |
| pixel_values: torch.Tensor, |
| image_grid_thw: torch.Tensor, |
| ) -> torch.Tensor: |
| |
| |
| inputs_embeds = self.draft_model.embed_input_ids(input_ids) |
| image_embeds = self.target_model.model.get_image_features( |
| pixel_values, image_grid_thw |
| ) |
| image_embeds = torch.cat(image_embeds, dim=0) |
| n_image_tokens = ( |
| input_ids == self.target_model.model.config.image_token_id |
| ).sum() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| mask = input_ids == self.target_model.model.config.image_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| return inputs_embeds |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.Tensor] = None, |
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: |
| """ |
| Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 |
| |
| Args: |
| input_ids: (batch, seq_len) |
| attention_mask: (batch, seq_len) |
| loss_mask: (batch, seq_len) |
| past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. |
| position_ids: (batch, seq_len) |
| pixel_values: batch image pixel values, used for VLM models |
| image_grid_thw: (batch, 3), image grid thw, used for VLM models |
| """ |
| |
| hidden_states, target, loss_mask, input_ids = self._prepare_data( |
| input_ids, attention_mask, loss_mask, pixel_values, image_grid_thw |
| ) |
|
|
| |
| target_p_padded, position_mask = _compute_target_p_padded( |
| target=target, |
| t2d=self.draft_model.t2d, |
| loss_mask=loss_mask, |
| length=self.length, |
| ) |
| del target |
|
|
| |
| batch_size, seq_length, _ = hidden_states.shape |
| seq_length_with_past = seq_length |
| past_key_values_length = 0 |
|
|
| |
| hidden_states = self.draft_model.project_hidden_states(hidden_states) |
|
|
| |
| if past_key_values is not None: |
| past_key_values_length = past_key_values[0][0].shape[2] |
| seq_length_with_past = seq_length_with_past + past_key_values_length |
|
|
| if position_ids is None: |
| attention_mask_tensor = ( |
| attention_mask |
| if not isinstance(attention_mask, dict) |
| else attention_mask["full_attention"] |
| ) |
| if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: |
| attention_mask_tensor = torch.diagonal( |
| attention_mask_tensor[:, 0], dim1=1, dim2=2 |
| ) |
| attention_mask_tensor = ( |
| attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min |
| ) |
| attention_mask_tensor = (1.0 - attention_mask_tensor).int() |
|
|
| position_ids, rope_deltas = self.target_model.model.get_rope_index( |
| input_ids, |
| image_grid_thw, |
| None, |
| second_per_grid_ts=None, |
| attention_mask=attention_mask_tensor, |
| ) |
| self.rope_deltas = rope_deltas |
| else: |
| position_ids = position_ids |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| (batch_size, seq_length_with_past), |
| dtype=torch.bool, |
| device=hidden_states.device, |
| ) |
| if self.attention_backend == "sdpa": |
| attention_mask = self.draft_model.prepare_decoder_attention_mask( |
| attention_mask=attention_mask, |
| hidden_states=hidden_states, |
| batch_size=batch_size, |
| seq_length=seq_length, |
| past_key_values_length=past_key_values_length, |
| ) |
|
|
| |
| plosses = [] |
| vlosses = [] |
| acces = [] |
| if self.attention_backend in ["sdpa", "fa"]: |
| cache_hidden = [[], []] |
| past_key_values = None |
| elif self.attention_backend == "flex_attention": |
| cache_hidden = None |
| past_key_values = DynamicCache() |
| else: |
| raise ValueError(f"Unknown attention backend: {self.attention_backend}") |
|
|
| for idx in range(self.length): |
| target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() |
| is_last = idx == self.length - 1 |
|
|
| |
| |
| inputs_embeds = self.draft_model.embed_input_ids(input_ids) |
| inputs_embeds = inputs_embeds.to(hidden_states.dtype) |
|
|
| |
| hidden_states_out = self.draft_model.backbone( |
| input_embeds=inputs_embeds, |
| hidden_states=hidden_states, |
| cache_hidden=cache_hidden, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
|
|
| |
| hidden_states = hidden_states_out |
|
|
| |
| logits = self.draft_model.compute_logits(hidden_states) |
|
|
| |
| with torch.no_grad(): |
| acces.append( |
| _compute_metric_acc( |
| logits=logits, |
| target_p=target_p, |
| position_mask=position_mask, |
| loss_mask=loss_mask, |
| ) |
| ) |
|
|
| |
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) |
| plosses.append(loss) |
|
|
| if not is_last: |
| |
| input_ids = padding(input_ids, left=False) |
| position_mask = padding(position_mask, left=False) |
| loss_mask = padding(loss_mask, left=False) |
| |
| return plosses, vlosses, acces |
|
|
|
|
| def _compute_target_p_padded(target, t2d, loss_mask, length): |
| with torch.no_grad(): |
| target_p, position_mask = _compute_target_p( |
| target=target, |
| t2d=t2d, |
| loss_mask=loss_mask, |
| ) |
|
|
| assert len(target_p.shape) == 3 |
| target_p_padded = F.pad( |
| target_p, |
| pad=(0, 0, 0, length), |
| mode="constant", |
| |
| value=1 / target_p.shape[-1], |
| ) |
|
|
| return target_p_padded, position_mask |
|
|
|
|
| @torch.compile(dynamic=None) |
| def _compute_target_p(target, t2d, loss_mask): |
| target_head = target |
| target_max_token = target_head.argmax(-1) |
| target_mask = t2d[target_max_token] |
| target_mask = target_mask[..., None].int() |
| position_mask = target_mask * loss_mask |
| target_head = target_head[..., t2d] |
| target_head = target_head.float() |
| target_p = nn.Softmax(dim=2)(target_head) |
| target_p = target_p.detach() |
| return target_p, position_mask |
|
|
|
|
| @torch.compile(dynamic=None) |
| def _compute_metric_acc(logits, target_p, position_mask, loss_mask): |
| return ( |
| (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) |
| ).sum() / loss_mask.sum().clamp_min(1e-6) |
|
|