Hanrui / progress /SpecForge /specforge /modeling /target /eagle3_target_model.py
Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple
import sglang.srt.managers.mm_utils as mm_utils
import torch
import torch.distributed as dist
import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
init_mm_embedding_cache,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
Req,
ScheduleBatch,
)
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather
from transformers import AutoModelForCausalLM
from specforge.distributed import get_tp_device_mesh, get_tp_group
from specforge.utils import padding
from .sglang_backend import SGLangRunner, wrap_eagle3_logits_processors_in_module
from .sglang_backend.utils import LogitsProcessorForEAGLE3
@dataclass
class Eagle3TargetOutput:
hidden_states: torch.Tensor
target: torch.Tensor
loss_mask: torch.Tensor
input_ids: torch.Tensor
attention_mask: torch.Tensor
last_hidden_states: Optional[torch.Tensor] = None
class Eagle3TargetModel(ABC):
"""
This offers a layer of abstraction for the target model backend. The user can choose different backends to suit their needs:
1. SGLang backend: for the mainstream model support with the fastest inference speed
2. HuggingFace backend: for models that are not supported by SGLang but can be loaded by HuggingFace.
3. Custom backend: for models with customized architecture and inference plan.
"""
def __init__(self):
self.aux_hidden_states_layers = None
@classmethod
@abstractmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> "Eagle3TargetModel":
"""
Initialize the target model backend from a pretrained model path.
"""
@abstractmethod
def generate_eagle3_data(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> Eagle3TargetOutput:
"""
Generate the eagle3 data from the target model.
"""
def set_aux_hidden_states_layers(
self, aux_hidden_states_layers: Optional[List[int]] = None
) -> None:
"""
Set the layers to capture the aux hidden states from the target model outputs.
"""
if aux_hidden_states_layers is None:
if hasattr(self.model.config, "num_hidden_layers"):
num_layers = self.model.config.num_hidden_layers
else:
raise ValueError(
f"Failed to set aux hidden states layers as model config {self.model.config} does not have num_hidden_layers"
)
aux_hidden_states_layers = [
1,
num_layers // 2 - 1,
num_layers - 4,
]
self.aux_hidden_states_layers = aux_hidden_states_layers
assert (
len(self.aux_hidden_states_layers) == 3
), "aux_hidden_states_layers is expected to be 3 layers for EAGLE3"
class HFEagle3TargetModel(Eagle3TargetModel):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> "HFEagle3TargetModel":
"""
Initialize the HuggingFace target model backend from a pretrained model path.
"""
tp_size = get_tp_group().size()
if tp_size > 1:
device_kwargs = {
"tp_plan": "auto",
"tp_size": tp_size,
"device_mesh": get_tp_device_mesh(),
}
else:
device_kwargs = {
"device_map": device,
}
target_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
**device_kwargs,
**kwargs,
)
return cls(target_model)
def _get_transformer_layers(self):
"""
Helper to find the module list containing the transformer layers.
Adapts to common architectures (Llama, Qwen, Mistral, OPT, etc.)
"""
if hasattr(self.model, "model") and hasattr(self.model.model, "layers"):
return self.model.model.layers
elif hasattr(self.model, "layers"):
return self.model.layers
elif hasattr(self.model, "transformer") and hasattr(
self.model.transformer, "h"
):
return self.model.transformer.h
else:
raise ValueError(
"Could not locate transformer layers in the model architecture to register hooks."
)
@torch.no_grad()
def generate_eagle3_data(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> Eagle3TargetOutput:
"""
Optimized HF backend:
Instead of returning all hidden states (memory heavy), we use forward hooks
to capture only the specific layers required by Eagle3.
"""
captured_states = {}
handles = []
def get_hook(layer_idx):
def hook(module, input, output):
# HF outputs for layers are usually tuples (hidden_states, present_key_value, ...)
# We only need the hidden_states (first element)
if isinstance(output, tuple):
hidden = output[0]
else:
hidden = output
captured_states[layer_idx] = hidden
return hook
# Locate the transformer layers ModuleList
layers = self._get_transformer_layers()
target_indices = self.aux_hidden_states_layers
# Register hooks
for idx in target_indices:
# Ensure index is within bounds
if 0 <= idx < len(layers):
handles.append(layers[idx].register_forward_hook(get_hook(idx)))
else:
raise ValueError(
f"Layer index {idx} out of bounds for model with {len(layers)} layers."
)
try:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=False,
output_attentions=False,
output_router_logits=False,
use_cache=False,
)
target = outputs.logits
finally:
# Always remove hooks to prevent memory leaks or side effects on subsequent calls
for handle in handles:
handle.remove()
# Verify we captured everything
if len(captured_states) != 3:
raise RuntimeError(
f"Expected to capture 3 layers, but captured {len(captured_states)}"
)
# Extract in the correct order
hidden_states0 = captured_states[target_indices[0]]
hidden_states1 = captured_states[target_indices[1]]
hidden_states2 = captured_states[target_indices[2]]
hidden_states = torch.cat(
(hidden_states0, hidden_states1, hidden_states2), dim=-1
)
# apply pading
target = outputs.logits
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
loss_mask = loss_mask[..., None].to(target.device)
return Eagle3TargetOutput(
hidden_states=hidden_states,
target=target,
loss_mask=loss_mask,
input_ids=input_ids,
attention_mask=attention_mask,
)
class SGLangEagle3TargetModel(Eagle3TargetModel):
def __init__(self, model_runner: SGLangRunner, hf_config=None):
super().__init__()
self.model_runner = model_runner
self.hf_config = hf_config
# VLM-specific attributes (initialized from hf_config if available)
self._init_vlm_attributes()
def _init_vlm_attributes(self):
"""Initialize VLM-specific attributes from hf_config for models like Qwen2.5-VL"""
if self.hf_config is None:
self.is_vlm = False
return
# Check if this is a VLM model by looking for vision_config
self.is_vlm = hasattr(self.hf_config, "vision_config")
if not self.is_vlm:
return
init_mm_embedding_cache(1024 * 1024 * 512)
# Model type (e.g., "qwen2_5_vl", "qwen2_vl")
self.model_type = getattr(self.hf_config, "model_type", None)
# Vision config attributes
vision_config = self.hf_config.vision_config
self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2)
self.tokens_per_second = getattr(vision_config, "tokens_per_second", None)
# Special token IDs from hf_config
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
self.video_token_id = getattr(self.hf_config, "video_token_id", None)
self.vision_start_token_id = getattr(
self.hf_config, "vision_start_token_id", None
)
self.vision_end_token_id = getattr(self.hf_config, "vision_end_token_id", None)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
) -> "SGLangEagle3TargetModel":
tp_size = dist.get_world_size(get_tp_group())
server_args = ServerArgs(
model_path=pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
dtype=torch_dtype,
enable_return_hidden_states=True,
disable_cuda_graph=True, # we use piecewise cuda graph for prefill instead
tp_size=tp_size,
pp_size=1,
**kwargs,
)
tp_rank = dist.get_rank(get_tp_group())
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
model_config = ModelConfig.from_server_args(server_args)
model_runner = SGLangRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=torch.cuda.current_device(),
tp_rank=dist.get_rank(get_tp_group()),
tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank,
moe_ep_size=server_args.ep_size,
pp_rank=0,
pp_size=1,
server_args=server_args,
nccl_port=None,
)
wrap_eagle3_logits_processors_in_module(
model_runner.model, return_full_logits=False
)
# Get hf_config from model_config for VLM attributes
hf_config = getattr(model_config, "hf_config", None)
return cls(model_runner, hf_config=hf_config)
def set_aux_hidden_states_layers(
self, aux_hidden_states_layers: Optional[List[int]] = None
) -> None:
self.model_runner.model.set_eagle3_layers_to_capture(aux_hidden_states_layers)
@torch.no_grad
def _extend(
self,
reqs,
capture_aux_hidden_states: bool = True,
return_last_hidden_states: bool = False,
return_logits: bool = False,
):
# set the logits processor for the model runner
for name, module in self.model_runner.model.named_modules():
if isinstance(module, LogitsProcessorForEAGLE3):
module.return_last_hidden_states = return_last_hidden_states
module.return_logits = return_logits
cache_params = CacheInitParams(
disable=False,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
page_size=self.model_runner.server_args.page_size,
)
tree_cache = RadixCache(cache_params)
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
tree_cache=tree_cache,
model_config=self.model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
)
batch.prepare_for_extend()
self._maybe_prepare_mlp_sync_batch(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL
eagle3_output, _ = self.model_runner.forward(forward_batch)
aux_hidden_states_list = None
input_lens = [len(req.origin_input_ids) for req in reqs]
if return_logits:
logits = torch.split(eagle3_output.logits, input_lens, dim=0)
else:
logits = [None] * len(reqs)
if capture_aux_hidden_states:
aux_hidden_states_list = torch.split(
eagle3_output.aux_hidden_states, input_lens, dim=0
)
else:
aux_hidden_states_list = [None] * len(reqs)
if return_last_hidden_states:
last_hidden_states = torch.split(
eagle3_output.last_hidden_states, input_lens, dim=0
)
else:
last_hidden_states = [None] * len(reqs)
# TODO: can we not clear?
self.model_runner.req_to_token_pool.clear()
self.model_runner.token_to_kv_pool_allocator.clear()
return logits, aux_hidden_states_list, last_hidden_states
def _maybe_prepare_mlp_sync_batch(self, batch: ScheduleBatch):
if require_mlp_sync(self.model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=self.model_runner.server_args.dp_size,
attn_tp_size=1,
tp_group=self.model_runner.tp_group,
get_idle_batch=None,
disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(
self.model_runner.server_args
),
disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule,
offload_tags=set(),
)
def extend(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
return_last_hidden_states: bool = False,
return_logits: bool = True,
):
sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1)
reqs, data_cache = [], []
if isinstance(input_ids, torch.Tensor):
input_ids = torch.split(input_ids, 1, dim=0)
attention_mask = torch.split(attention_mask, 1, dim=0)
loss_mask = torch.split(loss_mask, 1, dim=0)
for idx, (input_id_, attention_mask_, loss_mask_) in enumerate(
zip(
input_ids,
attention_mask,
loss_mask,
)
):
req = Req(
rid=str(idx),
origin_input_text="",
origin_input_ids=input_id_.view(-1).tolist(),
sampling_params=sampling_params,
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
data_cache.append([input_id_, attention_mask_, loss_mask_])
reqs.append(req)
logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend(
reqs,
capture_aux_hidden_states=True,
return_last_hidden_states=return_last_hidden_states,
return_logits=return_logits,
)
return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list
def get_rope_index(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
video_grid_thw: Optional[torch.Tensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Get M-RoPE position indices for VLM models like Qwen2.5-VL.
This is a wrapper around MRotaryEmbedding.get_rope_index that uses
the VLM-specific attributes initialized from hf_config.
Args:
input_ids: (batch_size, seq_len) input token IDs
image_grid_thw: (num_images, 3) image grid dimensions (t, h, w)
video_grid_thw: (num_videos, 3) video grid dimensions (t, h, w)
second_per_grid_ts: Optional temporal information for videos
attention_mask: (batch_size, seq_len) attention mask
Returns:
position_ids: (3, batch_size, seq_len) M-RoPE position IDs
rope_deltas: Optional position deltas for incremental decoding
"""
if not self.is_vlm:
raise ValueError("get_rope_index is only available for VLM models")
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
position_ids, rope_deltas = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.spatial_merge_size,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
vision_start_token_id=self.vision_start_token_id,
model_type=self.model_type,
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
tokens_per_second=self.tokens_per_second,
)
return position_ids, rope_deltas
def extend_vlm(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
return_last_hidden_states: bool = False,
return_logits: bool = True,
pixel_values: Optional[List[torch.Tensor]] = None,
image_grid_thw: Optional[List[torch.Tensor]] = None,
):
"""
Args:
input_ids: (batch_size, seq_len) or List of (1, seq_len) tensors
attention_mask: (batch_size, seq_len) or List of (1, seq_len) tensors
loss_mask: (batch_size, seq_len) or List of (1, seq_len) tensors
pixel_values: List of pixel_values tensors, one per sample in batch
image_grid_thw: List of image_grid_thw tensors, one per sample in batch
"""
mm_utils.embedding_cache.clear()
sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1)
reqs, data_cache = [], []
# Split tensors if needed
if isinstance(input_ids, torch.Tensor):
batch_size = input_ids.shape[0]
input_ids = torch.split(input_ids, 1, dim=0)
attention_mask = torch.split(attention_mask, 1, dim=0)
loss_mask = torch.split(loss_mask, 1, dim=0)
else:
batch_size = len(input_ids)
# Process image_grid_thw - convert to list if needed
if image_grid_thw is None:
image_grid_thw = [None] * batch_size
elif not isinstance(image_grid_thw, (list, tuple)):
image_grid_thw = [image_grid_thw]
# pixel_values is a single 2D tensor (total_patches, patch_dim) for Qwen2.5-VL
# We need to track offset and slice it based on image_grid_thw for each sample
pixel_values_offset = 0 # Track current offset in pixel_values
for idx, (input_id_, attention_mask_, loss_mask_, image_grid_thw_) in enumerate(
zip(
input_ids,
attention_mask,
loss_mask,
image_grid_thw,
)
):
# Compute num_patches for this sample from image_grid_thw_
# image_grid_thw_: (num_images, 3) where each row is (t, h, w)
if image_grid_thw_ is not None:
# Ensure image_grid_thw_ is 2D: (num_images, 3)
if image_grid_thw_.dim() == 1:
image_grid_thw_ = image_grid_thw_.unsqueeze(0) # (3,) -> (1, 3)
elif image_grid_thw_.dim() == 0:
raise ValueError(
f"image_grid_thw_ is 0-dim tensor, expected at least 1D. Value: {image_grid_thw_}"
)
# Calculate num_patches for this sample: sum(t * h * w) for all images
num_patches = (
(
image_grid_thw_[:, 0]
* image_grid_thw_[:, 1]
* image_grid_thw_[:, 2]
)
.sum()
.item()
)
num_patches = int(num_patches)
# Slice pixel_values for this sample
pixel_value_ = pixel_values[
pixel_values_offset : pixel_values_offset + num_patches
]
pixel_values_offset += num_patches
else:
pixel_value_ = None
num_patches = 0
# Compute mrope positions for VLM models (e.g., Qwen2.5-VL)
input_id_flat = input_id_.view(-1)
# Count image tokens
num_img_tokens = (input_id_flat == self.image_token_id).sum().item()
# print(f"[extend_vlm] num_img_tokens in input_ids: {num_img_tokens}")
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.spatial_merge_size,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
vision_start_token_id=self.vision_start_token_id,
model_type=self.model_type,
input_ids=input_id_flat.unsqueeze(0),
image_grid_thw=(
image_grid_thw_.cpu() if image_grid_thw_ is not None else None
),
tokens_per_second=self.tokens_per_second,
)
offset = BaseMultimodalProcessor.get_mm_items_offset(
input_id_flat, self.image_token_id
)
mm_item = MultimodalDataItem(
modality=Modality.IMAGE,
feature=pixel_value_, # torch.Tensor: (num_patches, patch_dim)
pad_value=self.image_token_id, # Required for placeholder tensor creation
offsets=offset, # List of (start, end) tuples
)
mm_item.set("image_grid_thw", image_grid_thw_.cpu())
mm_item.set_pad_value()
mm_inputs = MultimodalInputs(
mm_items=[mm_item],
im_token_id=self.image_token_id,
im_start_id=self.vision_start_token_id,
im_end_id=self.vision_end_token_id,
mrope_positions=(
mrope_positions.squeeze(1) if mrope_positions is not None else None
),
mrope_position_delta=mrope_position_delta,
)
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
input_id_list = pattern.pad_input_tokens(
input_id_.view(-1).tolist(), mm_inputs
)
req = Req(
rid=str(idx),
origin_input_text="",
origin_input_ids=input_id_list,
sampling_params=sampling_params,
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
req.multimodal_inputs = mm_inputs
data_cache.append([input_id_, attention_mask_, loss_mask_])
reqs.append(req)
logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend(
reqs,
capture_aux_hidden_states=True,
return_last_hidden_states=return_last_hidden_states,
return_logits=return_logits,
)
return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list
@torch.no_grad()
def generate_eagle3_data(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
is_vlm: bool = False,
) -> Eagle3TargetOutput:
"""
return:
data_for_draft: List[Dict[str, torch.Tensor]] of draft_batch_size, draft_micro_batch_size = 1
- input_ids: (1, seq_len)
- attention_mask: (1, seq_len)
- loss_mask: (1, seq_len)
- target: (1, seq_len, vocab_size) or (1, seq_len, hidden_size)
- hidden_states: (1, seq_len, hidden_size)
- pixel_values: (patch_len, patch_width)
- image_grid_thw (batch_size, 3)
"""
if is_vlm:
data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = (
self.extend_vlm(
input_ids,
attention_mask,
loss_mask,
return_last_hidden_states=False,
return_logits=True,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
)
else:
data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = (
self.extend(
input_ids,
attention_mask,
loss_mask,
return_last_hidden_states=False,
return_logits=True,
)
)
aux_hidden_states_out = []
target_out = []
loss_mask_out = []
input_ids_out = []
last_hidden_states_out = []
for idx, (data, logits, aux_hidden_states, last_hidden_states) in enumerate(
zip(
data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list
)
):
aux_hidden_states_out.append(aux_hidden_states.unsqueeze(0))
loss_mask_out.append(data[2])
input_ids_out.append(data[0])
# when generating hidden states for offline training, we don't compute logits and only keep the last_hidden_states
# when training online, we don't keep the last_hidden_states and only keep the logits
if logits is not None:
target_out.append(logits.unsqueeze(0))
else:
target_out.append(None)
if last_hidden_states is not None:
last_hidden_states_out.append(last_hidden_states.unsqueeze(0))
else:
last_hidden_states_out.append(None)
aux_hidden_states_out = torch.cat(aux_hidden_states_out, dim=0)
loss_mask_out = torch.cat(loss_mask_out, dim=0)
input_ids_out = torch.cat(input_ids_out, dim=0)
if target_out[0] is not None:
target_out = torch.cat(target_out, dim=0)
else:
target_out = None
if last_hidden_states_out[0] is not None:
last_hidden_states_out = torch.cat(last_hidden_states_out, dim=0)
else:
last_hidden_states_out = None
target_out = padding(target_out, left=False)
input_ids_out = padding(input_ids_out, left=False)
loss_mask_out = loss_mask_out[..., None]
return Eagle3TargetOutput(
hidden_states=aux_hidden_states_out,
target=target_out,
loss_mask=loss_mask_out,
input_ids=input_ids_out,
attention_mask=attention_mask,
last_hidden_states=last_hidden_states_out,
)
class CustomEagle3TargetModel(Eagle3TargetModel):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> "CustomEagle3TargetModel":
from specforge.modeling.auto import AutoDistributedTargetModel
target_model = AutoDistributedTargetModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
device=device,
**kwargs,
)
return cls(target_model)
@torch.no_grad()
def generate_eagle3_data(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> Eagle3TargetOutput:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
layers_to_output_hidden_states=self.aux_hidden_states_layers,
use_cache=False,
)
# For custom backends, the model implementation is responsible for only
# returning the requested layers in `outputs.hidden_states`.
hidden_states = torch.cat(outputs.hidden_states, dim=-1)
target = outputs.logits
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
loss_mask = loss_mask[..., None].to(target.device)
return Eagle3TargetOutput(
hidden_states=hidden_states,
target=target,
loss_mask=loss_mask,
input_ids=input_ids,
attention_mask=attention_mask,
)
def get_eagle3_target_model(
pretrained_model_name_or_path: str,
backend: str = "sglang",
torch_dtype: torch.dtype = None,
device: str = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> Eagle3TargetModel:
if backend == "sglang":
return SGLangEagle3TargetModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
device=device,
cache_dir=cache_dir,
**kwargs,
)
elif backend == "hf":
return HFEagle3TargetModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
device=device,
cache_dir=cache_dir,
**kwargs,
)
elif backend == "custom":
return CustomEagle3TargetModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch_dtype,
device=device,
cache_dir=cache_dir,
**kwargs,
)
else:
raise ValueError(f"Invalid backend: {backend}")