IQuest-Coder-V1-40B-Base-Stage1 / modeling_iquestcoder.py
IQuestLabBot's picture
Upload folder using huggingface_hub
568f31f verified
"""IQuestCoder model implementation.
This implementation combines ideas from:
- LLaMA: Core architecture and forward pass (for compatibility)
- OLMo: QKV clipping for training stability
- Qwen2: Sliding window attention support
The forward pass is fully compatible with LLaMA weights.
"""
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import (
LossKwargs,
auto_docstring,
can_return_tuple,
is_torch_flex_attn_available,
logging,
)
from .configuration_iquestcoder import IQuestCoderConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from transformers.integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
# =============================================================================
# Helper Functions
# =============================================================================
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
unsqueeze_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q: The query tensor.
k: The key tensor.
cos: The cosine part of the rotary embedding.
sin: The sine part of the rotary embedding.
position_ids: Deprecated and unused.
unsqueeze_dim: The dimension along which to unsqueeze cos and sin.
Returns:
Tuple of query and key tensors rotated using the Rotary Position Embedding.
"""
# Borrowed from OLMo: preserve original dtypes for numerical stability
q_dtype, k_dtype = q.dtype, k.dtype
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(q_dtype), k_embed.to(k_dtype)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Expands key/value heads for Grouped Query Attention.
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim).
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Standard eager attention implementation."""
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# =============================================================================
# Model Components
# =============================================================================
class IQuestCoderRMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
RMSNorm is computationally simpler than LayerNorm while achieving similar
performance. It normalizes the input by its RMS value.
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class IQuestCoderRotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE).
Implements rotary positional embeddings as described in the RoFormer paper.
Supports various RoPE scaling methods for extended context lengths.
"""
def __init__(self, config: IQuestCoderConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class IQuestCoderMLP(nn.Module):
"""Feed-forward network with SwiGLU activation.
Uses the gated linear unit variant with SiLU activation for improved
performance compared to standard FFN.
"""
def __init__(self, config: IQuestCoderConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: down_proj(act_fn(gate_proj(x)) * up_proj(x))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class IQuestCoderAttention(nn.Module):
"""Multi-headed attention with support for Grouped Query Attention (GQA).
Features:
- Grouped Query Attention for memory efficiency
- Optional QKV clipping for training stability (from OLMo)
- Optional sliding window attention (from Qwen2)
- Rotary Position Embeddings
"""
def __init__(self, config: IQuestCoderConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
# Projection layers
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# Compute Q, K, V projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# [OLMo Feature] Optional QKV clipping for training stability
if self.config.clip_qkv is not None:
query_states = query_states.clamp(min=-self.config.clip_qkv, max=self.config.clip_qkv)
key_states = key_states.clamp(min=-self.config.clip_qkv, max=self.config.clip_qkv)
value_states = value_states.clamp(min=-self.config.clip_qkv, max=self.config.clip_qkv)
# Reshape to (batch, heads, seq_len, head_dim)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
# Apply rotary position embeddings
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Update KV cache if provided
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# [Qwen2 Feature] Sliding window attention
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
# Select attention implementation
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
'Falling back to eager attention. This warning can be removed using the argument '
'`attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# Compute attention
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=sliding_window,
**kwargs,
)
# Reshape and project output
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class IQuestCoderDecoderLayer(GradientCheckpointingLayer):
"""Transformer decoder layer with pre-normalization.
Architecture: Pre-RMSNorm -> Attention -> Residual -> Pre-RMSNorm -> MLP -> Residual
"""
def __init__(self, config: IQuestCoderConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = IQuestCoderAttention(config=config, layer_idx=layer_idx)
self.mlp = IQuestCoderMLP(config)
self.input_layernorm = IQuestCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = IQuestCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Warn if sliding window is enabled but not properly supported
if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
# Pre-norm + Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Pre-norm + MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
# =============================================================================
# Base Model
# =============================================================================
@auto_docstring
class IQuestCoderPreTrainedModel(PreTrainedModel):
"""Base class for IQuestCoder models."""
config_class = IQuestCoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["IQuestCoderDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, IQuestCoderRMSNorm):
module.weight.data.fill_(1.0)
@auto_docstring
class IQuestCoderModel(IQuestCoderPreTrainedModel):
"""
IQuestCoder Model outputting raw hidden-states without any specific head on top.
This model is compatible with LLaMA weights while incorporating features from OLMo and Qwen2.
"""
def __init__(self, config: IQuestCoderConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[IQuestCoderDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = IQuestCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = IQuestCoderRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def set_input_embeddings(self, value: nn.Embedding):
self.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if not isinstance(past_key_values, (type(None), Cache)):
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# Create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'. "
"This may lead to unexpected behaviour for Flash Attention version of IQuestCoder. "
"Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input."
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window if self.config.use_sliding_window else None,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: IQuestCoderConfig,
past_key_values: Cache,
):
"""Creates a causal 4D mask from a 2D mask, or returns the 4D mask if already provided."""
if attention_mask is not None and attention_mask.dim() == 4:
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
# [Qwen2 Feature] Handle sliding window mask
if getattr(config, "use_sliding_window", False) and config.sliding_window is not None:
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone()
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# =============================================================================
# Model Heads
# =============================================================================
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
...
@auto_docstring
class IQuestCoderForCausalLM(IQuestCoderPreTrainedModel, GenerationMixin):
"""IQuestCoder Model with a language modeling head on top for causal LM."""
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: IQuestCoderConfig):
super().__init__(config)
self.model = IQuestCoderModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def set_input_embeddings(self, value: nn.Embedding):
self.model.embed_tokens = value
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Linear):
self.lm_head = new_embeddings
def set_decoder(self, decoder: IQuestCoderModel):
self.model = decoder
def get_decoder(self) -> IQuestCoderModel:
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer
>>> from modeling_iquestcoder import IQuestCoderForCausalLM
>>> model = IQuestCoderForCausalLM.from_pretrained("path/to/IQuestCoder")
>>> tokenizer = AutoTokenizer.from_pretrained("path/to/IQuestCoder")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you."
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# Decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@auto_docstring(
custom_intro="""
The IQuestCoder Model transformer with a sequence classification head on top (linear layer).
[`IQuestCoderForSequenceClassification`] uses the last token in order to do the classification, as other causal
models (e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row.
If no `pad_token_id` is defined, it simply takes the last value in each row of the batch.
"""
)
class IQuestCoderForSequenceClassification(IQuestCoderPreTrainedModel):
"""IQuestCoder Model with a sequence classification head."""
def __init__(self, config: IQuestCoderConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = IQuestCoderModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def set_input_embeddings(self, value: nn.Embedding):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@auto_docstring
class IQuestCoderForTokenClassification(IQuestCoderPreTrainedModel):
"""IQuestCoder Model with a token classification head."""
def __init__(self, config: IQuestCoderConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = IQuestCoderModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def set_input_embeddings(self, value: nn.Embedding):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> TokenClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@auto_docstring
class IQuestCoderForQuestionAnswering(IQuestCoderPreTrainedModel):
"""IQuestCoder Model with a span classification head for extractive question-answering."""
base_model_prefix = "transformer"
def __init__(self, config: IQuestCoderConfig):
super().__init__(config)
self.transformer = IQuestCoderModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.embed_tokens
def set_input_embeddings(self, value: nn.Embedding):
self.transformer.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> QuestionAnsweringModelOutput:
outputs: BaseModelOutputWithPast = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"IQuestCoderPreTrainedModel",
"IQuestCoderModel",
"IQuestCoderForCausalLM",
"IQuestCoderForSequenceClassification",
"IQuestCoderForTokenClassification",
"IQuestCoderForQuestionAnswering",
]