KSimplex Geometric Prior for Stable Diffusion: Complete Mathematical Reference
A Cayley-Menger validated pentachoron attention skeleton for rectified flow diffusion.
Author: AbstractPhil
License: MIT
Weights: AbstractPhil/sd15-rectified-geometric-matching
Abstract
We present KSimplex, a geometric cross-attention prior that modulates CLIP conditioning before it enters the frozen UNet of a Stable Diffusion 1.5 rectified flow pipeline. The prior operates in a 4-dimensional simplex (pentachoron) coordinate space, using Cayley-Menger determinants to validate geometric configurations and enforce structural coherence. With only 4.8M trainable parameters (0.56% of the 859M frozen backbone), the prior produces measurable improvements in spatial coherence and anatomical correctness after a single epoch on 10,000 synthetic images.
This document catalogs every mathematical formula implemented in the system, from simplex construction through training loss to ODE sampling.
1. Architecture
CLIP(prompt) → [B, 77, 768]
→ KSimplexCrossAttentionPrior (4.8M params, fp32)
→ StackedKSimplexAttention (4 layers)
→ KSimplexAttentionLayer × 4
→ coordinate projection
→ soft vertex assignment
→ deformed template anchoring
→ geometric attention (distance-based)
→ CM-validated value projection
→ residual blend
→ modulated [B, 77, 768]
→ Frozen SD1.5 UNet (859M params, fp16)
→ v_predicted
All original SD1.5 weights load unmodified. Geometric parameters are purely additive.
2. Simplex Foundations
2.1 Regular Simplex Template
A k-simplex is the k-dimensional generalization of a triangle. For k=4 (pentachoron), we embed n = k+1 = 5 vertices in ℝ^{d_e} (d_e = 32).
The regular template T ∈ ℝ^{(k+1) × d_e} has unit edge lengths:
Constructed via geovocab2.shapes.factory.SimplexFactory(k=4, embed_dim=32, method="regular").
2.2 Edge Count
For n = k+1 vertices:
For k=4: |E| = 5·4/2 = 10 edges.
2.3 Stability Constraint
Empirically validated: the embedding dimension must satisfy:
For k=4, d_e=32: ratio = 8 ✓. Below this ratio, CM determinants become numerically unstable.
2.4 Pairwise Squared Distances (Vectorized)
Given vertices V ∈ ℝ^{n × d_e}:
Computed as: diff = V.unsqueeze(1) - V.unsqueeze(0), D² = (diff²).sum(dim=-1).
Upper triangle extracted via torch.triu_indices(n, n, offset=1).
3. Cayley-Menger Determinant
3.1 Bordered Distance Matrix
For n = k+1 points with pairwise squared distances d²_{ij}, the CM determinant is the determinant of the (n+1)×(n+1) bordered matrix:
Structure:
- cm[0, 0] = 0 (corner)
- cm[0, 1:] = cm[1:, 0] = 1 (border row/column)
- cm[i+1, j+1] = d²_{ij} for i ≠ j (squared distances)
- cm[i+1, i+1] = 0 (diagonal)
Computed via torch.linalg.det(cm) in float32 (half precision not supported).
3.2 Simplex Validity Condition
A set of distances defines a valid k-simplex in Euclidean space iff:
For k=4: (-1)^5 · CM < 0, so CM must be negative.
3.3 Simplex Volume from CM
The squared volume of a k-simplex:
For k=4: denom = 2^4 · (4!)^2 = 16 · 576 = 9216.
4. Deformed Template
4.1 Per-Layer Deformation
Each attention layer maintains learnable offsets ΔV ∈ ℝ^{(k+1) × d_e}, initialized as N(0, 0.01):
where the deformation scale δ is clamped:
Stability zone: δ ∈ [0.15, 0.35] empirically optimal. The clamp bounds allow the optimizer to explore slightly beyond.
4.2 Timestep-Conditioned Deformation (Phase 3)
A small MLP maps the normalized timestep t ∈ [0,1] to per-layer deformation factors:
where W₁ ∈ ℝ^{64×1}, W₂ ∈ ℝ^{L×64}, σ = sigmoid.
The effective deformation scale for layer ℓ at timestep t:
The multiplier (0.5 + f) ∈ [0.5, 1.5] can both reduce and amplify the base deformation.
Critical implementation detail: Effective scales are computed transiently — the underlying parameter δ_ℓ is never mutated during forward passes. An earlier version used .data.copy_() which caused training drift.
5. Token-to-Simplex Mapping
5.1 Coordinate Projection
Each CLIP token h_i ∈ ℝ^{768} is projected into simplex coordinate space:
where W_coord ∈ ℝ^{d_e × 768}.
5.2 Soft Vertex Assignment
Tokens are softly routed to simplex vertices:
where W_vertex ∈ ℝ^{(k+1) × 768}.
5.3 Template Anchoring
The geometric contribution for each token blends the deformed template vertices:
Batched: template_contribution = vertex_weights @ deformed # (B, 77, d_e)
5.4 Final Coordinates
Combined projection + anchor:
5.5 Unit Sphere Normalization
L2-normalized to bound pairwise squared distances to [0, 4]:
This keeps the CM determinant well-conditioned regardless of initialization scale.
6. Geometric Attention
6.1 Distance-Based Logits
Unlike standard dot-product attention, KSimplex attention derives from pairwise distances in simplex space:
The 1/√d_e scaling mirrors standard attention's 1/√d_k. Tokens that are geometrically close in simplex space attend more strongly to each other.
6.2 Attention Weights
Geometric attention is shared across all H heads. Each head has independent value projections.
6.3 Multi-Head Value Projection
where d_h = 768/8 = 96.
6.4 Residual Connection
Each layer adds to the residual stream:
7. Entropy Sharpening (Emergent Property)
7.1 Per-Layer Attention Entropy
averaged over batch and token dimensions.
7.2 Monotonic Decrease
Validated across classification, language modeling, and now diffusion:
This coarse-to-fine sharpening emerges naturally from stacked geometric constraints — no explicit entropy loss is applied. Early layers attend broadly (soft geometry), later layers attend sharply (crystallized structure).
8. Residual Blending
8.1 Learnable Blend
The geo_prior output blends with original CLIP conditioning:
Initialized: β_logit = 0.0 → β = 0.5 (equal mix). Sigmoid bounds β ∈ (0, 1).
8.2 Timestep-Conditioned Blend (Optional)
Allows the network to learn different blend strengths at different noise levels.
9. Geometric Loss Functions
9.1 CM Validity Loss
Penalizes invalid simplex configurations with a hinge loss:
where ε = 10⁻⁶ is a small margin preventing zero-volume degeneracies. Distances are sampled from the first k+1 tokens of each layer (deterministic sampling for stability).
9.2 Volume Consistency Loss
Prevents all layers from collapsing to identical geometry:
The negative standard deviation of log-volumes rewards spread: each layer should carve a distinct geometric structure, not duplicate.
9.3 Combined Geometric Loss
Defaults: λ_CM = 0.01, λ_vol = 0.005.
9.4 Geometric Loss Warmup
Linear ramp over W steps prevents geometric regularization from dominating early training:
Default: W = 200 steps.
10. Rectified Flow Matching
10.1 Linear Interpolation
Rectified flow defines a straight path between clean data x₀ and Gaussian noise ε:
10.2 Velocity Target
The derivative along this path is constant:
10.3 Shifted Schedule
Shift parameter s > 1 biases timestep distribution toward higher noise:
| t_base | t_shifted (s=2.5) |
|---|---|
| 0.1 | 0.217 |
| 0.3 | 0.517 |
| 0.5 | 0.714 |
| 0.7 | 0.854 |
| 0.9 | 0.957 |
Applied identically during training (timestep sampling) and inference (sigma schedule).
10.4 Logit-Normal Timestep Sampling
Biases samples toward mid-range timesteps where learning signal is strongest:
Default: μ = 0, σ = 1. Then scaled to [t_min, t_max] and shifted.
10.5 Continuous Float Timesteps
Timesteps passed to the UNet are not quantized:
The raw continuous t is additionally passed to the geo_prior as t_continuous to avoid quantization artifacts in the deformation schedule.
11. Min-SNR Loss Weighting
11.1 Signal-to-Noise Ratio
For the flow matching parameterization where σ = t:
11.2 Min-SNR-γ Clamping
Prevents high-noise timesteps (large t, low SNR) from dominating gradients:
11.3 Velocity Prediction Adjustment
An additional factor accounts for the velocity parameterization:
11.4 Weighted Task Loss
Default: γ = 5.0.
12. Classifier-Free Guidance
12.1 CFG Dropout (Training)
Encoder hidden states are zeroed with probability p_drop to train the unconditional path:
Default: p_drop = 0.1.
12.2 CFG at Inference
Implemented via batched forward pass: latent_input = cat([x, x]), enc_input = cat([E_cond, E_uncond]), then chunk the output.
13. Total Training Loss
Expanding fully:
Only the geo_prior parameters (4.8M) receive gradients. The UNet backbone (859M) and CLIP encoder are frozen.
14. Euler ODE Sampling
14.1 Sigma Schedule
Linear spacing from 1→0 with shift applied:
producing N+1 values from σ₀ ≈ 1 to σ_N = 0.
14.2 Euler Integration Step
Note: dt = σ_{i+1} - σ_i is negative (moving from noise toward clean).
15. Learning Rate Schedule
15.1 Linear Warmup
15.2 Cosine Decay
Defaults: η₀ = 10⁻⁴, η_min = 10⁻⁶, W_warmup = 100.
16. VAE Latent Space
16.1 Encoding
where s_f = 0.18215 (SD1.5 scaling factor).
16.2 Decoding
Latent shape: (B, 4, H/8, W/8). For 512×512 images: (B, 4, 64, 64).
17. Mixed Precision Strategy
| Component | Dtype | Rationale |
|---|---|---|
| Frozen UNet backbone | fp16 | 859M params, memory efficiency |
| Geo prior (all ops) | fp32 | CM det numerical stability |
| CM determinant | fp32 | torch.linalg.det requires it |
| Geo prior output | → fp16 | Cast to match UNet cross-attn |
| GradScaler | Enabled | Mixed fp32 geo + fp16 UNet pass |
The geo_prior runs inside torch.amp.autocast("cuda", enabled=False) with explicit .float() casts. On first inference call, geo_prior auto-casts itself to fp32 (one-time, 4.8M params ≈ 19MB).
18. Complete Hyperparameter Table
| Symbol | Parameter | Value | Source |
|---|---|---|---|
| k | Simplex dimension | 4 | SimplexConfig |
| d_e | Coordinate embedding dim | 32 | SimplexConfig |
| L | Stacked attention layers | 4 | SimplexConfig |
| H | Attention heads | 8 | SimplexConfig |
| d_h | Head dimension (768/H) | 96 | Derived |
| n | Vertices (k+1) | 5 | Derived |
| |E| | Edges (n choose 2) | 10 | Derived |
| δ_base | Base deformation scale | 0.25 | SimplexConfig |
| β_init | Initial blend logit | 0.0 (→β=0.5) | SimplexConfig |
| λ_CM | CM validity weight | 0.01 | SimplexConfig |
| λ_vol | Volume consistency weight | 0.005 | SimplexConfig |
| λ_geo | Total geo loss weight | 0.01 | TrainConfig |
| W_geo | Geo loss warmup steps | 200 | TrainConfig |
| s | Schedule shift | 2.5 | TrainConfig |
| γ | Min-SNR gamma | 5.0 | TrainConfig |
| p_drop | CFG dropout | 0.1 | TrainConfig |
| η₀ | Base learning rate | 10⁻⁴ | TrainConfig |
| η_min | Minimum learning rate | 10⁻⁶ | TrainConfig |
| W_lr | LR warmup steps | 100 | TrainConfig |
| s_f | VAE scaling factor | 0.18215 | SD1.5 |
19. Empirical Validation
Prior KSimplex Research
| Benchmark | Result |
|---|---|
| FMNIST classification | 89.13% |
| CIFAR-10 classification | 84.59% |
| CIFAR-100 classification | 69.08% |
| KSimplex LLM (54M params) | PPL 113 |
| Geometric validity (CM sign) | 100% |
| Attention entropy | Monotonically decreasing across layers |
| Token sweet spot | 25–77 tokens |
| Deformation stability zone | δ ∈ [0.15, 0.35] |
Diffusion Results (1 epoch, 10k ImageNet-synthetic)
- Trainable params: 4,845,725 (0.56% of model)
- Training time: ~7 min on L4 (24GB)
- Batch size: 6
- Fragmented anatomy (split bodies, duplicated heads) → coherent spatial composition
- Circular geometry (bowls, wheels) maintains curvature
- Novel prompts (unseen during training) produce equally coherent outputs
- Lune base capabilities fully preserved — purely additive improvement
20. Formula Index
Quick reference of all 42 formulas by section:
| # | Formula | Section |
|---|---|---|
| 1 | Regular simplex: d²_{ij} = 1 ∀ i≠j | §2.1 |
| 2 | Edge count: n(n-1)/2 | §2.2 |
| 3 | Stability: d_e/k ≥ 8 | §2.3 |
| 4 | Pairwise distances: D²_{ij} = Σ(V_i - V_j)² | §2.4 |
| 5 | CM bordered matrix construction | §3.1 |
| 6 | CM determinant: det(CM) | §3.1 |
| 7 | Validity: (-1)^{k+1} · CM > 0 | §3.2 |
| 8 | Volume²: (-1)^{k+1}/(2^k·(k!)²) · CM | §3.3 |
| 9 | k=4 volume: -1/9216 · CM | §3.3 |
| 10 | Deformed template: T + δ·ΔV | §4.1 |
| 11 | Deformation clamp: [0.05, 0.5] | §4.1 |
| 12 | Deform schedule MLP: σ(W₂·SiLU(W₁·t)) | §4.2 |
| 13 | Effective scale: δ·(0.5 + f(t)) | §4.2 |
| 14 | Coord projection: W_coord · LN(h) | §5.1 |
| 15 | Soft vertex assign: softmax(W_vertex · LN(h)) | §5.2 |
| 16 | Template anchor: w^T · V_deformed | §5.3 |
| 17 | Final coords: c + a | §5.4 |
| 18 | L2 normalize: ĉ/‖ĉ‖₂ | §5.5 |
| 19 | Distance logits: -d²/√d_e | §6.1 |
| 20 | Attention weights: softmax(logits) | §6.2 |
| 21 | Multi-head output: Concat(α·V^h)·W_O | §6.3 |
| 22 | Residual: H + output | §6.4 |
| 23 | Entropy: -Σ α log α | §7.1 |
| 24 | Entropy ordering: H₁ > H₂ > H₃ > H₄ | §7.2 |
| 25 | Blend: (1-β)·E_clip + β·E_geo | §8.1 |
| 26 | Blend factor: σ(β_logit) | §8.1 |
| 27 | Timestep blend: σ(MLP(t)) | §8.2 |
| 28 | CM validity loss: ReLU hinge | §9.1 |
| 29 | Volume consistency: -std(log|vol²|) | §9.2 |
| 30 | Combined geo: λ_CM·L_CM + λ_vol·L_vol | §9.3 |
| 31 | Geo warmup: λ·min(s/W, 1) | §9.4 |
| 32 | Flow interpolation: (1-t)·x₀ + t·ε | §10.1 |
| 33 | Velocity target: ε - x₀ | §10.2 |
| 34 | Schedule shift: s·t/(1+(s-1)·t) | §10.3 |
| 35 | Logit-normal: σ(N(μ,σ²)) | §10.4 |
| 36 | SNR: (1-t)²/t² | §11.1 |
| 37 | Min-SNR weight: min(SNR,γ)/SNR | §11.2 |
| 38 | Velocity adjustment: w/(SNR+1) | §11.3 |
| 39 | CFG dropout: E'=0 w.p. p_drop | §12.1 |
| 40 | CFG guidance: v_u + s·(v_c - v_u) | §12.2 |
| 41 | Euler step: x + dt·v_θ | §14.2 |
| 42 | Cosine LR: η_min + (η₀-η_min)·½(1+cos(πp)) | §15.2 |
21. Usage
from sd15_trainer_geo.pipeline import load_pipeline, load_geo_from_hub
from sd15_trainer_geo.generate import generate, show_images
pipe = load_pipeline(device="cuda", dtype=torch.float16)
pipe.unet.load_pretrained(
"AbstractPhil/tinyflux-experts",
subfolder="", filename="sd15-flow-lune-unet.safetensors"
)
load_geo_from_hub(pipe, "AbstractPhil/sd15-rectified-geometric-matching")
out = generate(pipe, ["a tabby cat on a windowsill"], shift=2.5, seed=42)
show_images(out)
Citation
@misc{abstractphil2026ksimplex,
title = {KSimplex Geometric Prior for Stable Diffusion},
author = {AbstractPhil},
year = {2026},
url = {https://huggingface.co/AbstractPhil/sd15-rectified-geometric-matching}
}
License: MIT — Build on it freely.