PHI-LBNET-2.7B-BASE / modeling.py
thefutureofai's picture
Upload 79 files
fd8d063 verified
"""
PhiForLogicalReasoning (LBNets) - Fixed architecture.
Reasoning operates on full sequences, not flattened tokens.
"""
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import PreTrainedModel, GenerationMixin
from transformers.activations import ACT2FN
from transformers.modeling_outputs import ModelOutput
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import logging
from transformers.models.phi.modeling_phi import (
PhiDecoderLayer,
PhiPreTrainedModel,
)
from .configuration import PhiReasoningConfig
logger = logging.get_logger(__name__)
# =============================================================================
# Output Dataclasses
# =============================================================================
@dataclass
class ReasoningModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
reasoning_states: Optional[Tuple[torch.FloatTensor, ...]] = None
reasoning_used: Optional[torch.BoolTensor] = None
halting_step: Optional[torch.LongTensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class ReasoningCausalLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
reasoning_states: Optional[Tuple[torch.FloatTensor, ...]] = None
reasoning_used: Optional[torch.BoolTensor] = None
halting_step: Optional[torch.LongTensor] = None
auxiliary_loss: Optional[torch.FloatTensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# =============================================================================
# Reasoning Components
# =============================================================================
class LatentReasoningTokens(nn.Module):
"""Learnable latent tokens that serve as the reasoning scratchpad."""
def __init__(self, config: PhiReasoningConfig):
super().__init__()
self.num_tokens = config.num_reasoning_tokens
self.hidden_size = config.hidden_size
self.embeddings = nn.Parameter(
torch.randn(1, self.num_tokens, self.hidden_size) * 0.02
)
self.step_embeddings = nn.Embedding(config.max_reasoning_steps, self.hidden_size)
def forward(
self, batch_size: int, step: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
tokens = self.embeddings.expand(batch_size, -1, -1).to(device=device, dtype=dtype)
step_tensor = torch.tensor([step], device=device, dtype=torch.long)
step_emb = self.step_embeddings(step_tensor).unsqueeze(1) # (1, 1, hidden)
return tokens + step_emb
class InputComplexityGate(nn.Module):
"""Determines whether input requires reasoning based on complexity."""
def __init__(self, config: PhiReasoningConfig):
super().__init__()
self.threshold = config.gating_threshold
self.gate = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size // 4),
nn.GELU(),
nn.Dropout(config.reasoning_dropout),
nn.Linear(config.hidden_size // 4, 1),
)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.BoolTensor]:
pooled = hidden_states.mean(dim=1) # (batch, hidden)
score = torch.sigmoid(self.gate(pooled).squeeze(-1)) # (batch,)
needs_reasoning = score > self.threshold
return score, needs_reasoning
class ReasoningAttention(nn.Module):
"""Multi-head attention for reasoning blocks (self or cross attention)."""
def __init__(self, config: PhiReasoningConfig, is_cross_attention: bool = False):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.scaling = self.head_dim ** -0.5
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.dropout = nn.Dropout(config.reasoning_dropout)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hidden_states.dim() == 2:
hidden_states = hidden_states.unsqueeze(1)
batch_size, seq_len, _ = hidden_states.shape
if key_value_states is None:
key_value_states = hidden_states
elif key_value_states.dim() == 2:
key_value_states = key_value_states.unsqueeze(1)
kv_seq_len = key_value_states.shape[1]
q = self.q_proj(hidden_states).view(
batch_size, seq_len, self.num_heads, self.head_dim
).transpose(1, 2)
k = self.k_proj(key_value_states).view(
batch_size, kv_seq_len, self.num_heads, self.head_dim
).transpose(1, 2)
v = self.v_proj(key_value_states).view(
batch_size, kv_seq_len, self.num_heads, self.head_dim
).transpose(1, 2)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
if attention_mask is not None:
if attention_mask.dim() == 2:
attention_mask = attention_mask[:, None, None, :]
elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, :, :]
if attention_mask.shape[-1] != kv_seq_len:
attention_mask = attention_mask[..., :kv_seq_len]
if attention_mask.dtype == torch.bool:
mask = torch.where(
attention_mask,
torch.tensor(0.0, dtype=attn_weights.dtype, device=attn_weights.device),
torch.tensor(
torch.finfo(attn_weights.dtype).min,
dtype=attn_weights.dtype,
device=attn_weights.device,
),
)
else:
mask = (1.0 - attention_mask.to(attn_weights.dtype)) * torch.finfo(
attn_weights.dtype
).min
attn_weights = attn_weights + mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
return self.out_proj(output)
class ReasoningBlock(nn.Module):
"""Single reasoning block: cross-attn to context + self-attn + MLP."""
def __init__(self, config: PhiReasoningConfig):
super().__init__()
self.cross_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.self_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_attn = ReasoningAttention(config, is_cross_attention=True)
self.self_attn = ReasoningAttention(config, is_cross_attention=False)
mlp_size = config.reasoning_intermediate_size
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, mlp_size),
ACT2FN[config.hidden_act],
nn.Dropout(config.reasoning_dropout),
nn.Linear(mlp_size, config.hidden_size),
nn.Dropout(config.reasoning_dropout),
)
def forward(
self,
reasoning_states: torch.Tensor,
context_states: torch.Tensor,
context_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = reasoning_states
normed = self.cross_attn_norm(reasoning_states)
reasoning_states = residual + self.cross_attn(
normed, key_value_states=context_states, attention_mask=context_mask
)
residual = reasoning_states
normed = self.self_attn_norm(reasoning_states)
reasoning_states = residual + self.self_attn(normed)
residual = reasoning_states
normed = self.mlp_norm(reasoning_states)
reasoning_states = residual + self.mlp(normed)
return reasoning_states
class AdaptiveHalting(nn.Module):
"""Decides when to stop reasoning based on confidence."""
def __init__(self, config: PhiReasoningConfig):
super().__init__()
self.threshold = config.halting_threshold
self.min_steps = config.min_reasoning_steps
self.max_steps = config.max_reasoning_steps
self.halt_predictor = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size // 4),
nn.GELU(),
nn.Linear(config.hidden_size // 4, 1),
)
def forward(
self, reasoning_states: torch.Tensor, step: int
) -> Tuple[torch.Tensor, torch.BoolTensor]:
pooled = reasoning_states.mean(dim=1)
halt_prob = torch.sigmoid(self.halt_predictor(pooled).squeeze(-1))
if step < self.min_steps:
should_halt = torch.zeros_like(halt_prob, dtype=torch.bool)
elif step >= self.max_steps - 1:
should_halt = torch.ones_like(halt_prob, dtype=torch.bool)
else:
should_halt = halt_prob > self.threshold
return halt_prob, should_halt
class ReasoningInjector(nn.Module):
"""Injects reasoning results back into the main hidden states via cross-attention."""
def __init__(self, config: PhiReasoningConfig):
super().__init__()
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_attn = ReasoningAttention(config, is_cross_attention=True)
self.gate_scale = nn.Parameter(torch.tensor([0.1])) # shape (1,)
self.dropout = nn.Dropout(config.reasoning_dropout)
def forward(
self, hidden_states: torch.Tensor, reasoning_states: torch.Tensor
) -> torch.Tensor:
residual = hidden_states
hidden_normed = self.norm(hidden_states)
reasoning_info = self.cross_attn(hidden_normed, key_value_states=reasoning_states)
return residual + self.dropout(reasoning_info * self.gate_scale)
# =============================================================================
# Causal Mask Helper
# =============================================================================
def _make_causal_mask(
attention_mask: Optional[torch.Tensor],
batch_size: int,
seq_length: int,
past_length: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
Build a proper 4D causal attention mask that PhiDecoderLayer expects.
Returns: (batch, 1, seq_length, past_length + seq_length) float tensor
with 0.0 for attend and -inf for mask.
"""
total_length = past_length + seq_length
# Start with causal mask: lower-triangular
# Shape: (1, 1, seq_length, total_length)
causal_mask = torch.full(
(1, 1, seq_length, total_length),
torch.finfo(dtype).min,
dtype=dtype,
device=device,
)
# Fill the causal (lower-triangular) portion with 0s
# Each position i can attend to positions 0..past_length+i
for i in range(seq_length):
causal_mask[0, 0, i, : past_length + i + 1] = 0.0
# Expand to batch size
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
# Apply padding mask if provided
if attention_mask is not None:
# attention_mask is (batch, total_seq_len) with 1=attend, 0=pad
# We need to mask out padded positions in the key dimension
if attention_mask.dim() == 2:
# Ensure it covers total_length
if attention_mask.shape[1] < total_length:
# Pad with 1s on the left (past positions are valid)
pad_len = total_length - attention_mask.shape[1]
attention_mask = F.pad(attention_mask, (pad_len, 0), value=1)
elif attention_mask.shape[1] > total_length:
attention_mask = attention_mask[:, :total_length]
# (batch, 1, 1, total_length)
padding_mask = attention_mask[:, None, None, :].to(dtype)
# Where padding_mask is 0, set to -inf
padding_mask = (1.0 - padding_mask) * torch.finfo(dtype).min
causal_mask = causal_mask.clone() + padding_mask
return causal_mask
# =============================================================================
# Main Model
# =============================================================================
class PhiReasoningModel(PhiPreTrainedModel):
config_class = PhiReasoningConfig
def __init__(self, config: PhiReasoningConfig):
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
injection_point = config.reasoning_injection_point
self.pre_reasoning_layers = nn.ModuleList(
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(injection_point)]
)
self.post_reasoning_layers = nn.ModuleList(
[
PhiDecoderLayer(config, layer_idx + injection_point)
for layer_idx in range(config.num_hidden_layers - injection_point)
]
)
self.reasoning_tokens = LatentReasoningTokens(config)
if config.share_reasoning_layers:
shared_block = ReasoningBlock(config)
self.reasoning_blocks = nn.ModuleList(
[shared_block for _ in range(config.num_reasoning_layers)]
)
else:
self.reasoning_blocks = nn.ModuleList(
[ReasoningBlock(config) for _ in range(config.num_reasoning_layers)]
)
self.input_gate = (
InputComplexityGate(config) if config.use_input_gating else None
)
self.halting = AdaptiveHalting(config) if config.use_adaptive_halting else None
self.reasoning_injector = ReasoningInjector(config)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _run_reasoning_loop(
self, context_states: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor], int]:
"""
Run the iterative reasoning loop.
context_states: (batch, seq_len, hidden) - full sequence.
"""
batch_size = context_states.shape[0]
device = context_states.device
dtype = context_states.dtype
reasoning_history = []
# Input gating
if self.input_gate is not None and not self.training:
complexity_score, needs_reasoning = self.input_gate(context_states)
if not needs_reasoning.any():
dummy = torch.zeros(
batch_size,
self.config.num_reasoning_tokens,
self.config.hidden_size,
device=device,
dtype=dtype,
)
return dummy, [], 0
# Initialize reasoning tokens
reasoning_states = self.reasoning_tokens(batch_size, 0, device, dtype)
final_step = 0
for step in range(self.config.max_reasoning_steps):
if step > 0:
step_emb = self.reasoning_tokens(batch_size, step, device, dtype)
reasoning_states = reasoning_states + 0.1 * step_emb
for block in self.reasoning_blocks:
if self.training and reasoning_states.requires_grad:
reasoning_states = torch.utils.checkpoint.checkpoint(
block,
reasoning_states,
context_states,
None,
use_reentrant=False,
)
else:
reasoning_states = block(reasoning_states, context_states)
if self.training:
reasoning_history.append(reasoning_states.detach())
else:
reasoning_history.append(reasoning_states)
final_step = step
# Adaptive halting (inference only)
if self.halting is not None and not self.training:
halt_prob, should_halt = self.halting(reasoning_states, step)
if should_halt.all():
break
return reasoning_states, reasoning_history, final_step
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_reasoning_states: bool = True,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> ReasoningModelOutput:
if use_cache is None:
use_cache = self.config.use_cache
if output_attentions is None:
output_attentions = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = inputs_embeds.shape[:2]
device = inputs_embeds.device
dtype = inputs_embeds.dtype
# past length
past_length = 0
if past_key_values is not None:
past_length = past_key_values.get_seq_length()
# position ids
if position_ids is None:
position_ids = torch.arange(
past_length, past_length + seq_length, dtype=torch.long, device=device
).unsqueeze(0).expand(batch_size, -1)
# cache init
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
# cache_position
if cache_position is None:
cache_position = torch.arange(past_length, past_length + seq_length, device=device)
# 4D causal mask for phi
causal_mask = _make_causal_mask(
attention_mask=attention_mask,
batch_size=batch_size,
seq_length=seq_length,
past_length=past_length,
dtype=dtype,
device=device,
)
hidden_states = self.embed_dropout(inputs_embeds)
# === Pre-reasoning layers ===
for layer in self.pre_reasoning_layers:
layer_outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
# present_key_value is always last element in HF decoder layer outputs
past_key_values = layer_outputs[-1]
# === Reasoning: only on the initial prompt pass ===
# When generating with KV cache, seq_length==1 and past_length>0: skip reasoning.
reasoning_history = []
halt_step = 0
if (past_length == 0) and (seq_length > 1):
reasoning_states, reasoning_history, halt_step = self._run_reasoning_loop(hidden_states)
hidden_states = self.reasoning_injector(hidden_states, reasoning_states)
# === Post-reasoning layers ===
for layer in self.post_reasoning_layers:
layer_outputs = layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
past_key_values = layer_outputs[-1]
# FINAL LAYER NORM (you were missing this)
hidden_states = self.final_layernorm(hidden_states)
return ReasoningModelOutput(
last_hidden_state=hidden_states,
reasoning_states=tuple(reasoning_history) if reasoning_history else None,
halting_step=torch.tensor([halt_step], device=device),
past_key_values=past_key_values if use_cache else None,
)
# =============================================================================
# Causal LM Wrapper
# =============================================================================
class PhiForLogicalReasoning(PhiPreTrainedModel, GenerationMixin):
config_class = PhiReasoningConfig
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: PhiReasoningConfig, *args,**kwargs):
super().__init__(config)
self.model = PhiReasoningModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_decoder(self):
return self.model
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_reasoning_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, ReasoningCausalLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_reasoning_states=output_reasoning_states,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
)
if not return_dict:
output = (logits,) + (outputs.reasoning_states, outputs.halting_step)
return ((loss,) + output) if loss is not None else output
return ReasoningCausalLMOutput(
loss=loss,
logits=logits,
reasoning_states=outputs.reasoning_states,
halting_step=outputs.halting_step,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
**kwargs,
):
if past_key_values is not None:
if input_ids.shape[1] != 1:
input_ids = input_ids[:, -1:]
model_inputs = {}
if inputs_embeds is not None and past_key_values is None:
model_inputs["inputs_embeds"] = inputs_embeds
else:
model_inputs["input_ids"] = input_ids
model_inputs.update(
{
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"cache_position": cache_position,
"use_cache": kwargs.get("use_cache", True),
}
)
return model_inputs