SociAgentTransformer
Transformer + Mixture-of-Experts for Agent Decision Making
1.45M params | ~5.5 MB (fp32) | ~1ms inference (50 agents, ONNX)
Input
Agent State Feature Vector
(B, 47)
Tokenizer
Feature Tokenizer
Split features into 6 semantic groups, project each to d_model
Personality
[0:6] Big5 + Age
6 -> 128
Time
[6:12] sin/cos + day
6 -> 128
Needs + Mood
[12:21] 6 needs + urgency
9 -> 128
Location
[21:31] zone + flags + people
10 -> 128
Time Period
[31:38] 7-class one-hot
7 -> 128
Last Action
[38:47] 9-class one-hot
9 -> 128
+ learnable positional embeddings per token
(B, 6, 128)
Encoder
x 4
Transformer Encoder Block
Multi-Head Self-Attention
8 heads, d_k=16, batch_first=True
Q, K, V: (B, 6, 128) -> (B, 6, 128)
Add & LayerNorm
Mixture-of-Experts Feed-Forward
4 experts, top-2 routing, gated softmax
Expert 0
128->256->128
Expert 1
128->256->128
Expert 2
128->256->128
Expert 3
128->256->128
Gate: Linear(128, 4) -> top-2
Add & LayerNorm
(B, 6, 128)
Pooling
[CLS] Query Aggregation
Learned query (1, 1, 128) attends to all 6 tokens via cross-attention
cls_query -> cross_attn(Q=cls, K=tokens, V=tokens) -> LayerNorm
h: (B, 128)
Task Heads
Action Head
2-layer MLP
Linear(128, 128)
GELU + Dropout(0.1)
Linear(128, 9)
(B, 9) logits
Location Head
Action-conditioned MLP
Linear(128+9, 128)
GELU + Dropout(0.1)
Linear(128, 38)
(B, 38) logits
Duration Head
Regression MLP
Linear(137, 64)
GELU
Linear(64, 1)
sigmoid*7+1
softmax(action).detach()
Output
Action Type
9 classes: move, work,
eat, sleep, talk, ...
Target Location
38 locations: cafe,
park, office, home, ...
Duration
1-8 ticks
(15 min each)
Training
Multi-Task Loss
L = 1.0*CE_action(weighted) + 0.5*CE_location + 0.2*MSE_duration
AdamW (lr=3e-4, wd=1e-4) | CosineAnnealing | Grad clip=1.0 | 30 epochs | Batch=512
ONNX export with opset 17 | CPU inference ~1ms for 50 agents