image/png

CULO — Colaboración de Unidades Locales de Expertos

La idea es práctica, fácil de implementar sobre un backbone Transformer o MLP y con mecanismos explícitos para colaboración entre expertos locales, balance de carga y estabilidad durante el entrenamiento.

Arquitectura MoE con expertos organizados en unidades locales (clusters) y un enrutamiento dual (global + local) que favorece colaboración entre expertos cercanos, reduce la fragmentación de aprendizaje y mejora uso de memoria/latencia en despliegue.

Creado por Novaciano. 2025.

Intuición de alto nivel

En lugar de tener cientos de expertos totalmente independientes, los expertos se agrupan en unidades locales (por ejemplo por posición, por token-type, por segmento de embedding).

El enrutador tiene dos etapas: un gating global que asigna la entrada a 1–2 unidades locales relevantes, y un gating local dentro de cada unidad que selecciona 1–k expertos concretos.

Dentro de cada unidad, los expertos comparten parámetros ligeros (p. ej. una capa de normalización o un pequeño coeficiente) que permite transferencia de conocimiento entre expertos vecinos — eso es la colaboración.

Se añaden pérdidas auxiliares para balance de carga (como en MoE) y para cooperación (consistencia entre salidas de expertos de la misma unidad).

Ventajas esperadas

Mejor generalización por compartir adaptadores dentro de unidades (transferencia local).

Menor fragmentación del aprendizaje respecto a MoE puro (menos expertos que “no aprenden”).

Control de costo: enrutamiento dual permite elegir pocas unidades + pocos expertos → menor FLOPS que MoE con cientos de expertos.

Escalabilidad modular: fácil añadir unidades/expertos por dominio o por despliegue.

Arquitectura (esquemática)

Entrada x → Router global → selecciona up to U_g unidades locales → para cada unidad:

Enrutador local → selecciona up to k expertos dentro de la unidad.

Los expertos (FFN / pequeño Transformer) procesan x en paralelo.

Se aplica un mecanismo de fusión colaborativa: mezcla ponderada + capa de adaptación compartida. Salida agregada → resto del modelo.

Parámetros clave:

N_units = número de unidades locales.

E_per_unit = expertos por unidad.

k = expertos activos por entrada dentro de unidad.

U_g = unidades locales seleccionadas por entrada (p. ej. 1 o 2).

Hyperparámetros de inicio (sugeridos)

N_units = 32

E_per_unit = 8 → total expertos = 256

k = 1 (por unidad), U_g = 1 o 2

λ1 (balance units) = 0.1, λ2 (balance experts) = 0.1, λ3 (coop) = 0.01, λ4 (cost) = 0.001

Warmup epochs: 3–5 con soft routing; total 50–200 epochs según dataset.

Forward Pass

# x: batch x seq_len x dim
# router_global: map x -> logits over N_units
# router_local[u]: map x -> logits over E_per_unit

units_top = topk(router_global(x), U_g)  # indices y scores
outputs = []
for u in units_top:
    experts_top = topk(router_local[u](x), k)  # per-token or per-position
    # compute experts in parallel (sparse)
    expert_outs = [expert_u_e(x) for e in experts_top]
    # collaborative fusion (weighted sum) + unit adapter
    weights = softmax(scores_for_experts)
    fused = sum(w * out for w,out in zip(weights, expert_outs))
    fused = unit_adapter[u](fused)  # small shared layer per unit
    outputs.append(fused * unit_selection_weight[u])  # combine units
y = sum(outputs)  # final aggregated output

Notas: router_* puede operar por token o por bloque (p.ej. per sequence). unit_adapter es la parte que implementa la colaboración.

Rutina de entrenamiento recomendada

Warmup: comenzar con gating “soft” (top-k con temperatura alta) y reducir temperatura progresivamente a routing esparso duro.

Gradual sparsity: iniciar con más expertos activos por entrada y reducir a target k para estabilidad.

Expert dropout: apagar expertos aleatoriamente en early epochs para forzar redundancia y evitar colapsos.

Optimizer: AdamW; aplicar weight decay normal y LR schedule con warmup.

Batching: agrupar tokens por unidad destino (si se puede) para eficiencia en GPU/TPU.

Variantes y extensiones

  • Hierarchical-CULO: múltiples niveles de unidades (p. ej. global → regional → micro-units).

  • Temporal CULO: unidades locales por posición temporal para modelos de series temporales.

  • Adapter-sharing: en lugar de adapter por unidad, usar factors compartidos entre unidades similares (clustering dinámico).

  • Routing por contexto multimodal: usar embeddings multimodales para gating (imagen + texto).

Casos de uso

  • Modelos de lenguaje con dominios heterogéneos (legal, medicina, código): crear unidades por dominio.

  • Modelos multimodales: unidades por modalidad (texto, audio, imagen) y sub-unidades por subdominio.

  • Despliegue en edge: reducir k y U_g para bajar latencia y consumo.

Limitaciones y riesgos

  • Mayor complejidad de enrutamiento: dos etapas añade latencia de decisión.

  • Hyperparámetros sensibles: N_units, E_per_unit, k, U_g, y coeficientes de pérdida requieren tuning.

  • Peligro de colapso: si balanceo falla, pocos expertos absorberán todo. Mitigar con L_balance y expert dropout.

  • Implementación: se necesita soporte eficiente para enrutamiento esparso (p. ej. XLA/TPU o kernels especializados).

Resumen

CULO (Colaboración de Unidades Locales de Expertos) es una MoE jerárquica con enrutamiento global → local, adaptadores compartidos por unidad y pérdidas de balance y cooperación. Reduce fragmentación de expertos y controla FLOPS mientras permite especialización localizada. Ideal para dominios heterogéneos y despliegues con restricciones de latencia.


NOTEBOOK 💻

# %% [markdown]
# CULO — Demo Notebook (Research-style)
#
# Purpose: concise, research-oriented PyTorch notebook implementing the CULO layer
# (Colaboración de Unidades Locales de Expertos). Includes minimal experiments with
# synthetic data, diagnostics (unit usage, cooperation loss), and plots. CPU-first.
#
# Usage: open with Jupyter / VSCode (Jupytext-friendly). Run sequentially.

# %%
# Imports and device
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
from typing import Tuple

# Device: CPU (per user choice). If you want GPU, change device to 'cuda' manually.
DEVICE = torch.device('cpu')
print('Device:', DEVICE)

# %% [markdown]
# Model definition — terse, research-oriented comments.
# CULO: hierarchical MoE with $N_{units}$ units, each containing $E$ experts.
# Routing is token-wise. Global router selects up to $U_g$ units; local router selects up to $k$ experts per unit.

# %%
class Expert(nn.Module):
    """Simple FFN expert."""
    def __init__(self, d_model: int, d_hidden: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.GELU(),
            nn.Linear(d_hidden, d_model)
        )
    def forward(self, x):
        return self.net(x)

class CULOLayer(nn.Module):
    """Compact CULO implementation focused on clarity for research prototyping.

    Key elements:
      - global_router: token -> N_units logits
      - local_routers: per-unit token -> E logits
      - experts: per-unit ModuleList of Expert
      - unit_adapters: small shared linear per unit
      - unit_scalars: learned scalar per unit
    """
    def __init__(self, d_model: int, d_hidden: int, N_units: int = 8, E_per_unit: int = 4,
                 k:int = 1, U_g:int = 1, coop_coef: float = 0.01):
        super().__init__()
        assert 1 <= U_g <= N_units
        self.d_model = d_model
        self.N_units = N_units
        self.E_per_unit = E_per_unit
        self.k = k
        self.U_g = U_g
        self.coop_coef = coop_coef

        self.global_router = nn.Linear(d_model, N_units)
        self.local_routers = nn.ModuleList([nn.Linear(d_model, E_per_unit) for _ in range(N_units)])
        self.experts = nn.ModuleList([nn.ModuleList([Expert(d_model, d_hidden) for _ in range(E_per_unit)]) for _ in range(N_units)])
        self.unit_adapters = nn.ModuleList([nn.Sequential(nn.Linear(d_model, d_model), nn.GELU()) for _ in range(N_units)])
        self.unit_scalars = nn.Parameter(torch.ones(N_units))

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        # x: (B, S, D)
        B, S, D = x.shape
        tokens = x.reshape(B*S, D)  # (T, D)
        T = tokens.shape[0]

        # Global routing
        g_logits = self.global_router(tokens)  # (T, N)
        g_soft = F.softmax(g_logits, dim=-1)
        top_vals, top_idx = torch.topk(g_soft, self.U_g, dim=-1)  # (T, U_g)
        mask_units = torch.zeros_like(g_soft)
        mask_units.scatter_(1, top_idx, top_vals)
        unit_weights = mask_units / (mask_units.sum(dim=-1, keepdim=True) + 1e-9)
        usage_per_unit = unit_weights.mean(dim=0)  # (N,)

        # Pre-allocate
        fused_per_unit = tokens.new_zeros((T, self.N_units, D))
        coop_loss = 0.0

        # Per-unit processing
        for u in range(self.N_units):
            l_logits = self.local_routers[u](tokens)  # (T, E)
            l_soft = F.softmax(l_logits, dim=-1)
            top_e_vals, top_e_idx = torch.topk(l_soft, self.k, dim=-1)
            mask_experts = torch.zeros_like(l_soft)
            mask_experts.scatter_(1, top_e_idx, top_e_vals)
            expert_weights = mask_experts / (mask_experts.sum(dim=-1, keepdim=True) + 1e-9)

            # compute all expert outputs (small-scale demo)
            expert_outs = []
            for e in range(self.E_per_unit):
                out = self.experts[u][e](tokens)  # (T, D)
                expert_outs.append(out.unsqueeze(1))
            expert_outs = torch.cat(expert_outs, dim=1)  # (T, E, D)

            fused = (expert_weights.unsqueeze(-1) * expert_outs).sum(dim=1)  # (T, D)
            adapted = self.unit_adapters[u](fused)  # (T, D)
            scaled = adapted * self.unit_scalars[u] * unit_weights[:, u].unsqueeze(-1)
            fused_per_unit[:, u, :] = scaled

            # lightweight cooperation proxy: sample tokens and compute variance among selected experts
            # This proxy is cheap and works for prototyping; replace with more rigorous metric if needed.
            sample_n = min(32, T)
            if sample_n > 0:
                idxs = torch.randperm(T)[:sample_n]
                topk_for_sample = top_e_idx[idxs]  # (sample_n, k)
                # build mask for selected experts
                mask_s = torch.zeros((sample_n, self.E_per_unit), device=tokens.device)
                mask_s.scatter_(1, topk_for_sample, 1.0)
                mask_s = mask_s.unsqueeze(-1)  # (sample_n, E, 1)
                sampled_expert_outs = expert_outs[idxs] * mask_s  # (sample_n, E, D)
                denom = mask_s.sum(dim=1) + 1e-9
                mean_selected = sampled_expert_outs.sum(dim=1) / denom
                diffs = sampled_expert_outs - mean_selected.unsqueeze(1)
                sq = (diffs ** 2).sum(dim=-1)
                coop_loss = coop_loss + sq.mean()

        out_tokens = fused_per_unit.sum(dim=1)  # (T, D)
        out = out_tokens.view(B, S, D)

        diagnostics = {
            'usage_per_unit': usage_per_unit.detach().cpu(),
            'coop_loss': coop_loss * (self.coop_coef / max(1, self.N_units))
        }

        return out, diagnostics

# %%
class DemoModel(nn.Module):
    """Embedding -> CULO -> residual -> classifier (toy language modeling head)."""
    def __init__(self, vocab: int = 200, d_model: int = 64, d_hidden: int = 128,
                 N_units: int = 8, E_per_unit: int = 4, k: int = 1, U_g: int = 1):
        super().__init__()
        self.embed = nn.Embedding(vocab, d_model)
        self.culo = CULOLayer(d_model=d_model, d_hidden=d_hidden, N_units=N_units, E_per_unit=E_per_unit, k=k, U_g=U_g)
        self.norm = nn.LayerNorm(d_model)
        self.out_proj = nn.Linear(d_model, vocab)

    def forward(self, tokens):
        x = self.embed(tokens)
        res, diag = self.culo(x)
        x = self.norm(x + res)
        logits = self.out_proj(x)
        return logits, diag

# %% [markdown]
# Small reproducible experiment (synthetic data). Hyperparameters set for CPU-friendly run.

# %%
def run_experiment(steps=50, B=8, S=16, vocab=200,
                   d_model=64, d_hidden=128, N_units=8, E_per_unit=4, k=1, U_g=1,
                   lr=1e-3, print_every=10, seed=42):
    torch.manual_seed(seed); random.seed(seed)
    model = DemoModel(vocab=vocab, d_model=d_model, d_hidden=d_hidden, N_units=N_units, E_per_unit=E_per_unit, k=k, U_g=U_g).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    history = {'loss': [], 'total_loss': [], 'usage_var': [], 'coop': []}

    for step in range(steps):
        tokens = torch.randint(0, vocab, (B, S), device=DEVICE)
        logits, diag = model(tokens)
        target = torch.randint(0, vocab, (B, S), device=DEVICE)
        loss = loss_fn(logits.view(-1, vocab), target.view(-1))

        usage = diag['usage_per_unit'].to(DEVICE)
        l_balance = usage.var()
        l_coop = diag['coop_loss'].to(DEVICE)
        total_loss = loss + 0.1 * l_balance + 0.01 * l_coop

        opt.zero_grad(); total_loss.backward(); opt.step()

        history['loss'].append(loss.item()); history['total_loss'].append(total_loss.item())
        history['usage_var'].append(usage.var().item()); history['coop'].append(l_coop.item())

        if (step + 1) % print_every == 0 or step == 0 or step == steps-1:
            print(f"step {step+1}/{steps}: loss={loss.item():.4f}, total={total_loss.item():.4f}, usage_var={usage.var().item():.6f}")

    return model, history

# %% [markdown]
# Run the experiment (small, CPU-friendly). Adjust `steps` for longer runs.

# %%
if __name__ == '__main__':
    model, history = run_experiment(steps=60, B=8, S=16, vocab=200, d_model=64, d_hidden=128, N_units=8, E_per_unit=4, k=1, U_g=1)

    # Plot diagnostics
    fig, axs = plt.subplots(2, 2, figsize=(10, 6))
    axs = axs.flatten()
    axs[0].plot(history['loss'])
    axs[0].set_title('Task loss')
    axs[1].plot(history['total_loss'])
    axs[1].set_title('Total loss (with regularizers)')
    axs[2].plot(history['usage_var'])
    axs[2].set_title('Unit usage variance')
    axs[3].plot(history['coop'])
    axs[3].set_title('Cooperation proxy')
    plt.tight_layout()
    plt.show()

# %% [markdown]
# Notes (research style):
# - This prototype uses dense computation for all experts (computes every expert's output) and masks
#   via top-k selection. Replace with bucketed / sparse kernels for production efficiency.
# - The cooperation proxy is a heuristic; for papers consider stronger metrics (e.g., representation-distance on held-out tokens, mutual information, or task-specific agreement metrics).
# - Hyperparameters to sweep: N_units, E_per_unit, k, U_g, coop_coef, and regularization weights.
# - For reproducible experiments on GPU, set DEVICE = torch.device('cuda') and ensure deterministic flags as needed.

# End of notebook

Downloads last month
22
Safetensors
Model size
1B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for NovaCorp/CULO-MoE

Finetuned
(1)
this model
Merges
1 model
Quantizations
6 models