File size: 4,112 Bytes
9463e5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
GLADIUS v2.0 — Tool Cortex

Tools as embeddings in the same vector space as vocabulary.
Tool activation via cosine similarity threshold — no JSON, no parsing.
SLA2 hybrid attention over tool registry at every layer (via kernel).

STUB VERSION — tool registry exists but no real tools are connected.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from .config import KernelConfig


class ToolCortex(nn.Module):
    """
    Tool understanding via shared embedding space.

    Tools live in the same manifold as tokens. When a hidden state
    is close enough to a tool embedding (cosine sim > threshold),
    the tool activates. No special syntax needed.

    argmax_tool S(tool | hidden_state) where S = cosine_similarity
    """

    def __init__(self, config: KernelConfig):
        super().__init__()
        self.config = config

        # Tool embeddings: same dimension as token embeddings
        self.tool_embeddings = nn.Parameter(
            torch.randn(config.max_tools, config.hidden_dim) * 0.02
        )

        # Tool activation gate (refines raw cosine similarity)
        self.activation_gate = nn.Sequential(
            nn.Linear(config.hidden_dim * 2, config.hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(config.hidden_dim // 2, 1),
            nn.Sigmoid(),
        )

        # Tool result projection (maps tool output back to hidden space)
        self.result_proj = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)

        # Registry metadata (not learned — set at runtime)
        self.register_buffer('tool_active', torch.zeros(config.max_tools, dtype=torch.bool))
        self.num_registered = 0

    def register_grid_tools(self):
        """Register ARC grid manipulation tools (rotate, flip, fill, etc.)."""
        grid_tool_names = ['rotate', 'flip_h', 'flip_v', 'fill', 'crop', 'tile', 'recolor', 'overlay']
        for i, name in enumerate(grid_tool_names):
            if i < self.config.max_tools:
                with torch.no_grad():
                    self.tool_embeddings.data[i] = torch.randn(self.config.hidden_dim) * 0.01
                    self.tool_active[i] = True
                    self.num_registered += 1

    def register_tool(self, tool_id: int, description_embedding: torch.Tensor):
        """
        Register a tool by initializing its embedding.

        In the full system, description_embedding comes from encoding
        the tool's natural language description through the shared embeddings.
        """
        with torch.no_grad():
            self.tool_embeddings.data[tool_id] = description_embedding
            self.tool_active[tool_id] = True
            self.num_registered += 1

    def check_activation(self, hidden: torch.Tensor) -> torch.Tensor | None:
        """
        Check if any tool should activate based on hidden state similarity.

        Args:
            hidden: (batch, seq_len, hidden_dim)
        Returns:
            tool_contribution: (batch, seq_len, hidden_dim) or None
        """
        if self.num_registered == 0:
            return None

        # Pool hidden state
        pooled = hidden.mean(dim=1)  # (B, D)

        # Cosine similarity with all tool embeddings
        pooled_norm = F.normalize(pooled, dim=-1)
        tools_norm = F.normalize(self.tool_embeddings, dim=-1)
        similarities = torch.matmul(pooled_norm, tools_norm.T)  # (B, max_tools)

        # Mask inactive tools
        similarities = similarities.masked_fill(~self.tool_active.unsqueeze(0), -1.0)

        # Find best matching tool
        best_sim, best_tool = similarities.max(dim=-1)  # (B,)

        # Check threshold
        if best_sim.max().item() < self.config.tool_activation_threshold:
            return None

        # Tool activated — in full system, this would invoke the tool
        # STUB: return the tool embedding as contribution
        tool_embed = self.tool_embeddings[best_tool]  # (B, D)
        contribution = self.result_proj(tool_embed)

        # Broadcast across sequence
        return contribution.unsqueeze(1).expand_as(hidden)