""" PhiForLogicalReasoning (LBNets) - Fixed architecture. Reasoning operates on full sequences, not flattened tokens. """ from dataclasses import dataclass from typing import Optional, Tuple, List, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from transformers import PreTrainedModel, GenerationMixin from transformers.activations import ACT2FN from transformers.modeling_outputs import ModelOutput from transformers.cache_utils import Cache, DynamicCache from transformers.utils import logging from transformers.models.phi.modeling_phi import ( PhiDecoderLayer, PhiPreTrainedModel, ) from .configuration import PhiReasoningConfig logger = logging.get_logger(__name__) # ============================================================================= # Output Dataclasses # ============================================================================= @dataclass class ReasoningModelOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None reasoning_states: Optional[Tuple[torch.FloatTensor, ...]] = None reasoning_used: Optional[torch.BoolTensor] = None halting_step: Optional[torch.LongTensor] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None @dataclass class ReasoningCausalLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None reasoning_states: Optional[Tuple[torch.FloatTensor, ...]] = None reasoning_used: Optional[torch.BoolTensor] = None halting_step: Optional[torch.LongTensor] = None auxiliary_loss: Optional[torch.FloatTensor] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # ============================================================================= # Reasoning Components # ============================================================================= class LatentReasoningTokens(nn.Module): """Learnable latent tokens that serve as the reasoning scratchpad.""" def __init__(self, config: PhiReasoningConfig): super().__init__() self.num_tokens = config.num_reasoning_tokens self.hidden_size = config.hidden_size self.embeddings = nn.Parameter( torch.randn(1, self.num_tokens, self.hidden_size) * 0.02 ) self.step_embeddings = nn.Embedding(config.max_reasoning_steps, self.hidden_size) def forward( self, batch_size: int, step: int, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: tokens = self.embeddings.expand(batch_size, -1, -1).to(device=device, dtype=dtype) step_tensor = torch.tensor([step], device=device, dtype=torch.long) step_emb = self.step_embeddings(step_tensor).unsqueeze(1) # (1, 1, hidden) return tokens + step_emb class InputComplexityGate(nn.Module): """Determines whether input requires reasoning based on complexity.""" def __init__(self, config: PhiReasoningConfig): super().__init__() self.threshold = config.gating_threshold self.gate = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size // 4), nn.GELU(), nn.Dropout(config.reasoning_dropout), nn.Linear(config.hidden_size // 4, 1), ) def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.BoolTensor]: pooled = hidden_states.mean(dim=1) # (batch, hidden) score = torch.sigmoid(self.gate(pooled).squeeze(-1)) # (batch,) needs_reasoning = score > self.threshold return score, needs_reasoning class ReasoningAttention(nn.Module): """Multi-head attention for reasoning blocks (self or cross attention).""" def __init__(self, config: PhiReasoningConfig, is_cross_attention: bool = False): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.scaling = self.head_dim ** -0.5 self.is_cross_attention = is_cross_attention self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.dropout = nn.Dropout(config.reasoning_dropout) def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if hidden_states.dim() == 2: hidden_states = hidden_states.unsqueeze(1) batch_size, seq_len, _ = hidden_states.shape if key_value_states is None: key_value_states = hidden_states elif key_value_states.dim() == 2: key_value_states = key_value_states.unsqueeze(1) kv_seq_len = key_value_states.shape[1] q = self.q_proj(hidden_states).view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) k = self.k_proj(key_value_states).view( batch_size, kv_seq_len, self.num_heads, self.head_dim ).transpose(1, 2) v = self.v_proj(key_value_states).view( batch_size, kv_seq_len, self.num_heads, self.head_dim ).transpose(1, 2) attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling if attention_mask is not None: if attention_mask.dim() == 2: attention_mask = attention_mask[:, None, None, :] elif attention_mask.dim() == 3: attention_mask = attention_mask[:, None, :, :] if attention_mask.shape[-1] != kv_seq_len: attention_mask = attention_mask[..., :kv_seq_len] if attention_mask.dtype == torch.bool: mask = torch.where( attention_mask, torch.tensor(0.0, dtype=attn_weights.dtype, device=attn_weights.device), torch.tensor( torch.finfo(attn_weights.dtype).min, dtype=attn_weights.dtype, device=attn_weights.device, ), ) else: mask = (1.0 - attention_mask.to(attn_weights.dtype)) * torch.finfo( attn_weights.dtype ).min attn_weights = attn_weights + mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_weights = self.dropout(attn_weights) output = torch.matmul(attn_weights, v) output = output.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size) return self.out_proj(output) class ReasoningBlock(nn.Module): """Single reasoning block: cross-attn to context + self-attn + MLP.""" def __init__(self, config: PhiReasoningConfig): super().__init__() self.cross_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.self_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.cross_attn = ReasoningAttention(config, is_cross_attention=True) self.self_attn = ReasoningAttention(config, is_cross_attention=False) mlp_size = config.reasoning_intermediate_size self.mlp = nn.Sequential( nn.Linear(config.hidden_size, mlp_size), ACT2FN[config.hidden_act], nn.Dropout(config.reasoning_dropout), nn.Linear(mlp_size, config.hidden_size), nn.Dropout(config.reasoning_dropout), ) def forward( self, reasoning_states: torch.Tensor, context_states: torch.Tensor, context_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = reasoning_states normed = self.cross_attn_norm(reasoning_states) reasoning_states = residual + self.cross_attn( normed, key_value_states=context_states, attention_mask=context_mask ) residual = reasoning_states normed = self.self_attn_norm(reasoning_states) reasoning_states = residual + self.self_attn(normed) residual = reasoning_states normed = self.mlp_norm(reasoning_states) reasoning_states = residual + self.mlp(normed) return reasoning_states class AdaptiveHalting(nn.Module): """Decides when to stop reasoning based on confidence.""" def __init__(self, config: PhiReasoningConfig): super().__init__() self.threshold = config.halting_threshold self.min_steps = config.min_reasoning_steps self.max_steps = config.max_reasoning_steps self.halt_predictor = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size // 4), nn.GELU(), nn.Linear(config.hidden_size // 4, 1), ) def forward( self, reasoning_states: torch.Tensor, step: int ) -> Tuple[torch.Tensor, torch.BoolTensor]: pooled = reasoning_states.mean(dim=1) halt_prob = torch.sigmoid(self.halt_predictor(pooled).squeeze(-1)) if step < self.min_steps: should_halt = torch.zeros_like(halt_prob, dtype=torch.bool) elif step >= self.max_steps - 1: should_halt = torch.ones_like(halt_prob, dtype=torch.bool) else: should_halt = halt_prob > self.threshold return halt_prob, should_halt class ReasoningInjector(nn.Module): """Injects reasoning results back into the main hidden states via cross-attention.""" def __init__(self, config: PhiReasoningConfig): super().__init__() self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.cross_attn = ReasoningAttention(config, is_cross_attention=True) self.gate_scale = nn.Parameter(torch.tensor([0.1])) # shape (1,) self.dropout = nn.Dropout(config.reasoning_dropout) def forward( self, hidden_states: torch.Tensor, reasoning_states: torch.Tensor ) -> torch.Tensor: residual = hidden_states hidden_normed = self.norm(hidden_states) reasoning_info = self.cross_attn(hidden_normed, key_value_states=reasoning_states) return residual + self.dropout(reasoning_info * self.gate_scale) # ============================================================================= # Causal Mask Helper # ============================================================================= def _make_causal_mask( attention_mask: Optional[torch.Tensor], batch_size: int, seq_length: int, past_length: int, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: """ Build a proper 4D causal attention mask that PhiDecoderLayer expects. Returns: (batch, 1, seq_length, past_length + seq_length) float tensor with 0.0 for attend and -inf for mask. """ total_length = past_length + seq_length # Start with causal mask: lower-triangular # Shape: (1, 1, seq_length, total_length) causal_mask = torch.full( (1, 1, seq_length, total_length), torch.finfo(dtype).min, dtype=dtype, device=device, ) # Fill the causal (lower-triangular) portion with 0s # Each position i can attend to positions 0..past_length+i for i in range(seq_length): causal_mask[0, 0, i, : past_length + i + 1] = 0.0 # Expand to batch size causal_mask = causal_mask.expand(batch_size, -1, -1, -1) # Apply padding mask if provided if attention_mask is not None: # attention_mask is (batch, total_seq_len) with 1=attend, 0=pad # We need to mask out padded positions in the key dimension if attention_mask.dim() == 2: # Ensure it covers total_length if attention_mask.shape[1] < total_length: # Pad with 1s on the left (past positions are valid) pad_len = total_length - attention_mask.shape[1] attention_mask = F.pad(attention_mask, (pad_len, 0), value=1) elif attention_mask.shape[1] > total_length: attention_mask = attention_mask[:, :total_length] # (batch, 1, 1, total_length) padding_mask = attention_mask[:, None, None, :].to(dtype) # Where padding_mask is 0, set to -inf padding_mask = (1.0 - padding_mask) * torch.finfo(dtype).min causal_mask = causal_mask.clone() + padding_mask return causal_mask # ============================================================================= # Main Model # ============================================================================= class PhiReasoningModel(PhiPreTrainedModel): config_class = PhiReasoningConfig def __init__(self, config: PhiReasoningConfig): super().__init__(config) self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.embed_dropout = nn.Dropout(config.embd_pdrop) injection_point = config.reasoning_injection_point self.pre_reasoning_layers = nn.ModuleList( [PhiDecoderLayer(config, layer_idx) for layer_idx in range(injection_point)] ) self.post_reasoning_layers = nn.ModuleList( [ PhiDecoderLayer(config, layer_idx + injection_point) for layer_idx in range(config.num_hidden_layers - injection_point) ] ) self.reasoning_tokens = LatentReasoningTokens(config) if config.share_reasoning_layers: shared_block = ReasoningBlock(config) self.reasoning_blocks = nn.ModuleList( [shared_block for _ in range(config.num_reasoning_layers)] ) else: self.reasoning_blocks = nn.ModuleList( [ReasoningBlock(config) for _ in range(config.num_reasoning_layers)] ) self.input_gate = ( InputComplexityGate(config) if config.use_input_gating else None ) self.halting = AdaptiveHalting(config) if config.use_adaptive_halting else None self.reasoning_injector = ReasoningInjector(config) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def _run_reasoning_loop( self, context_states: torch.Tensor ) -> Tuple[torch.Tensor, List[torch.Tensor], int]: """ Run the iterative reasoning loop. context_states: (batch, seq_len, hidden) - full sequence. """ batch_size = context_states.shape[0] device = context_states.device dtype = context_states.dtype reasoning_history = [] # Input gating if self.input_gate is not None and not self.training: complexity_score, needs_reasoning = self.input_gate(context_states) if not needs_reasoning.any(): dummy = torch.zeros( batch_size, self.config.num_reasoning_tokens, self.config.hidden_size, device=device, dtype=dtype, ) return dummy, [], 0 # Initialize reasoning tokens reasoning_states = self.reasoning_tokens(batch_size, 0, device, dtype) final_step = 0 for step in range(self.config.max_reasoning_steps): if step > 0: step_emb = self.reasoning_tokens(batch_size, step, device, dtype) reasoning_states = reasoning_states + 0.1 * step_emb for block in self.reasoning_blocks: if self.training and reasoning_states.requires_grad: reasoning_states = torch.utils.checkpoint.checkpoint( block, reasoning_states, context_states, None, use_reentrant=False, ) else: reasoning_states = block(reasoning_states, context_states) if self.training: reasoning_history.append(reasoning_states.detach()) else: reasoning_history.append(reasoning_states) final_step = step # Adaptive halting (inference only) if self.halting is not None and not self.training: halt_prob, should_halt = self.halting(reasoning_states, step) if should_halt.all(): break return reasoning_states, reasoning_history, final_step def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_reasoning_states: bool = True, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> ReasoningModelOutput: if use_cache is None: use_cache = self.config.use_cache if output_attentions is None: output_attentions = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = inputs_embeds.shape[:2] device = inputs_embeds.device dtype = inputs_embeds.dtype # past length past_length = 0 if past_key_values is not None: past_length = past_key_values.get_seq_length() # position ids if position_ids is None: position_ids = torch.arange( past_length, past_length + seq_length, dtype=torch.long, device=device ).unsqueeze(0).expand(batch_size, -1) # cache init if use_cache and past_key_values is None: past_key_values = DynamicCache() # cache_position if cache_position is None: cache_position = torch.arange(past_length, past_length + seq_length, device=device) # 4D causal mask for phi causal_mask = _make_causal_mask( attention_mask=attention_mask, batch_size=batch_size, seq_length=seq_length, past_length=past_length, dtype=dtype, device=device, ) hidden_states = self.embed_dropout(inputs_embeds) # === Pre-reasoning layers === for layer in self.pre_reasoning_layers: layer_outputs = layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: # present_key_value is always last element in HF decoder layer outputs past_key_values = layer_outputs[-1] # === Reasoning: only on the initial prompt pass === # When generating with KV cache, seq_length==1 and past_length>0: skip reasoning. reasoning_history = [] halt_step = 0 if (past_length == 0) and (seq_length > 1): reasoning_states, reasoning_history, halt_step = self._run_reasoning_loop(hidden_states) hidden_states = self.reasoning_injector(hidden_states, reasoning_states) # === Post-reasoning layers === for layer in self.post_reasoning_layers: layer_outputs = layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: past_key_values = layer_outputs[-1] # FINAL LAYER NORM (you were missing this) hidden_states = self.final_layernorm(hidden_states) return ReasoningModelOutput( last_hidden_state=hidden_states, reasoning_states=tuple(reasoning_history) if reasoning_history else None, halting_step=torch.tensor([halt_step], device=device), past_key_values=past_key_values if use_cache else None, ) # ============================================================================= # Causal LM Wrapper # ============================================================================= class PhiForLogicalReasoning(PhiPreTrainedModel, GenerationMixin): config_class = PhiReasoningConfig _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: PhiReasoningConfig, *args,**kwargs): super().__init__(config) self.model = PhiReasoningModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_decoder(self): return self.model def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_reasoning_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, ReasoningCausalLMOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_reasoning_states=output_reasoning_states, cache_position=cache_position, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ) if not return_dict: output = (logits,) + (outputs.reasoning_states, outputs.halting_step) return ((loss,) + output) if loss is not None else output return ReasoningCausalLMOutput( loss=loss, logits=logits, reasoning_states=outputs.reasoning_states, halting_step=outputs.halting_step, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs, ): if past_key_values is not None: if input_ids.shape[1] != 1: input_ids = input_ids[:, -1:] model_inputs = {} if inputs_embeds is not None and past_key_values is None: model_inputs["inputs_embeds"] = inputs_embeds else: model_inputs["input_ids"] = input_ids model_inputs.update( { "past_key_values": past_key_values, "attention_mask": attention_mask, "cache_position": cache_position, "use_cache": kwargs.get("use_cache", True), } ) return model_inputs