|
|
"""
|
|
|
COGNITIVE-CORE: Reusable Cognitive Modules
|
|
|
===========================================
|
|
|
|
|
|
Complete library of cognitive modules that can be composed to build
|
|
|
any cognitive model: vision, language, world model, multimodal, etc.
|
|
|
|
|
|
All modules are agnostic and can be configured for different use cases.
|
|
|
|
|
|
Copyright © 2026 Mike Amega (Logo) - Ame Web Studio
|
|
|
License: Proprietary - All Rights Reserved
|
|
|
"""
|
|
|
|
|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
|
from collections import deque
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
from .cognitive_base import CognitiveConfig, CognitiveModule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
"""Root Mean Square Layer Normalization - More efficient than LayerNorm."""
|
|
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
|
|
return x / rms * self.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
|
"""Rotary Position Embedding (RoPE) with scaling support."""
|
|
|
|
|
|
def __init__(
|
|
|
self, dim: int, max_seq_len: int = 4096, base: int = 10000, scaling: float = 1.0
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.scaling = scaling
|
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
|
|
|
t = torch.arange(max_seq_len).float() / scaling
|
|
|
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
|
emb = torch.cat([freqs, freqs], dim=-1)
|
|
|
self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
|
|
|
self.register_buffer("sin_cache", emb.sin()[None, None, :, :])
|
|
|
|
|
|
def forward(
|
|
|
self, q: torch.Tensor, k: torch.Tensor, seq_len: int, offset: int = 0
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
cos = self.cos_cache[:, :, offset : offset + seq_len, :].to(q.dtype)
|
|
|
sin = self.sin_cache[:, :, offset : offset + seq_len, :].to(q.dtype)
|
|
|
q_rot = (q * cos) + (self._rotate_half(q) * sin)
|
|
|
k_rot = (k * cos) + (self._rotate_half(k) * sin)
|
|
|
return q_rot, k_rot
|
|
|
|
|
|
def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
|
|
return torch.cat([-x2, x1], dim=-1)
|
|
|
|
|
|
|
|
|
class SinusoidalPositionalEncoding(nn.Module):
|
|
|
"""Classical sinusoidal positional encoding."""
|
|
|
|
|
|
def __init__(self, d_model: int, max_seq_len: int = 4096, dropout: float = 0.1):
|
|
|
super().__init__()
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
pe = torch.zeros(max_seq_len, d_model)
|
|
|
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
|
|
|
div_term = torch.exp(
|
|
|
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
|
|
)
|
|
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
pe = pe.unsqueeze(0)
|
|
|
|
|
|
self.register_buffer("pe", pe)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
x = x + self.pe[:, : x.size(1)]
|
|
|
return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroupedQueryAttention(nn.Module):
|
|
|
"""Grouped Query Attention (GQA) with RoPE and KV-Cache support."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
n_heads: int = 8,
|
|
|
n_kv_heads: int = 4,
|
|
|
max_seq_len: int = 4096,
|
|
|
dropout: float = 0.1,
|
|
|
use_rope: bool = True,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.n_heads = n_heads
|
|
|
self.n_kv_heads = n_kv_heads
|
|
|
self.head_dim = d_model // n_heads
|
|
|
self.n_rep = n_heads // n_kv_heads
|
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
|
|
self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
|
|
|
self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
|
|
|
self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
|
|
|
self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.rope = RotaryEmbedding(self.head_dim, max_seq_len) if use_rope else None
|
|
|
|
|
|
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
if self.n_rep == 1:
|
|
|
return x
|
|
|
B, n_kv, T, D = x.shape
|
|
|
return (
|
|
|
x[:, :, None, :, :]
|
|
|
.expand(B, n_kv, self.n_rep, T, D)
|
|
|
.reshape(B, self.n_heads, T, D)
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: torch.Tensor,
|
|
|
mask: Optional[torch.Tensor] = None,
|
|
|
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
use_cache: bool = False,
|
|
|
) -> Tuple[torch.Tensor, Optional[Tuple]]:
|
|
|
B, T, C = x.shape
|
|
|
|
|
|
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
|
|
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
offset = 0
|
|
|
if kv_cache is not None:
|
|
|
k_cache, v_cache = kv_cache
|
|
|
offset = k_cache.size(2)
|
|
|
k = torch.cat([k_cache, k], dim=2)
|
|
|
v = torch.cat([v_cache, v], dim=2)
|
|
|
|
|
|
if self.rope is not None:
|
|
|
q, _ = self.rope(q, q, T, offset)
|
|
|
_, k = self.rope(k, k, k.size(2), 0)
|
|
|
|
|
|
k = self._repeat_kv(k)
|
|
|
v = self._repeat_kv(v)
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
if mask is not None:
|
|
|
attn = attn.masked_fill(mask == 0, float("-inf"))
|
|
|
|
|
|
attn = F.softmax(attn, dim=-1)
|
|
|
attn = self.dropout(attn)
|
|
|
|
|
|
out = (attn @ v).transpose(1, 2).reshape(B, T, -1)
|
|
|
out = self.o_proj(out)
|
|
|
|
|
|
new_cache = None
|
|
|
if use_cache:
|
|
|
k_to_cache = (
|
|
|
self.k_proj(x)
|
|
|
.view(B, T, self.n_kv_heads, self.head_dim)
|
|
|
.transpose(1, 2)
|
|
|
)
|
|
|
v_to_cache = (
|
|
|
self.v_proj(x)
|
|
|
.view(B, T, self.n_kv_heads, self.head_dim)
|
|
|
.transpose(1, 2)
|
|
|
)
|
|
|
if kv_cache is not None:
|
|
|
k_to_cache = torch.cat([kv_cache[0], k_to_cache], dim=2)
|
|
|
v_to_cache = torch.cat([kv_cache[1], v_to_cache], dim=2)
|
|
|
new_cache = (k_to_cache, v_to_cache)
|
|
|
|
|
|
return out, new_cache
|
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
|
"""Cross-attention for multimodal fusion."""
|
|
|
|
|
|
def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
|
|
|
super().__init__()
|
|
|
self.n_heads = n_heads
|
|
|
self.head_dim = d_model // n_heads
|
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
query: torch.Tensor,
|
|
|
key_value: torch.Tensor,
|
|
|
mask: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
|
B, T, C = query.shape
|
|
|
_, S, _ = key_value.shape
|
|
|
|
|
|
q = self.q_proj(query).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
k = (
|
|
|
self.k_proj(key_value)
|
|
|
.view(B, S, self.n_heads, self.head_dim)
|
|
|
.transpose(1, 2)
|
|
|
)
|
|
|
v = (
|
|
|
self.v_proj(key_value)
|
|
|
.view(B, S, self.n_heads, self.head_dim)
|
|
|
.transpose(1, 2)
|
|
|
)
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
if mask is not None:
|
|
|
attn = attn.masked_fill(mask == 0, float("-inf"))
|
|
|
|
|
|
attn = F.softmax(attn, dim=-1)
|
|
|
attn = self.dropout(attn)
|
|
|
|
|
|
out = (attn @ v).transpose(1, 2).reshape(B, T, -1)
|
|
|
return self.o_proj(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwiGLU(nn.Module):
|
|
|
"""SwiGLU activation - better than GELU for transformers."""
|
|
|
|
|
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
|
|
super().__init__()
|
|
|
hidden = int(d_ff * 2 / 3)
|
|
|
hidden = ((hidden + 63) // 64) * 64
|
|
|
|
|
|
self.w1 = nn.Linear(d_model, hidden, bias=False)
|
|
|
self.w2 = nn.Linear(hidden, d_model, bias=False)
|
|
|
self.w3 = nn.Linear(d_model, hidden, bias=False)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
"""Standard MLP with GELU activation."""
|
|
|
|
|
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(d_model, d_ff),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(d_ff, d_model),
|
|
|
nn.Dropout(dropout),
|
|
|
)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Expert(nn.Module):
|
|
|
"""Single expert module."""
|
|
|
|
|
|
def __init__(self, d_model: int, d_ff: int, expert_type: str = "general"):
|
|
|
super().__init__()
|
|
|
self.expert_type = expert_type
|
|
|
self.ffn = SwiGLU(d_model, d_ff)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
return self.ffn(x)
|
|
|
|
|
|
|
|
|
class SparseMoE(nn.Module):
|
|
|
"""Sparse Mixture of Experts with Top-K routing."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
d_ff: int,
|
|
|
num_experts: int = 8,
|
|
|
top_k: int = 2,
|
|
|
expert_types: Optional[List[str]] = None,
|
|
|
aux_loss_weight: float = 0.01,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.num_experts = num_experts
|
|
|
self.top_k = top_k
|
|
|
self.aux_loss_weight = aux_loss_weight
|
|
|
|
|
|
if expert_types is None:
|
|
|
expert_types = ["general"]
|
|
|
|
|
|
self.router = nn.Linear(d_model, num_experts, bias=False)
|
|
|
self.experts = nn.ModuleList(
|
|
|
[
|
|
|
Expert(d_model, d_ff, expert_types[i % len(expert_types)])
|
|
|
for i in range(num_experts)
|
|
|
]
|
|
|
)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
B, T, C = x.shape
|
|
|
x_flat = x.view(-1, C)
|
|
|
|
|
|
router_logits = self.router(x_flat)
|
|
|
topk_weights, topk_indices = torch.topk(
|
|
|
F.softmax(router_logits, dim=-1), self.top_k, dim=-1
|
|
|
)
|
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
|
|
|
output = torch.zeros_like(x_flat)
|
|
|
|
|
|
for i, expert in enumerate(self.experts):
|
|
|
mask = (topk_indices == i).any(dim=-1)
|
|
|
if not mask.any():
|
|
|
continue
|
|
|
expert_weight = torch.where(
|
|
|
topk_indices == i, topk_weights, torch.zeros_like(topk_weights)
|
|
|
).sum(dim=-1)
|
|
|
expert_out = expert(x_flat[mask])
|
|
|
output[mask] += expert_out * expert_weight[mask].unsqueeze(-1)
|
|
|
|
|
|
|
|
|
router_probs = F.softmax(router_logits, dim=-1)
|
|
|
expert_usage = router_probs.mean(dim=0)
|
|
|
aux_loss = (
|
|
|
self.num_experts
|
|
|
* (expert_usage * expert_usage).sum()
|
|
|
* self.aux_loss_weight
|
|
|
)
|
|
|
|
|
|
return output.view(B, T, C), aux_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ContrastiveLPOL(CognitiveModule):
|
|
|
"""
|
|
|
LPOL Memory System with configurable knowledge domains.
|
|
|
Uses contrastive learning for memory retrieval.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
config: CognitiveConfig,
|
|
|
domains: Optional[List[str]] = None,
|
|
|
slots_per_domain: int = 512,
|
|
|
retrieval_k: int = 8,
|
|
|
):
|
|
|
super().__init__(config)
|
|
|
|
|
|
if domains is None:
|
|
|
domains = [
|
|
|
"semantic",
|
|
|
"episodic",
|
|
|
"procedural",
|
|
|
"spatial",
|
|
|
"temporal",
|
|
|
"causal",
|
|
|
"social",
|
|
|
"emotional",
|
|
|
"conceptual",
|
|
|
]
|
|
|
|
|
|
self.domains = domains
|
|
|
self.k = retrieval_k
|
|
|
|
|
|
self.memories = nn.ParameterDict(
|
|
|
{
|
|
|
domain: nn.Parameter(torch.randn(slots_per_domain, d_model) * 0.01)
|
|
|
for domain in domains
|
|
|
}
|
|
|
)
|
|
|
|
|
|
self.domain_clf = nn.Sequential(
|
|
|
nn.Linear(d_model, len(domains) * 2),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(len(domains) * 2, len(domains)),
|
|
|
)
|
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model)
|
|
|
self.k_proj = nn.Linear(d_model, d_model)
|
|
|
self.v_proj = nn.Linear(d_model, d_model)
|
|
|
self.out_proj = nn.Linear(d_model * 2, d_model)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
B, T, C = x.shape
|
|
|
|
|
|
domain_probs = F.softmax(self.domain_clf(x.mean(dim=1)), dim=-1)
|
|
|
all_mem = torch.cat([self.memories[d] for d in self.domains], dim=0)
|
|
|
|
|
|
q = self.q_proj(x)
|
|
|
k = self.k_proj(all_mem)
|
|
|
v = self.v_proj(all_mem)
|
|
|
|
|
|
sim = torch.matmul(q, k.T) / math.sqrt(C)
|
|
|
topk_sim, topk_idx = torch.topk(sim, min(self.k, all_mem.size(0)), dim=-1)
|
|
|
weights = F.softmax(topk_sim, dim=-1)
|
|
|
retrieved = (weights.unsqueeze(-1) * v[topk_idx]).sum(dim=2)
|
|
|
output = self.out_proj(torch.cat([x, retrieved], dim=-1))
|
|
|
|
|
|
return {
|
|
|
"output": output,
|
|
|
"domain_probs": domain_probs,
|
|
|
"retrieval_weights": weights,
|
|
|
}
|
|
|
|
|
|
def reset_state(self):
|
|
|
pass
|
|
|
|
|
|
def update_memory(self, x: torch.Tensor, domain: str, lr: float = 0.01):
|
|
|
"""Online memory update."""
|
|
|
if domain in self.memories:
|
|
|
with torch.no_grad():
|
|
|
mem = self.memories[domain]
|
|
|
sim = F.cosine_similarity(
|
|
|
x.mean(dim=1, keepdim=True), mem.unsqueeze(0), dim=-1
|
|
|
)
|
|
|
_, idx = sim.min(dim=-1)
|
|
|
mem[idx] = (1 - lr) * mem[idx] + lr * x.mean(dim=1)
|
|
|
|
|
|
|
|
|
class MultiScaleMemory(CognitiveModule):
|
|
|
"""Short-term and long-term memory with consolidation."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
config: CognitiveConfig,
|
|
|
short_term_dim: int = 512,
|
|
|
long_term_dim: int = 256,
|
|
|
st_decay: float = 0.95,
|
|
|
lt_decay: float = 0.99,
|
|
|
consolidation_threshold: float = 0.7,
|
|
|
):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.st_decay = st_decay
|
|
|
self.lt_decay = lt_decay
|
|
|
self.consolidation_threshold = consolidation_threshold
|
|
|
|
|
|
|
|
|
self.st_compress = nn.Sequential(
|
|
|
nn.Linear(d_model, short_term_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(short_term_dim, short_term_dim),
|
|
|
)
|
|
|
self.st_gate = nn.GRUCell(short_term_dim, short_term_dim)
|
|
|
|
|
|
|
|
|
self.consolidation = nn.Sequential(
|
|
|
nn.Linear(short_term_dim + long_term_dim, 256),
|
|
|
nn.SiLU(),
|
|
|
nn.Linear(256, 1),
|
|
|
nn.Sigmoid(),
|
|
|
)
|
|
|
self.st_to_lt = nn.Linear(short_term_dim, long_term_dim)
|
|
|
self.lt_gate = nn.GRUCell(long_term_dim, long_term_dim)
|
|
|
|
|
|
|
|
|
self.fusion = nn.Sequential(
|
|
|
nn.Linear(short_term_dim + long_term_dim, d_model), nn.Tanh()
|
|
|
)
|
|
|
|
|
|
|
|
|
self.register_buffer("st_state", torch.zeros(1, short_term_dim))
|
|
|
self.register_buffer("lt_state", torch.zeros(1, long_term_dim))
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
B = x.size(0)
|
|
|
h_compressed = self.st_compress(x.mean(dim=1))
|
|
|
|
|
|
st_prev = self.st_state.expand(B, -1)
|
|
|
st_new = self.st_decay * st_prev + (1 - self.st_decay) * self.st_gate(
|
|
|
h_compressed, st_prev
|
|
|
)
|
|
|
|
|
|
lt_prev = self.lt_state.expand(B, -1)
|
|
|
consolidation_score = self.consolidation(torch.cat([st_new, lt_prev], dim=-1))
|
|
|
|
|
|
if (consolidation_score > self.consolidation_threshold).any():
|
|
|
lt_input = self.st_to_lt(st_new)
|
|
|
lt_new = self.lt_decay * lt_prev + (1 - self.lt_decay) * self.lt_gate(
|
|
|
lt_input, lt_prev
|
|
|
)
|
|
|
else:
|
|
|
lt_new = lt_prev
|
|
|
|
|
|
self.st_state = st_new[:1].detach()
|
|
|
self.lt_state = lt_new[:1].detach()
|
|
|
|
|
|
fused = self.fusion(torch.cat([st_new, lt_new], dim=-1))
|
|
|
|
|
|
return {
|
|
|
"st": st_new,
|
|
|
"lt": lt_new,
|
|
|
"fused": fused,
|
|
|
"consolidation_score": consolidation_score.mean().item(),
|
|
|
}
|
|
|
|
|
|
def reset_state(self):
|
|
|
self.st_state.zero_()
|
|
|
self.lt_state.zero_()
|
|
|
|
|
|
|
|
|
class EpisodicMemory(CognitiveModule):
|
|
|
"""Episodic memory for experience storage and retrieval."""
|
|
|
|
|
|
def __init__(self, d_model: int, config: CognitiveConfig, max_episodes: int = 1000):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.encoder = nn.Sequential(
|
|
|
nn.Linear(d_model, d_model // 2),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(d_model // 2, d_model),
|
|
|
)
|
|
|
|
|
|
self.register_buffer("episodes", torch.zeros(max_episodes, d_model))
|
|
|
self.register_buffer("count", torch.tensor(0))
|
|
|
self.max = max_episodes
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
encoded = self.encoder(x)
|
|
|
return {"encoded": encoded}
|
|
|
|
|
|
def store(self, x: torch.Tensor):
|
|
|
"""Store an experience."""
|
|
|
with torch.no_grad():
|
|
|
idx = self.count.item() % self.max
|
|
|
self.episodes[idx] = x.mean(dim=(0, 1)) if x.dim() == 3 else x.mean(dim=0)
|
|
|
self.count += 1
|
|
|
|
|
|
def retrieve(self, query: torch.Tensor, k: int = 5) -> torch.Tensor:
|
|
|
"""Retrieve k most similar episodes."""
|
|
|
n = min(self.count.item(), self.max)
|
|
|
if n == 0:
|
|
|
return torch.zeros_like(query)
|
|
|
|
|
|
episodes = self.episodes[:n]
|
|
|
sim = F.cosine_similarity(query.unsqueeze(1), episodes.unsqueeze(0), dim=-1)
|
|
|
_, indices = sim.topk(min(k, n), dim=-1)
|
|
|
return episodes[indices].mean(dim=1)
|
|
|
|
|
|
def reset_state(self):
|
|
|
self.count.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WorldBuffer(CognitiveModule):
|
|
|
"""Single domain world buffer with state prediction."""
|
|
|
|
|
|
def __init__(self, d_model: int, config: CognitiveConfig, domain: str = "physical"):
|
|
|
super().__init__(config)
|
|
|
self.domain = domain
|
|
|
|
|
|
state_dim = getattr(config, "world_state_dim", 256)
|
|
|
|
|
|
self.encoder = nn.Sequential(
|
|
|
nn.Linear(d_model, state_dim), nn.GELU(), nn.Linear(state_dim, state_dim)
|
|
|
)
|
|
|
|
|
|
self.dynamics = nn.GRUCell(state_dim, state_dim)
|
|
|
|
|
|
self.predictor = nn.Sequential(
|
|
|
nn.Linear(state_dim, state_dim), nn.Tanh(), nn.Linear(state_dim, state_dim)
|
|
|
)
|
|
|
|
|
|
self.register_buffer("state", torch.zeros(1, state_dim))
|
|
|
self.register_buffer("prediction", torch.zeros(1, state_dim))
|
|
|
self.register_buffer("surprise", torch.tensor(0.0))
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
if x.dim() == 3:
|
|
|
x = x.mean(dim=1)
|
|
|
|
|
|
encoded = self.encoder(x)
|
|
|
|
|
|
|
|
|
if self.prediction.norm() > 0:
|
|
|
surprise = F.mse_loss(
|
|
|
encoded, self.prediction.expand(encoded.size(0), -1)
|
|
|
).item()
|
|
|
else:
|
|
|
surprise = 0.0
|
|
|
|
|
|
self.surprise = torch.tensor(surprise)
|
|
|
|
|
|
|
|
|
new_state = self.dynamics(encoded, self.state.expand(encoded.size(0), -1))
|
|
|
update_rate = getattr(self.config, "world_update_rate", 0.1)
|
|
|
self.state = (
|
|
|
update_rate * new_state[:1] + (1 - update_rate) * self.state
|
|
|
).detach()
|
|
|
self.prediction = self.predictor(self.state).detach()
|
|
|
|
|
|
return {"surprise": surprise, "state": new_state}
|
|
|
|
|
|
def reset_state(self):
|
|
|
self.state.zero_()
|
|
|
self.prediction.zero_()
|
|
|
self.surprise.zero_()
|
|
|
|
|
|
|
|
|
class MultiWorldBuffer(CognitiveModule):
|
|
|
"""Multi-domain world model buffers."""
|
|
|
|
|
|
def __init__(
|
|
|
self, d_model: int, config: CognitiveConfig, domains: Optional[List[str]] = None
|
|
|
):
|
|
|
super().__init__(config)
|
|
|
|
|
|
if domains is None:
|
|
|
domains = ["physical", "social", "abstract", "temporal"]
|
|
|
|
|
|
self.world_buffers = nn.ModuleDict(
|
|
|
{d: WorldBuffer(d_model, config, d) for d in domains}
|
|
|
)
|
|
|
self.register_buffer("aggregate_surprise", torch.tensor(0.0))
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
results = {}
|
|
|
total_surprise = 0.0
|
|
|
|
|
|
for domain, buffer in self.world_buffers.items():
|
|
|
result = buffer(x)
|
|
|
results[domain] = result
|
|
|
total_surprise += result["surprise"]
|
|
|
|
|
|
self.aggregate_surprise = torch.tensor(total_surprise / len(self.world_buffers))
|
|
|
|
|
|
return {
|
|
|
"domain_results": results,
|
|
|
"aggregate_surprise": self.aggregate_surprise.item(),
|
|
|
}
|
|
|
|
|
|
def reset_state(self):
|
|
|
for buffer in self.world_buffers.values():
|
|
|
buffer.reset_state()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NonVerbalTension(nn.Module):
|
|
|
"""Tracks prediction error as internal tension signal."""
|
|
|
|
|
|
def __init__(self, integration_rate: float = 0.1, buffer_size: int = 100):
|
|
|
super().__init__()
|
|
|
self.integration_rate = integration_rate
|
|
|
self.register_buffer("prediction_errors", torch.zeros(buffer_size))
|
|
|
self.register_buffer("error_idx", torch.tensor(0))
|
|
|
self.register_buffer("integrated_tension", torch.tensor(0.0))
|
|
|
|
|
|
def update(self, pred: torch.Tensor, actual: torch.Tensor):
|
|
|
with torch.no_grad():
|
|
|
error = F.mse_loss(pred.float(), actual.float()).item()
|
|
|
idx = self.error_idx.item() % len(self.prediction_errors)
|
|
|
self.prediction_errors[idx] = error
|
|
|
self.error_idx += 1
|
|
|
|
|
|
def integrate(self) -> float:
|
|
|
n = min(self.error_idx.item(), len(self.prediction_errors))
|
|
|
if n > 0:
|
|
|
raw = self.prediction_errors[:n].mean().item()
|
|
|
self.integrated_tension = (
|
|
|
1 - self.integration_rate
|
|
|
) * self.integrated_tension + self.integration_rate * raw
|
|
|
return self.integrated_tension.item()
|
|
|
|
|
|
|
|
|
class InternalState(CognitiveModule):
|
|
|
"""Complete internal cognitive state tracker."""
|
|
|
|
|
|
def __init__(self, d_model: int, config: CognitiveConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
internal_dim = getattr(config, "internal_state_dim", 128)
|
|
|
latent_dim = getattr(config, "latent_state_dim", 768)
|
|
|
|
|
|
self.tension = NonVerbalTension()
|
|
|
|
|
|
self.encoder = nn.Sequential(nn.Linear(latent_dim, internal_dim), nn.Tanh())
|
|
|
|
|
|
self.register_buffer("discomfort", torch.zeros(1, internal_dim))
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
fused: torch.Tensor,
|
|
|
pred: Optional[torch.Tensor] = None,
|
|
|
actual: Optional[torch.Tensor] = None,
|
|
|
**kwargs,
|
|
|
) -> Dict[str, Any]:
|
|
|
if pred is not None and actual is not None:
|
|
|
self.tension.update(pred, actual)
|
|
|
|
|
|
tension = self.tension.integrate()
|
|
|
|
|
|
encoded = self.encoder(fused)
|
|
|
if encoded.dim() == 3:
|
|
|
encoded = encoded.mean(dim=1)
|
|
|
|
|
|
self.discomfort = 0.9 * self.discomfort + 0.1 * encoded[:1].detach()
|
|
|
|
|
|
return {
|
|
|
"tension": tension,
|
|
|
"discomfort": self.discomfort,
|
|
|
"encoded_state": encoded,
|
|
|
}
|
|
|
|
|
|
def reset_state(self):
|
|
|
self.discomfort.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DreamPhase(CognitiveModule):
|
|
|
"""Dream phase for memory consolidation."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
config: CognitiveConfig,
|
|
|
buffer_size: int = 256,
|
|
|
dream_threshold: float = 0.7,
|
|
|
):
|
|
|
super().__init__(config)
|
|
|
|
|
|
internal_dim = getattr(config, "internal_state_dim", 128)
|
|
|
|
|
|
self.buffer = deque(maxlen=buffer_size)
|
|
|
self.is_dreaming = False
|
|
|
self.dream_steps = 0
|
|
|
self.dream_threshold = dream_threshold
|
|
|
self.total_dreams = 0
|
|
|
|
|
|
self.consolidator = nn.Sequential(
|
|
|
nn.Linear(internal_dim, internal_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(internal_dim, internal_dim),
|
|
|
nn.Tanh(),
|
|
|
)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
return {"is_dreaming": self.is_dreaming, "dream_steps": self.dream_steps}
|
|
|
|
|
|
def record(self, state: torch.Tensor, tension: float):
|
|
|
"""Record state for potential dream consolidation."""
|
|
|
self.buffer.append((state.detach().cpu(), tension))
|
|
|
|
|
|
def should_dream(self) -> bool:
|
|
|
if len(self.buffer) < 10:
|
|
|
return False
|
|
|
recent = [t for _, t in list(self.buffer)[-10:]]
|
|
|
return sum(recent) / len(recent) > self.dream_threshold
|
|
|
|
|
|
def enter_dream(self):
|
|
|
self.is_dreaming = True
|
|
|
self.dream_steps = 0
|
|
|
self.total_dreams += 1
|
|
|
|
|
|
def dream_step(self, identity: torch.Tensor) -> Optional[torch.Tensor]:
|
|
|
"""Execute one dream consolidation step."""
|
|
|
if not self.is_dreaming or len(self.buffer) == 0:
|
|
|
return None
|
|
|
|
|
|
self.dream_steps += 1
|
|
|
|
|
|
|
|
|
idx = torch.randint(0, len(self.buffer), (1,)).item()
|
|
|
state, _ = self.buffer[idx]
|
|
|
state = state.to(identity.device)
|
|
|
|
|
|
|
|
|
consolidated = self.consolidator(state)
|
|
|
|
|
|
|
|
|
if self.dream_steps > 50:
|
|
|
self.is_dreaming = False
|
|
|
|
|
|
return consolidated
|
|
|
|
|
|
def reset_state(self):
|
|
|
self.buffer.clear()
|
|
|
self.is_dreaming = False
|
|
|
self.dream_steps = 0
|
|
|
|
|
|
|
|
|
class SelfTrace(CognitiveModule):
|
|
|
"""Identity tracking across time."""
|
|
|
|
|
|
def __init__(self, d_model: int, config: CognitiveConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
internal_dim = getattr(config, "internal_state_dim", 128)
|
|
|
|
|
|
self.register_buffer("identity", torch.zeros(1, internal_dim))
|
|
|
self.register_buffer("n_traces", torch.tensor(0))
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
return {"identity": self.identity, "n_traces": self.n_traces.item()}
|
|
|
|
|
|
def record(self, state: torch.Tensor, tension: float):
|
|
|
"""Update identity based on state and tension."""
|
|
|
with torch.no_grad():
|
|
|
if state.dim() > 2:
|
|
|
state = state.mean(dim=1)
|
|
|
|
|
|
|
|
|
weight = min(0.1, 0.01 * max(1.0, tension))
|
|
|
self.identity = (1 - weight) * self.identity + weight * state[:1]
|
|
|
self.n_traces += 1
|
|
|
|
|
|
def get_identity(self) -> torch.Tensor:
|
|
|
return self.identity
|
|
|
|
|
|
def reset_state(self):
|
|
|
self.identity.zero_()
|
|
|
self.n_traces.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NeurogenesisLayer(CognitiveModule):
|
|
|
"""Layer with dynamic neuron birth/death based on usage."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
input_dim: int,
|
|
|
n_neurons: int,
|
|
|
config: CognitiveConfig,
|
|
|
max_neurons: int = 256,
|
|
|
usage_decay: float = 0.99,
|
|
|
birth_threshold: float = 0.8,
|
|
|
death_threshold: float = 0.01,
|
|
|
):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.input_dim = input_dim
|
|
|
self.max_neurons = max_neurons
|
|
|
self.usage_decay = usage_decay
|
|
|
self.birth_threshold = birth_threshold
|
|
|
self.death_threshold = death_threshold
|
|
|
|
|
|
self.weights = nn.Parameter(torch.randn(max_neurons, input_dim) * 0.02)
|
|
|
self.bias = nn.Parameter(torch.zeros(max_neurons))
|
|
|
|
|
|
self.register_buffer("n_neurons", torch.tensor(n_neurons))
|
|
|
self.register_buffer("usage", torch.ones(max_neurons))
|
|
|
self.register_buffer("lifetime", torch.zeros(max_neurons))
|
|
|
self.register_buffer("births", torch.tensor(0))
|
|
|
self.register_buffer("deaths", torch.tensor(0))
|
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
n = self.n_neurons.item()
|
|
|
out = torch.tanh(F.linear(x, self.weights[:n], self.bias[:n]))
|
|
|
|
|
|
with torch.no_grad():
|
|
|
activation = out.abs().mean(dim=0) if out.dim() > 1 else out.abs()
|
|
|
if activation.size(-1) >= n:
|
|
|
self.usage[:n] = (
|
|
|
self.usage_decay * self.usage[:n]
|
|
|
+ (1 - self.usage_decay) * activation[..., :n].mean(dim=0)
|
|
|
if activation.dim() > 1
|
|
|
else activation[:n]
|
|
|
)
|
|
|
self.lifetime[:n] += 1
|
|
|
|
|
|
return {
|
|
|
"output": out,
|
|
|
"n_neurons": n,
|
|
|
"avg_usage": self.usage[:n].mean().item(),
|
|
|
}
|
|
|
|
|
|
def maybe_birth(self, coherence: float) -> bool:
|
|
|
"""Try to add a neuron if coherence is high."""
|
|
|
n = self.n_neurons.item()
|
|
|
if coherence > self.birth_threshold and n < self.max_neurons:
|
|
|
with torch.no_grad():
|
|
|
nn.init.normal_(self.weights[n], std=0.02)
|
|
|
self.bias[n] = 0
|
|
|
self.usage[n] = 1.0
|
|
|
self.lifetime[n] = 0
|
|
|
self.n_neurons += 1
|
|
|
self.births += 1
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
def maybe_death(self) -> int:
|
|
|
"""Remove underused neurons."""
|
|
|
n = self.n_neurons.item()
|
|
|
if n <= 8:
|
|
|
return 0
|
|
|
|
|
|
dead = 0
|
|
|
with torch.no_grad():
|
|
|
for i in range(n - 1, 7, -1):
|
|
|
if self.usage[i] < self.death_threshold and self.lifetime[i] > 100:
|
|
|
|
|
|
last = self.n_neurons.item() - 1
|
|
|
if i < last:
|
|
|
self.weights.data[i] = self.weights.data[last]
|
|
|
self.bias.data[i] = self.bias.data[last]
|
|
|
self.usage[i] = self.usage[last]
|
|
|
self.lifetime[i] = self.lifetime[last]
|
|
|
self.n_neurons -= 1
|
|
|
self.deaths += 1
|
|
|
dead += 1
|
|
|
return dead
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
n = self.n_neurons.item()
|
|
|
return {
|
|
|
"total_neurons": n,
|
|
|
"births": self.births.item(),
|
|
|
"deaths": self.deaths.item(),
|
|
|
"avg_usage": self.usage[:n].mean().item() if n > 0 else 0,
|
|
|
}
|
|
|
|
|
|
def reset_state(self):
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EARCPModule(CognitiveModule):
|
|
|
"""
|
|
|
Ensemble Auto-Regulated Coherence Protocol.
|
|
|
Compresses hidden states and regulates information flow.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, d_model: int, config: CognitiveConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
latent_dim = getattr(config, "latent_state_dim", 768)
|
|
|
d_ff = getattr(config, "d_ff", 2048)
|
|
|
|
|
|
self.compress = nn.Sequential(
|
|
|
nn.Linear(d_model, (d_model + latent_dim) // 2),
|
|
|
nn.SiLU(),
|
|
|
nn.Linear((d_model + latent_dim) // 2, latent_dim),
|
|
|
)
|
|
|
|
|
|
self.state_gate = nn.Linear(latent_dim * 2, latent_dim)
|
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model)
|
|
|
self.k_proj = nn.Linear(latent_dim, d_model)
|
|
|
self.v_proj = nn.Linear(latent_dim, d_model)
|
|
|
self.out_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
|
self.coherence_proc = nn.Sequential(
|
|
|
nn.Linear(d_model, d_ff), nn.SiLU(), nn.Linear(d_ff, d_model)
|
|
|
)
|
|
|
|
|
|
|
|
|
nn.init.zeros_(self.out_proj.weight)
|
|
|
nn.init.zeros_(self.coherence_proc[-1].weight)
|
|
|
|
|
|
def forward(self, h: torch.Tensor, fused: torch.Tensor, **kwargs) -> Dict[str, Any]:
|
|
|
h_compressed = self.compress(h.mean(dim=1))
|
|
|
|
|
|
gate = torch.sigmoid(self.state_gate(torch.cat([h_compressed, fused], dim=-1)))
|
|
|
state = (1 - gate) * fused + gate * h_compressed
|
|
|
|
|
|
q = self.q_proj(h)
|
|
|
k = self.k_proj(state).unsqueeze(1)
|
|
|
v = self.v_proj(state).unsqueeze(1)
|
|
|
|
|
|
attn = F.softmax(q @ k.transpose(-2, -1) / math.sqrt(h.size(-1)), dim=-1)
|
|
|
h = h + 0.02 * self.out_proj(attn @ v)
|
|
|
h = h + 0.1 * self.coherence_proc(h)
|
|
|
|
|
|
coherence = torch.sigmoid(h.mean()).item()
|
|
|
|
|
|
return {"hidden": h, "state": state, "coherence": coherence}
|
|
|
|
|
|
def reset_state(self):
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VAEEncoder(nn.Module):
|
|
|
"""Convolutional VAE Encoder for visual inputs."""
|
|
|
|
|
|
def __init__(
|
|
|
self, in_channels: int = 3, latent_dim: int = 256, channels: List[int] = None
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
if channels is None:
|
|
|
channels = [32, 64, 128, 256]
|
|
|
|
|
|
layers = []
|
|
|
prev_c = in_channels
|
|
|
|
|
|
for c in channels:
|
|
|
layers.extend(
|
|
|
[
|
|
|
nn.Conv2d(prev_c, c, 4, 2, 1),
|
|
|
nn.BatchNorm2d(c),
|
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
|
]
|
|
|
)
|
|
|
prev_c = c
|
|
|
|
|
|
self.encoder = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
self.flat_size = channels[-1] * 4 * 4
|
|
|
|
|
|
self.fc_mu = nn.Linear(self.flat_size, latent_dim)
|
|
|
self.fc_logvar = nn.Linear(self.flat_size, latent_dim)
|
|
|
|
|
|
def forward(
|
|
|
self, x: torch.Tensor
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
h = self.encoder(x)
|
|
|
h = h.view(h.size(0), -1)
|
|
|
|
|
|
mu = self.fc_mu(h)
|
|
|
logvar = self.fc_logvar(h)
|
|
|
|
|
|
|
|
|
std = torch.exp(0.5 * logvar)
|
|
|
eps = torch.randn_like(std)
|
|
|
z = mu + eps * std
|
|
|
|
|
|
return z, mu, logvar
|
|
|
|
|
|
|
|
|
class VAEDecoder(nn.Module):
|
|
|
"""Convolutional VAE Decoder for visual outputs."""
|
|
|
|
|
|
def __init__(
|
|
|
self, latent_dim: int = 256, out_channels: int = 3, channels: List[int] = None
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
if channels is None:
|
|
|
channels = [256, 128, 64, 32]
|
|
|
|
|
|
self.fc = nn.Linear(latent_dim, channels[0] * 4 * 4)
|
|
|
self.init_channels = channels[0]
|
|
|
|
|
|
layers = []
|
|
|
for i in range(len(channels) - 1):
|
|
|
layers.extend(
|
|
|
[
|
|
|
nn.ConvTranspose2d(channels[i], channels[i + 1], 4, 2, 1),
|
|
|
nn.BatchNorm2d(channels[i + 1]),
|
|
|
nn.ReLU(inplace=True),
|
|
|
]
|
|
|
)
|
|
|
|
|
|
|
|
|
layers.extend(
|
|
|
[nn.ConvTranspose2d(channels[-1], out_channels, 4, 2, 1), nn.Sigmoid()]
|
|
|
)
|
|
|
|
|
|
self.decoder = nn.Sequential(*layers)
|
|
|
|
|
|
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
|
|
h = self.fc(z)
|
|
|
h = h.view(h.size(0), self.init_channels, 4, 4)
|
|
|
return self.decoder(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UniversalLatentSpace(CognitiveModule):
|
|
|
"""Universal Latent Space for cross-modal alignment."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
config: CognitiveConfig,
|
|
|
uls_dim: int = 1024,
|
|
|
n_anchors: int = 64,
|
|
|
):
|
|
|
super().__init__(config)
|
|
|
|
|
|
self.uls_dim = uls_dim
|
|
|
|
|
|
self.anchors = nn.Parameter(torch.randn(n_anchors, uls_dim) * 0.02)
|
|
|
|
|
|
|
|
|
self.text_to_uls = nn.Sequential(
|
|
|
nn.Linear(d_model, d_model),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(d_model, uls_dim),
|
|
|
RMSNorm(uls_dim),
|
|
|
)
|
|
|
|
|
|
self.vision_to_uls = nn.Sequential(
|
|
|
nn.Linear(d_model, d_model),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(d_model, uls_dim),
|
|
|
RMSNorm(uls_dim),
|
|
|
)
|
|
|
|
|
|
self.audio_to_uls = nn.Sequential(
|
|
|
nn.Linear(d_model, d_model),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(d_model, uls_dim),
|
|
|
RMSNorm(uls_dim),
|
|
|
)
|
|
|
|
|
|
self.uls_to_model = nn.Sequential(
|
|
|
nn.Linear(uls_dim, d_model),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(d_model, d_model),
|
|
|
RMSNorm(d_model),
|
|
|
)
|
|
|
|
|
|
self.anchor_attn = nn.MultiheadAttention(uls_dim, num_heads=4, batch_first=True)
|
|
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor], **kwargs) -> Dict[str, Any]:
|
|
|
unified_features = []
|
|
|
|
|
|
if "text" in features and features["text"] is not None:
|
|
|
unified_features.append(self.text_to_uls(features["text"]))
|
|
|
|
|
|
if "vision" in features and features["vision"] is not None:
|
|
|
unified_features.append(self.vision_to_uls(features["vision"]))
|
|
|
|
|
|
if "audio" in features and features["audio"] is not None:
|
|
|
unified_features.append(self.audio_to_uls(features["audio"]))
|
|
|
|
|
|
if not unified_features:
|
|
|
B = 1
|
|
|
device = self.anchors.device
|
|
|
unified = torch.zeros(B, 1, self.uls_dim, device=device)
|
|
|
else:
|
|
|
|
|
|
unified = torch.stack(unified_features, dim=0).mean(dim=0)
|
|
|
|
|
|
|
|
|
anchors_expanded = self.anchors.unsqueeze(0).expand(unified.size(0), -1, -1)
|
|
|
enhanced, _ = self.anchor_attn(unified, anchors_expanded, anchors_expanded)
|
|
|
enhanced = unified + 0.1 * enhanced
|
|
|
|
|
|
output = self.uls_to_model(enhanced)
|
|
|
|
|
|
return {"unified": unified, "enhanced": enhanced, "output": output}
|
|
|
|
|
|
def reset_state(self):
|
|
|
pass
|
|
|
|