| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Tuple, List, Dict, Any |
| | import math |
| | from ..configs.config import ModelConfig, InterleavedThinkingConfig |
| |
|
| |
|
| | class EfficientAttention(nn.Module): |
| | """Memory-efficient attention mechanism with flash attention support.""" |
| |
|
| | def __init__(self, config: ModelConfig): |
| | super().__init__() |
| | self.config = config |
| | self.n_heads = config.heads |
| | self.head_dim = config.dim // config.heads |
| | self.scale = self.head_dim ** -0.5 |
| |
|
| | self.q_proj = nn.Linear(config.dim, config.dim, bias=False) |
| | self.k_proj = nn.Linear(config.dim, config.dim, bias=False) |
| | self.v_proj = nn.Linear(config.dim, config.dim, bias=False) |
| | self.o_proj = nn.Linear(config.dim, config.dim, bias=False) |
| |
|
| | self.dropout = nn.Dropout(config.dropout) |
| |
|
| | |
| | self.rope_cache = self._build_rope_cache(config.max_seq_len) |
| |
|
| | def _build_rope_cache(self, max_seq_len: int) -> torch.Tensor: |
| | """Build RoPE (Rotary Position Embedding) cache.""" |
| | inv_freq = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) |
| | t = torch.arange(max_seq_len).float() |
| | freqs = torch.einsum('i , j -> i j', t, inv_freq) |
| | return torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
| |
|
| | def _apply_rope(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: |
| | """Apply RoPE to input tensor.""" |
| | rope = self.rope_cache[start_pos:start_pos + x.size(1)].to(x.device) |
| | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
| | rope_shaped = rope.reshape(1, xshaped.size(1), 1, xshaped.size(3), 2) |
| | x_out = torch.stack([ |
| | xshaped[..., 0] * rope_shaped[..., 0] - xshaped[..., 1] * rope_shaped[..., 1], |
| | xshaped[..., 1] * rope_shaped[..., 0] + xshaped[..., 0] * rope_shaped[..., 1], |
| | ], -1) |
| | return x_out.flatten(3).type_as(x) |
| |
|
| | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| | B, T, C = x.size() |
| |
|
| | |
| | q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) |
| | k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim) |
| | v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim) |
| |
|
| | |
| | q = self._apply_rope(q.transpose(1, 2)).transpose(1, 2) |
| | k = self._apply_rope(k.transpose(1, 2)).transpose(1, 2) |
| |
|
| | |
| | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| |
|
| | |
| | if self.config.flash_attention and hasattr(torch.nn.functional, 'scaled_dot_product_attention'): |
| | |
| | attn_output = F.scaled_dot_product_attention( |
| | q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0.0 |
| | ) |
| | else: |
| | |
| | attn_weights = (q @ k.transpose(-2, -1)) * self.scale |
| | if mask is not None: |
| | attn_weights = attn_weights.masked_fill(mask == 0, float('-inf')) |
| | attn_weights = F.softmax(attn_weights, dim=-1) |
| | attn_weights = self.dropout(attn_weights) |
| | attn_output = attn_weights @ v |
| |
|
| | |
| | attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C) |
| | return self.o_proj(attn_output) |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | """Feed-forward network with SwiGLU activation.""" |
| |
|
| | def __init__(self, config: ModelConfig): |
| | super().__init__() |
| | hidden_dim = 4 * config.dim |
| |
|
| | self.gate_proj = nn.Linear(config.dim, hidden_dim, bias=False) |
| | self.up_proj = nn.Linear(config.dim, hidden_dim, bias=False) |
| | self.down_proj = nn.Linear(hidden_dim, config.dim, bias=False) |
| | self.dropout = nn.Dropout(config.dropout) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | gate = self.gate_proj(x) |
| | up = self.up_proj(x) |
| | |
| | hidden = F.silu(gate) * up |
| | return self.down_proj(self.dropout(hidden)) |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | """Single transformer block with efficient attention and feed-forward.""" |
| |
|
| | def __init__(self, config: ModelConfig): |
| | super().__init__() |
| | self.attention = EfficientAttention(config) |
| | self.feed_forward = FeedForward(config) |
| | self.attention_norm = nn.RMSNorm(config.dim) |
| | self.ffn_norm = nn.RMSNorm(config.dim) |
| |
|
| | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| | |
| | attn_out = self.attention(self.attention_norm(x), mask) |
| | x = x + attn_out |
| |
|
| | |
| | ff_out = self.feed_forward(self.ffn_norm(x)) |
| | x = x + ff_out |
| |
|
| | return x |
| |
|
| |
|
| | class CompactTransformer(nn.Module): |
| | """Compact transformer model with efficient architecture.""" |
| |
|
| | def __init__(self, config: ModelConfig): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.dim) |
| | self.embed_positions = nn.Embedding(config.max_seq_len, config.dim) |
| |
|
| | |
| | self.layers = nn.ModuleList([ |
| | TransformerBlock(config) for _ in range(config.layers) |
| | ]) |
| |
|
| | |
| | self.norm = nn.RMSNorm(config.dim) |
| | self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) |
| |
|
| | |
| | self.embed_tokens.weight = self.lm_head.weight |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, module): |
| | """Initialize model weights.""" |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | ) -> Dict[str, torch.Tensor]: |
| | B, T = input_ids.size() |
| |
|
| | |
| | if position_ids is None: |
| | position_ids = torch.arange(T, dtype=torch.long, device=input_ids.device).unsqueeze(0) |
| |
|
| | |
| | token_emb = self.embed_tokens(input_ids) |
| | pos_emb = self.embed_positions(position_ids) |
| | x = token_emb + pos_emb |
| |
|
| | |
| | if attention_mask is not None: |
| | |
| | causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool() |
| | causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) |
| | |
| | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| | attention_mask = attention_mask & ~causal_mask |
| | else: |
| | |
| | causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool() |
| | attention_mask = ~causal_mask.unsqueeze(0).unsqueeze(0) |
| |
|
| | |
| | for layer in self.layers: |
| | x = layer(x, attention_mask) |
| |
|
| | |
| | x = self.norm(x) |
| |
|
| | |
| | logits = self.lm_head(x) |
| |
|
| | return {"logits": logits, "hidden_states": x} |
| |
|
| | def get_num_params(self) -> int: |
| | """Get total number of parameters.""" |
| | return sum(p.numel() for p in self.parameters()) |
| |
|
| |
|
| | class ReasoningPath(nn.Module): |
| | """Enhanced reasoning path for interleaved thinking with uncertainty estimation.""" |
| |
|
| | def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig, path_id: int = 0): |
| | super().__init__() |
| | self.config = config |
| | self.thinking_config = thinking_config |
| | self.path_id = path_id |
| |
|
| | |
| | self.reasoning_layers = nn.ModuleList([ |
| | TransformerBlock(config) for _ in range(min(2, config.layers // 2)) |
| | ]) |
| |
|
| | |
| | self.confidence_head = nn.Sequential( |
| | nn.Linear(config.dim, config.dim // 2), |
| | nn.ReLU(), |
| | nn.Linear(config.dim // 2, 2) |
| | ) |
| |
|
| | self.output_projection = nn.Linear(config.dim, config.vocab_size) |
| |
|
| | |
| | if thinking_config.path_specialization: |
| | self.specialization_adapter = nn.Linear(config.dim, config.dim) |
| | else: |
| | self.specialization_adapter = None |
| |
|
| | def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: |
| | x = hidden_states |
| |
|
| | |
| | if self.specialization_adapter is not None: |
| | |
| | path_bias = torch.sin(torch.tensor(float(self.path_id) * 0.1)) * 0.1 |
| | x = x + path_bias |
| |
|
| | |
| | if self.specialization_adapter is not None: |
| | x = self.specialization_adapter(x) |
| |
|
| | |
| | for layer in self.reasoning_layers: |
| | x = layer(x, mask) |
| |
|
| | |
| | if self.thinking_config.uncertainty_estimation: |
| | confidence_params = self.confidence_head(x.mean(dim=1)) |
| | confidence_mean = torch.sigmoid(confidence_params[:, 0:1]) |
| | confidence_var = F.softplus(confidence_params[:, 1:2]) + 1e-6 |
| |
|
| | |
| | if self.training: |
| | eps = torch.randn_like(confidence_var) |
| | confidence = confidence_mean + torch.sqrt(confidence_var) * eps |
| | confidence = torch.clamp(confidence, 0.0, 1.0) |
| | else: |
| | confidence = confidence_mean |
| | else: |
| | |
| | confidence = torch.sigmoid(self.confidence_head(x.mean(dim=1))) |
| |
|
| | |
| | reasoning_logits = self.output_projection(x) |
| |
|
| | return { |
| | "reasoning_logits": reasoning_logits, |
| | "confidence": confidence, |
| | "confidence_var": confidence_var if self.thinking_config.uncertainty_estimation else None, |
| | "reasoning_states": x |
| | } |
| |
|
| |
|
| | class EarlyStopController(nn.Module): |
| | """Enhanced controller for early stopping with task-specific thresholds.""" |
| |
|
| | def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig): |
| | super().__init__() |
| | self.thinking_config = thinking_config |
| |
|
| | |
| | self.complexity_classifier = nn.Sequential( |
| | nn.Linear(config.dim, config.dim // 2), |
| | nn.ReLU(), |
| | nn.Linear(config.dim // 2, 3), |
| | ) |
| |
|
| | |
| | self.stop_predictor = nn.Linear(config.dim, 1) |
| |
|
| | |
| | if thinking_config.task_specific_thresholds: |
| | self.task_threshold_predictor = nn.Sequential( |
| | nn.Linear(config.dim, config.dim // 4), |
| | nn.ReLU(), |
| | nn.Linear(config.dim // 4, 1), |
| | nn.Sigmoid() |
| | ) |
| | else: |
| | self.task_threshold_predictor = None |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | |
| | complexity_logits = self.complexity_classifier(hidden_states.mean(dim=1)) |
| | complexity_probs = F.softmax(complexity_logits, dim=-1) |
| |
|
| | |
| | stop_logits = self.stop_predictor(hidden_states.mean(dim=1)) |
| | stop_prob = torch.sigmoid(stop_logits) |
| |
|
| | |
| | task_threshold = None |
| | if self.task_threshold_predictor is not None: |
| | task_threshold = self.task_threshold_predictor(hidden_states.mean(dim=1)) |
| |
|
| | return complexity_probs, stop_prob, task_threshold |
| |
|
| |
|
| | class HierarchicalReasoningPath(nn.Module): |
| | """Hierarchical reasoning path with different abstraction levels.""" |
| |
|
| | def __init__(self, config: ModelConfig, thinking_config: InterleavedThinkingConfig, level: int): |
| | super().__init__() |
| | self.level = level |
| | self.config = config |
| | self.thinking_config = thinking_config |
| |
|
| | |
| | if level == 0: |
| | self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(2)]) |
| | self.abstraction_projection = nn.Linear(config.dim, config.dim // 2) |
| | elif level == 1: |
| | self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(1)]) |
| | self.abstraction_projection = nn.Linear(config.dim, config.dim // 4) |
| | else: |
| | self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(1)]) |
| | self.abstraction_projection = nn.Linear(config.dim, config.dim // 8) |
| |
|
| | self.confidence_head = nn.Linear(config.dim, 1) |
| | self.output_projection = nn.Linear(config.dim, config.vocab_size) |
| |
|
| | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: |
| | |
| | for layer in self.layers: |
| | x = layer(x, mask) |
| |
|
| | |
| | abstracted = self.abstraction_projection(x) |
| |
|
| | confidence = torch.sigmoid(self.confidence_head(x.mean(dim=1))) |
| | reasoning_logits = self.output_projection(x) |
| |
|
| | return { |
| | "reasoning_logits": reasoning_logits, |
| | "confidence": confidence, |
| | "abstracted_states": abstracted, |
| | "level": self.level |
| | } |
| |
|
| |
|
| | class InterleavedThinking(nn.Module): |
| | """Enhanced interleaved thinking mechanism with hierarchical paths and attention fusion.""" |
| |
|
| | def __init__(self, model_config: ModelConfig, thinking_config: InterleavedThinkingConfig): |
| | super().__init__() |
| | self.model_config = model_config |
| | self.thinking_config = thinking_config |
| |
|
| | |
| | if thinking_config.hierarchical_paths: |
| | self.reasoning_paths = nn.ModuleList([ |
| | HierarchicalReasoningPath(model_config, thinking_config, level % thinking_config.num_hierarchy_levels) |
| | for level in range(thinking_config.max_reasoning_paths) |
| | ]) |
| | else: |
| | |
| | self.reasoning_paths = nn.ModuleList([ |
| | ReasoningPath(model_config, thinking_config, path_id=i) |
| | for i in range(thinking_config.max_reasoning_paths) |
| | ]) |
| |
|
| | |
| | self.early_stop_controller = EarlyStopController(model_config, thinking_config) |
| |
|
| | |
| | if thinking_config.attention_fusion: |
| | self.fusion_attention = nn.MultiheadAttention( |
| | embed_dim=model_config.vocab_size, |
| | num_heads=8, |
| | dropout=0.1 |
| | ) |
| | self.fusion_norm = nn.LayerNorm(model_config.vocab_size) |
| | else: |
| | |
| | self.path_combiner = nn.Linear( |
| | model_config.vocab_size * thinking_config.max_reasoning_paths, |
| | model_config.vocab_size |
| | ) |
| |
|
| | |
| | if thinking_config.adaptive_compression: |
| | compression_dim = model_config.dim // 4 |
| | self.memory_compressor = nn.Linear(model_config.dim, compression_dim) |
| | self.memory_reconstructor = nn.Linear(compression_dim, model_config.dim) |
| | self.compression_gate = nn.Sequential( |
| | nn.Linear(model_config.dim, 1), |
| | nn.Sigmoid() |
| | ) |
| | elif thinking_config.memory_compression: |
| | self.memory_compressor = nn.Linear(model_config.dim, model_config.dim // 4) |
| | self.memory_reconstructor = None |
| | self.compression_gate = None |
| | else: |
| | self.memory_compressor = None |
| | self.memory_reconstructor = None |
| | self.compression_gate = None |
| |
|
| | def forward( |
| | self, |
| | base_hidden_states: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | current_depth: int = 0 |
| | ) -> Dict[str, Any]: |
| | batch_size = base_hidden_states.size(0) |
| |
|
| | |
| | complexity_probs, stop_prob, task_threshold = self.early_stop_controller(base_hidden_states) |
| |
|
| | |
| | effective_threshold = task_threshold if task_threshold is not None else self.thinking_config.early_stop_threshold |
| | should_stop = stop_prob > effective_threshold |
| |
|
| | if should_stop.item() and current_depth > 1: |
| | return { |
| | "should_stop": True, |
| | "reasoning_results": None, |
| | "final_logits": None, |
| | "confidence_scores": None, |
| | "complexity": complexity_probs, |
| | "visualization_data": {"early_stop_depth": current_depth} if self.thinking_config.visualization_enabled else None |
| | } |
| |
|
| | |
| | path_results = [] |
| | confidence_scores = [] |
| | confidence_vars = [] |
| |
|
| | for path in self.reasoning_paths: |
| | result = path(base_hidden_states, mask) |
| | path_results.append(result["reasoning_logits"]) |
| | confidence_scores.append(result["confidence"]) |
| | if "confidence_var" in result and result["confidence_var"] is not None: |
| | confidence_vars.append(result["confidence_var"]) |
| |
|
| | |
| | path_logits = torch.stack(path_results, dim=1) |
| | confidence_scores = torch.stack(confidence_scores, dim=1) |
| |
|
| | |
| | if self.thinking_config.attention_fusion: |
| | |
| | |
| | B, P, T, V = path_logits.size() |
| | flat_logits = path_logits.view(B * T, P, V) |
| |
|
| | |
| | attn_output, _ = self.fusion_attention( |
| | flat_logits.mean(dim=1, keepdim=True), |
| | flat_logits, |
| | flat_logits |
| | ) |
| | combined_logits = self.fusion_norm(attn_output.squeeze(1)).view(B, T, V) |
| | else: |
| | |
| | confidence_weights = F.softmax(confidence_scores.squeeze(-1), dim=-1) |
| | confidence_weights = confidence_weights.unsqueeze(-1).unsqueeze(-1) |
| |
|
| | |
| | weighted_logits = (path_logits * confidence_weights).sum(dim=1) |
| |
|
| | |
| | combined_logits = self.path_combiner( |
| | weighted_logits.view(batch_size, -1, self.model_config.vocab_size * self.thinking_config.max_reasoning_paths) |
| | ) |
| |
|
| | |
| | compressed_states = None |
| | reconstruction_loss = None |
| | if self.memory_compressor is not None: |
| | |
| | best_path_idx = confidence_scores.mean(dim=-1).argmax(dim=-1) |
| | best_reasoning_states = torch.stack([ |
| | path_results[i]["reasoning_states"][b] for b, i in enumerate(best_path_idx) |
| | ], dim=0) |
| |
|
| | if self.thinking_config.adaptive_compression and self.memory_reconstructor is not None: |
| | |
| | compression_gate = self.compression_gate(best_reasoning_states.mean(dim=1)) |
| | compressed = self.memory_compressor(best_reasoning_states) |
| | reconstructed = self.memory_reconstructor(compressed) |
| |
|
| | |
| | reconstruction_loss = F.mse_loss(reconstructed, best_reasoning_states) |
| |
|
| | |
| | compressed_states = torch.where( |
| | compression_gate.unsqueeze(-1).unsqueeze(-1) > 0.5, |
| | compressed, |
| | best_reasoning_states |
| | ) |
| | else: |
| | |
| | compressed_states = self.memory_compressor(best_reasoning_states) |
| |
|
| | |
| | visualization_data = None |
| | if self.thinking_config.visualization_enabled: |
| | visualization_data = { |
| | "confidence_scores": confidence_scores.cpu().numpy(), |
| | "confidence_vars": [v.cpu().numpy() for v in confidence_vars] if confidence_vars else None, |
| | "complexity_probs": complexity_probs.cpu().numpy(), |
| | "task_threshold": task_threshold.cpu().numpy() if task_threshold is not None else None, |
| | "path_logits_shape": path_logits.shape, |
| | "hierarchical_levels": [getattr(path, 'level', 0) for path in self.reasoning_paths] if self.thinking_config.hierarchical_paths else None, |
| | "reconstruction_loss": reconstruction_loss.item() if reconstruction_loss is not None else None |
| | } |
| |
|
| | return { |
| | "should_stop": False, |
| | "reasoning_results": { |
| | "path_logits": path_logits, |
| | "confidence_scores": confidence_scores, |
| | "complexity": complexity_probs, |
| | "compressed_states": compressed_states, |
| | "confidence_vars": confidence_vars if confidence_vars else None, |
| | "reconstruction_loss": reconstruction_loss |
| | }, |
| | "final_logits": combined_logits, |
| | "confidence_scores": confidence_scores.mean(dim=1), |
| | "visualization_data": visualization_data |
| | } |
| |
|
| |
|
| | class CompactAIModel(nn.Module): |
| | """Complete compact AI model with interleaved thinking.""" |
| |
|
| | def __init__(self, model_config: ModelConfig, thinking_config: InterleavedThinkingConfig): |
| | super().__init__() |
| | self.model_config = model_config |
| | self.thinking_config = thinking_config |
| |
|
| | |
| | self.base_model = CompactTransformer(model_config) |
| |
|
| | |
| | self.thinking = InterleavedThinking(model_config, thinking_config) |
| |
|
| | |
| | self.depth_controller = nn.Linear(model_config.dim, thinking_config.reasoning_depth) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | use_thinking: bool = True, |
| | max_reasoning_depth: Optional[int] = None, |
| | ) -> Dict[str, Any]: |
| | |
| | base_outputs = self.base_model(input_ids, attention_mask) |
| | base_logits = base_outputs["logits"] |
| | base_hidden = base_outputs["hidden_states"] |
| |
|
| | if not use_thinking: |
| | return { |
| | "logits": base_logits, |
| | "thinking_results": None, |
| | "final_tokens": 0, |
| | "visualization_data": None |
| | } |
| |
|
| | |
| | if max_reasoning_depth is None: |
| | |
| | if self.thinking_config.dynamic_depth: |
| | depth_logits = self.depth_controller(base_hidden.mean(dim=1)) |
| | depth_probs = F.softmax(depth_logits, dim=-1) |
| | max_reasoning_depth = depth_probs.argmax(dim=-1).item() + 1 |
| | else: |
| | max_reasoning_depth = self.thinking_config.reasoning_depth |
| |
|
| | |
| | current_hidden = base_hidden |
| | thinking_results = [] |
| | total_reasoning_tokens = 0 |
| | visualization_history = [] if self.thinking_config.visualization_enabled else None |
| |
|
| | for depth in range(max_reasoning_depth): |
| | thinking_output = self.thinking(current_hidden, attention_mask, depth) |
| |
|
| | if thinking_output["should_stop"]: |
| | break |
| |
|
| | thinking_results.append(thinking_output["reasoning_results"]) |
| |
|
| | |
| | if self.thinking_config.visualization_enabled and thinking_output["visualization_data"]: |
| | thinking_output["visualization_data"]["depth"] = depth |
| | visualization_history.append(thinking_output["visualization_data"]) |
| |
|
| | |
| | if thinking_output["reasoning_results"]["compressed_states"] is not None: |
| | current_hidden = thinking_output["reasoning_results"]["compressed_states"] |
| | else: |
| | |
| | |
| | current_hidden = current_hidden + thinking_output["final_logits"].detach() * 0.1 |
| |
|
| | total_reasoning_tokens += input_ids.size(1) |
| |
|
| | |
| | if total_reasoning_tokens >= self.thinking_config.token_budget: |
| | break |
| |
|
| | |
| | if thinking_results: |
| | |
| | final_logits = thinking_output["final_logits"] |
| | else: |
| | final_logits = base_logits |
| |
|
| | return { |
| | "logits": final_logits, |
| | "thinking_results": thinking_results, |
| | "final_tokens": total_reasoning_tokens, |
| | "visualization_data": visualization_history |
| | } |
| |
|
| |
|
| | def create_compact_model(model_size: str = "small") -> CompactAIModel: |
| | """Create a compact AI model with the specified size.""" |
| |
|
| | if model_size == "tiny": |
| | model_config = ModelConfig( |
| | model_size="tiny", |
| | dim=256, |
| | layers=8, |
| | heads=8, |
| | ) |
| | elif model_size == "small": |
| | model_config = ModelConfig( |
| | model_size="small", |
| | dim=512, |
| | layers=12, |
| | heads=8, |
| | ) |
| | elif model_size == "medium": |
| | model_config = ModelConfig( |
| | model_size="medium", |
| | dim=768, |
| | layers=16, |
| | heads=12, |
| | ) |
| | else: |
| | raise ValueError(f"Unknown model size: {model_size}") |
| |
|
| | thinking_config = InterleavedThinkingConfig() |
| | return CompactAIModel(model_config, thinking_config) |