Spaces:
Running
Running
Fix: Bind to 0.0.0.0 and show_api=False to resolve runtime and schema errors
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +80 -6
- app.py +169 -0
- config.json +58 -0
- gfn/__init__.py +30 -0
- gfn/realizations/__init__.py +19 -0
- gfn/realizations/api.py +66 -0
- gfn/realizations/gssm/__init__.py +10 -0
- gfn/realizations/gssm/api.py +102 -0
- gfn/realizations/gssm/config/__init__.py +63 -0
- gfn/realizations/gssm/config/defaults.py +111 -0
- gfn/realizations/gssm/config/loader.py +163 -0
- gfn/realizations/gssm/config/schema.py +199 -0
- gfn/realizations/gssm/config/serialization.py +39 -0
- gfn/realizations/gssm/config/validator.py +109 -0
- gfn/realizations/gssm/constants.py +67 -0
- gfn/realizations/gssm/core/__init__.py +7 -0
- gfn/realizations/gssm/core/state.py +60 -0
- gfn/realizations/gssm/core/types.py +27 -0
- gfn/realizations/gssm/csrc/README.md +2 -0
- gfn/realizations/gssm/csrc/compile_cuda_12.9.bat +68 -0
- gfn/realizations/gssm/csrc/extension.cpp +87 -0
- gfn/realizations/gssm/csrc/geometry/low_rank.cu +160 -0
- gfn/realizations/gssm/csrc/integrators/integrators.cpp +252 -0
- gfn/realizations/gssm/csrc/integrators/integrators.h +41 -0
- gfn/realizations/gssm/csrc/losses/toroidal.cu +99 -0
- gfn/realizations/gssm/csrc/setup.py +38 -0
- gfn/realizations/gssm/cuda/__init__.py +11 -0
- gfn/realizations/gssm/cuda/autograd/__init__.py +0 -0
- gfn/realizations/gssm/cuda/kernels/__init__.py +0 -0
- gfn/realizations/gssm/cuda/kernels/geometry_kernels.py +99 -0
- gfn/realizations/gssm/cuda/kernels/integrator_kernels.py +73 -0
- gfn/realizations/gssm/cuda/ops/__init__.py +52 -0
- gfn/realizations/gssm/data/__init__.py +16 -0
- gfn/realizations/gssm/data/dataset.py +14 -0
- gfn/realizations/gssm/data/loader.py +53 -0
- gfn/realizations/gssm/data/replay.py +130 -0
- gfn/realizations/gssm/data/transforms.py +40 -0
- gfn/realizations/gssm/errors.py +23 -0
- gfn/realizations/gssm/geometry/__init__.py +42 -0
- gfn/realizations/gssm/geometry/adaptive.py +83 -0
- gfn/realizations/gssm/geometry/base.py +70 -0
- gfn/realizations/gssm/geometry/euclidean.py +20 -0
- gfn/realizations/gssm/geometry/factory.py +117 -0
- gfn/realizations/gssm/geometry/hierarchical.py +84 -0
- gfn/realizations/gssm/geometry/holographic.py +91 -0
- gfn/realizations/gssm/geometry/hyperbolic.py +97 -0
- gfn/realizations/gssm/geometry/low_rank.py +324 -0
- gfn/realizations/gssm/geometry/reactive.py +109 -0
- gfn/realizations/gssm/geometry/spherical.py +47 -0
- gfn/realizations/gssm/geometry/torus.py +274 -0
README.md
CHANGED
|
@@ -1,12 +1,86 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: GFN MNIAH Solver
|
| 3 |
+
emoji: 📍
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
| 8 |
+
python_version: "3.11"
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
license: cc-by-nc-nd-4.0
|
| 12 |
+
library_name: gfn
|
| 13 |
+
language: en
|
| 14 |
+
pipeline_tag: other
|
| 15 |
+
tags:
|
| 16 |
+
- gfn
|
| 17 |
+
- physics-informed
|
| 18 |
+
- geometric-deep-learning
|
| 19 |
+
- g-ssm
|
| 20 |
+
- needle-in-a-haystack
|
| 21 |
+
- long-context
|
| 22 |
+
model-index:
|
| 23 |
+
- name: gfn-gssm-mniah-k2
|
| 24 |
+
results:
|
| 25 |
+
- task:
|
| 26 |
+
type: other
|
| 27 |
+
name: Multi-Needle Retrieval
|
| 28 |
+
dataset:
|
| 29 |
+
name: synthetic-needle-haystack
|
| 30 |
+
type: synthetic
|
| 31 |
+
metrics:
|
| 32 |
+
- name: Accuracy
|
| 33 |
+
type: accuracy
|
| 34 |
+
value: 100.0
|
| 35 |
---
|
| 36 |
|
| 37 |
+
# 📍 GFN MNIAH (Multi-Needle-in-a-Haystack) Solver
|
| 38 |
+
|
| 39 |
+
[](https://doi.org/10.5281/zenodo.19141133)
|
| 40 |
+
[](https://huggingface.co/DepthMuun)
|
| 41 |
+
[](https://github.com/DepthMuun/gfn)
|
| 42 |
+
|
| 43 |
+
This repository demonstrates the **Geometric Flow Network (GFN)** framework's ability to handle extreme-length context retrieval without KV-cache. It implements the **Geodesic State Space Model (G-SSM)** specialized for the K=2 Needle-in-a-Haystack task.
|
| 44 |
+
|
| 45 |
+
## 🚀 Technical Highlights
|
| 46 |
+
- **Context Length**: Validated up to **32,000 tokens** (linear complexity $O(L)$).
|
| 47 |
+
- **Inductive Bias**: Topologically encodes "needles" as geodetic impulses that curve the world state manifold into a target decision region.
|
| 48 |
+
- **Inference VRAM**: ~38MB constant (O(1) memory).
|
| 49 |
+
|
| 50 |
+
## 🛠️ Local Installation & Usage
|
| 51 |
+
To run this model locally, you need the **GFN Framework** and the model assets.
|
| 52 |
+
|
| 53 |
+
### 1. Install GFN Framework
|
| 54 |
+
```bash
|
| 55 |
+
pip install git+https://github.com/DepthMuun/gfn.git
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### 2. Clone this repository
|
| 59 |
+
```bash
|
| 60 |
+
git lfs install
|
| 61 |
+
git clone https://huggingface.co/spaces/DepthMuun/gfn-gssm-mniah-k2-space
|
| 62 |
+
cd gfn-gssm-mniah-k2-space
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### 3. Run the Interactive Demo
|
| 66 |
+
```bash
|
| 67 |
+
python app.py
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## 📜 Citation
|
| 71 |
+
If you use this work, please cite:
|
| 72 |
+
```latex
|
| 73 |
+
@article{sturtz2026geometry,
|
| 74 |
+
title={Geometric Flow Networks: A Physics-Informed Paradigm for Sequential Intelligence},
|
| 75 |
+
author={Stürtz, Joaquín},
|
| 76 |
+
journal={Zenodo Preprints},
|
| 77 |
+
year={2026},
|
| 78 |
+
doi={10.5281/zenodo.19141133},
|
| 79 |
+
url={https://doi.org/10.5281/zenodo.19141133}
|
| 80 |
+
}
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## 🔗 Resources
|
| 84 |
+
- **Official Checkpoint**: [GFN MNIAH Model](https://huggingface.co/DepthMuun/gfn-gssm-mniah-k2)
|
| 85 |
+
- **Framework Source**: [GitHub: DepthMuun/gfn](https://github.com/DepthMuun/gfn)
|
| 86 |
+
- **Official Paper**: [Zenodo](https://doi.org/10.5281/zenodo.19141133)
|
app.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add local gfn folder to path if it exists (for HF Spaces)
|
| 11 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
if os.path.exists(os.path.join(script_dir, "gfn")):
|
| 13 |
+
sys.path.insert(0, script_dir)
|
| 14 |
+
|
| 15 |
+
import gfn
|
| 16 |
+
|
| 17 |
+
def load_model():
|
| 18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
+
|
| 20 |
+
# Load config safely using absolute path
|
| 21 |
+
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.json")
|
| 22 |
+
with open(config_path, "r") as f:
|
| 23 |
+
config = json.load(f)
|
| 24 |
+
|
| 25 |
+
model = gfn.gssm.create(
|
| 26 |
+
vocab_size=config['architecture']['vocab_size'],
|
| 27 |
+
dim=config['architecture']['dim'],
|
| 28 |
+
depth=config['architecture']['depth'],
|
| 29 |
+
heads=config['architecture']['heads'],
|
| 30 |
+
integrator=config['architecture']['integrator'],
|
| 31 |
+
impulse_scale=config['architecture']['impulse_scale'],
|
| 32 |
+
dynamics_type=config['architecture']['dynamics_type'],
|
| 33 |
+
topology_type=config['architecture']['topology_type'],
|
| 34 |
+
physics=config['physics'],
|
| 35 |
+
holographic=config['architecture'].get('holographic', True),
|
| 36 |
+
).to(device)
|
| 37 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 38 |
+
checkpoint_path = os.path.join(script_dir, "mniah_model_final.pt")
|
| 39 |
+
if not os.path.exists(checkpoint_path):
|
| 40 |
+
raise FileNotFoundError(f"Missing model weights: {checkpoint_path}. Please place the trained checkpoint here.")
|
| 41 |
+
|
| 42 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=True)
|
| 43 |
+
model.load_state_dict(ckpt['model'])
|
| 44 |
+
model.eval()
|
| 45 |
+
return model, device
|
| 46 |
+
|
| 47 |
+
model, device = load_model()
|
| 48 |
+
|
| 49 |
+
import json
|
| 50 |
+
import tempfile
|
| 51 |
+
|
| 52 |
+
def run_mniah(seq_len, needle_pos_str, num_needles):
|
| 53 |
+
try:
|
| 54 |
+
if needle_pos_str.strip() == "":
|
| 55 |
+
# Random needles
|
| 56 |
+
lo = 1
|
| 57 |
+
pool = torch.randperm(seq_len)[:num_needles]
|
| 58 |
+
positions = sorted((pool + lo).tolist())
|
| 59 |
+
else:
|
| 60 |
+
positions = sorted([int(p.strip()) for p in needle_pos_str.split(",")])
|
| 61 |
+
if len(positions) != num_needles:
|
| 62 |
+
return f"Error: Number of positions ({len(positions)}) must match needle count ({num_needles})."
|
| 63 |
+
if any(p < 1 or p > seq_len for p in positions):
|
| 64 |
+
return f"Error: Positions must be between 1 and {seq_len}."
|
| 65 |
+
|
| 66 |
+
# Input Generation
|
| 67 |
+
# Format: [Haystack, Needle, ...] (No context token for k2 model)
|
| 68 |
+
x = torch.zeros(1, seq_len, dtype=torch.long, device=device)
|
| 69 |
+
|
| 70 |
+
# Needles are 1-indexed from the user UI, so subtract 1 for tensor indexing
|
| 71 |
+
for p in positions:
|
| 72 |
+
x[0, p - 1] = 1 # Needle token
|
| 73 |
+
|
| 74 |
+
t0 = time.time()
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
output = model(x)
|
| 77 |
+
x_pred = output[0] # [B, L, D] or [B, L, H, D]
|
| 78 |
+
if x_pred.ndim == 4:
|
| 79 |
+
x_pred = x_pred.mean(dim=2)
|
| 80 |
+
|
| 81 |
+
elapsed = time.time() - t0
|
| 82 |
+
|
| 83 |
+
# Binary prediction (toroidal)
|
| 84 |
+
PI = math.pi
|
| 85 |
+
TWO_PI = 2.0 * PI
|
| 86 |
+
half_pi = PI * 0.5
|
| 87 |
+
dist_pos = torch.min(
|
| 88 |
+
torch.abs(x_pred - half_pi) % TWO_PI,
|
| 89 |
+
TWO_PI - (torch.abs(x_pred - half_pi) % TWO_PI)
|
| 90 |
+
).mean(dim=-1)
|
| 91 |
+
dist_neg = torch.min(
|
| 92 |
+
torch.abs(x_pred + half_pi) % TWO_PI,
|
| 93 |
+
TWO_PI - (torch.abs(x_pred + half_pi) % TWO_PI)
|
| 94 |
+
).mean(dim=-1)
|
| 95 |
+
|
| 96 |
+
preds = (dist_pos < dist_neg).long()[0] # [L]
|
| 97 |
+
|
| 98 |
+
# Result summary
|
| 99 |
+
last_pos = positions[-1] - 1 # Adjust to 0-indexed
|
| 100 |
+
acc_after = (preds[last_pos+1:] == 1).float().mean().item() if last_pos+1 < seq_len else 1.0
|
| 101 |
+
acc_before = (preds[:last_pos] == 0).float().mean().item() if last_pos > 0 else 1.0
|
| 102 |
+
result_data = {
|
| 103 |
+
"Status": "Success",
|
| 104 |
+
"Context Length": seq_len,
|
| 105 |
+
"Needle Positions": positions,
|
| 106 |
+
"Predictions (subset)": preds.tolist()[:500],
|
| 107 |
+
"Acc After Last Needle": f"{acc_after:.2%}",
|
| 108 |
+
"Acc Before Flip": f"{acc_before:.2%}",
|
| 109 |
+
"Inference Time": f"{elapsed:.3f}s",
|
| 110 |
+
"Full Trace": preds.tolist()
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# Save to temp file for download
|
| 114 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w')
|
| 115 |
+
json.dump(result_data, temp_file, indent=4)
|
| 116 |
+
temp_file.close()
|
| 117 |
+
|
| 118 |
+
# Create Markdown table for UI
|
| 119 |
+
md_summary = f"""
|
| 120 |
+
### 📊 Invariant Evaluation Summary
|
| 121 |
+
|
| 122 |
+
| Metric | Value |
|
| 123 |
+
|:---|:---|
|
| 124 |
+
| **Status** | 🟢 {result_data['Status']} |
|
| 125 |
+
| **Context Length ($L$)** | {seq_len:,} tokens |
|
| 126 |
+
| **Needles Count ($K$)** | {num_needles} (Positions: {positions}) |
|
| 127 |
+
| **Accuracy (Before Critical Threshold)** | **{acc_before:.2%}** |
|
| 128 |
+
| **Accuracy (After Critical Threshold)** | **{acc_after:.2%}** |
|
| 129 |
+
| **Inference Time** | {elapsed:.3f}s |
|
| 130 |
+
|
| 131 |
+
> *Note: Model was trained explicitly on $K=2$. Extrapolating to higher $K$ demonstrates how the physical integrator accumulates forces. The 'Accuracy Before' metric measures strict adherence before the final expected needle, but geometrical flips might occur exactly when the integrated energy crosses the learned $K=2$ threshold.*
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
return md_summary, temp_file.name
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return f"### ❌ Error\n{str(e)}", None
|
| 137 |
+
|
| 138 |
+
with gr.Blocks(title="G-SSM MNIAH Solver") as demo:
|
| 139 |
+
gr.Markdown("# 📍 G-SSM Multi-Needle-in-a-Haystack (MNIAH)")
|
| 140 |
+
gr.Markdown("""
|
| 141 |
+
This space demonstrates the **Geodesic State Space Model (G-SSM)** on the Multi-Needle task.
|
| 142 |
+
The model must detect and remember exactly **K needles** scattered across a long sequence
|
| 143 |
+
of "haystack" tokens. Its internal state only "flips" to the target configuration after **all K impulses**
|
| 144 |
+
have been integrated into the geodetic flow.
|
| 145 |
+
""")
|
| 146 |
+
|
| 147 |
+
with gr.Row():
|
| 148 |
+
seq_len = gr.Number(value=1000, minimum=64, maximum=10_000_000_000, precision=0, label="Sequence Length (Up to 10B, be mindful of VRAM!)")
|
| 149 |
+
num_k = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Number of Needles (K)")
|
| 150 |
+
positions = gr.Textbox(label="Manual Positions (comma separated, empty for random)", placeholder="e.g. 100, 450")
|
| 151 |
+
|
| 152 |
+
submit_btn = gr.Button("Evaluate Geometric Memory", variant="primary")
|
| 153 |
+
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column(scale=2):
|
| 156 |
+
output_md = gr.Markdown("### 📊 Results will appear here...")
|
| 157 |
+
with gr.Column(scale=1):
|
| 158 |
+
download_btn = gr.File(label="Full Trace (JSON)")
|
| 159 |
+
|
| 160 |
+
submit_btn.click(fn=run_mniah, inputs=[seq_len, positions, num_k], outputs=[output_md, download_btn])
|
| 161 |
+
|
| 162 |
+
gr.Examples(
|
| 163 |
+
examples=[
|
| 164 |
+
[1000, "100, 900", 2],
|
| 165 |
+
[4000, "50, 1500, 3900", 3],
|
| 166 |
+
[32000, "", 1]
|
| 167 |
+
],
|
| 168 |
+
inputs=[seq_len, positions, num_k]
|
| 169 |
+
)
|
config.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": {
|
| 3 |
+
"vocab_size": 2,
|
| 4 |
+
"dim": 16,
|
| 5 |
+
"depth": 2,
|
| 6 |
+
"heads": 2,
|
| 7 |
+
"integrator": "leapfrog",
|
| 8 |
+
"impulse_scale": 80.0,
|
| 9 |
+
"dynamics_type": "direct",
|
| 10 |
+
"topology_type": "torus",
|
| 11 |
+
"holographic": true
|
| 12 |
+
},
|
| 13 |
+
"physics": {
|
| 14 |
+
"embedding": {
|
| 15 |
+
"type": "functional",
|
| 16 |
+
"mode": "linear",
|
| 17 |
+
"coord_dim": 16
|
| 18 |
+
},
|
| 19 |
+
"readout": {
|
| 20 |
+
"type": "implicit",
|
| 21 |
+
"coord_dim": 16
|
| 22 |
+
},
|
| 23 |
+
"active_inference": {
|
| 24 |
+
"enabled": true,
|
| 25 |
+
"dynamic_time": {
|
| 26 |
+
"enabled": true
|
| 27 |
+
},
|
| 28 |
+
"reactive_curvature": {
|
| 29 |
+
"enabled": true,
|
| 30 |
+
"plasticity": 0.2
|
| 31 |
+
},
|
| 32 |
+
"singularities": {
|
| 33 |
+
"enabled": true,
|
| 34 |
+
"strength": 20.0,
|
| 35 |
+
"threshold": 0.8
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"fractal": {
|
| 39 |
+
"enabled": true,
|
| 40 |
+
"threshold": 0.5,
|
| 41 |
+
"alpha": 0.2
|
| 42 |
+
},
|
| 43 |
+
"topology": {
|
| 44 |
+
"type": "torus",
|
| 45 |
+
"riemannian_type": "low_rank"
|
| 46 |
+
},
|
| 47 |
+
"stability": {
|
| 48 |
+
"enable_trace_normalization": true,
|
| 49 |
+
"base_dt": 0.4,
|
| 50 |
+
"velocity_saturation": 15.0,
|
| 51 |
+
"friction": 2.0,
|
| 52 |
+
"toroidal_curvature_scale": 0.01
|
| 53 |
+
},
|
| 54 |
+
"hysteresis": {
|
| 55 |
+
"enabled": false
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
}
|
gfn/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GFN (Geodesic Flow Network) Package
|
| 3 |
+
==================================
|
| 4 |
+
Unified framework for Geodesic State Space Models (G-SSM)
|
| 5 |
+
and Inertial State Networks (ISN).
|
| 6 |
+
|
| 7 |
+
This package implements the GFN paradigm as a platform for
|
| 8 |
+
physics-informed neural dynamics.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# ── Realizations ──────────────────────────────────────────────────────────────
|
| 12 |
+
from .realizations import api, gssm, isn
|
| 13 |
+
from .realizations.api import create, load, save
|
| 14 |
+
|
| 15 |
+
# ── Dynamic Registry
|
| 16 |
+
REALIZATIONS = api.list_available()
|
| 17 |
+
|
| 18 |
+
# ── Package Metadata ──────────────────────────────────────────────────────────
|
| 19 |
+
__version__ = "2.7.0"
|
| 20 |
+
__author__ = "DepthMuun"
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"gssm",
|
| 24 |
+
"isn",
|
| 25 |
+
"api",
|
| 26 |
+
"create",
|
| 27 |
+
"load",
|
| 28 |
+
"save",
|
| 29 |
+
"REALIZATIONS",
|
| 30 |
+
]
|
gfn/realizations/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GFN Realizations Subpackage
|
| 3 |
+
===========================
|
| 4 |
+
Contains specific implementations of the GFN paradigm:
|
| 5 |
+
- G-SSM: Geodesic State Space Model (Riemannian/Symplectic)
|
| 6 |
+
- ISN: Inertial State Network (Physics-Informed Interaction Engine)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from . import api
|
| 10 |
+
from .api import create, list_available
|
| 11 |
+
|
| 12 |
+
# Trigger registration of standard realizations
|
| 13 |
+
from . import gssm
|
| 14 |
+
from . import isn
|
| 15 |
+
|
| 16 |
+
# Future realizations can be added here or via external plugins
|
| 17 |
+
# from . import rt
|
| 18 |
+
|
| 19 |
+
__all__ = ['gssm', 'isn', 'api', 'create', 'list_available']
|
gfn/realizations/api.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GFN Realizations API Router (Purified Version)
|
| 3 |
+
==============================================
|
| 4 |
+
Agnostic factory and dynamic registry for GFN architectural realizations.
|
| 5 |
+
Follows SOLID principles: open for extension, closed for modification.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any, Optional, Protocol, runtime_checkable
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
@runtime_checkable
|
| 15 |
+
class RealizationProvider(Protocol):
|
| 16 |
+
"""Protocol defining the interface any GFN realization must provide."""
|
| 17 |
+
def create(self, **kwargs) -> nn.Module: ...
|
| 18 |
+
def save(self, model: nn.Module, path: str): ...
|
| 19 |
+
def load(self, path: str, **kwargs) -> nn.Module: ...
|
| 20 |
+
|
| 21 |
+
# The Dynamic Registry
|
| 22 |
+
_REGISTRY: Dict[str, RealizationProvider] = {}
|
| 23 |
+
|
| 24 |
+
def register(name: str, provider: RealizationProvider):
|
| 25 |
+
"""
|
| 26 |
+
Register a new realization architecture.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
name: Unique identifier for the realization.
|
| 30 |
+
provider: An object or module implementing the RealizationProvider protocol.
|
| 31 |
+
"""
|
| 32 |
+
name = name.lower()
|
| 33 |
+
if name in _REGISTRY:
|
| 34 |
+
logger.debug(f"Overwriting GFN realization provider: {name}")
|
| 35 |
+
_REGISTRY[name] = provider
|
| 36 |
+
|
| 37 |
+
def list_available() -> List[str]:
|
| 38 |
+
"""List all dynamically registered architectural realizations."""
|
| 39 |
+
return list(_REGISTRY.keys())
|
| 40 |
+
|
| 41 |
+
def create(name: str, **kwargs) -> nn.Module:
|
| 42 |
+
"""
|
| 43 |
+
Unified factory to create any registered GFN realization by name.
|
| 44 |
+
"""
|
| 45 |
+
name = name.lower()
|
| 46 |
+
if name not in _REGISTRY:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"GFN Error: Realization '{name}' is not registered. "
|
| 49 |
+
f"Ensure the subpackage is imported. Available: {list_available()}"
|
| 50 |
+
)
|
| 51 |
+
return _REGISTRY[name].create(**kwargs)
|
| 52 |
+
|
| 53 |
+
def save(model: nn.Module, path: str, realization: Optional[str] = None):
|
| 54 |
+
"""Unified save interface delegate."""
|
| 55 |
+
if realization and realization.lower() in _REGISTRY:
|
| 56 |
+
_REGISTRY[realization.lower()].save(model, path)
|
| 57 |
+
else:
|
| 58 |
+
import torch
|
| 59 |
+
torch.save(model.state_dict(), path)
|
| 60 |
+
|
| 61 |
+
def load(path: str, realization: str, **kwargs) -> nn.Module:
|
| 62 |
+
"""Unified load interface delegate."""
|
| 63 |
+
realization = realization.lower()
|
| 64 |
+
if realization not in _REGISTRY:
|
| 65 |
+
raise ValueError(f"GFN Error: Realization provider for '{realization}' not found.")
|
| 66 |
+
return _REGISTRY[realization].load(path, **kwargs)
|
gfn/realizations/gssm/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Export core API
|
| 2 |
+
from .api import create, save, load, Model, Manifold, loss, Trainer
|
| 3 |
+
|
| 4 |
+
# Register with central realization registry
|
| 5 |
+
try:
|
| 6 |
+
from .. import api as central_api
|
| 7 |
+
from . import api as gssm_api
|
| 8 |
+
central_api.register('gssm', gssm_api)
|
| 9 |
+
except ImportError:
|
| 10 |
+
pass # Fallback for standalone GSSM usage
|
gfn/realizations/gssm/api.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/api.py — GFN V5
|
| 3 |
+
Interfaz pública simplificada y orquestación de alto nivel.
|
| 4 |
+
Centraliza la creación, carga y evaluación de modelos.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Dict, Any, Union
|
| 10 |
+
|
| 11 |
+
from .models.factory import ModelFactory
|
| 12 |
+
from .models.manifold import ManifoldModel
|
| 13 |
+
from .losses.factory import LossFactory
|
| 14 |
+
from .training.trainer import GFNTrainer
|
| 15 |
+
from .training.evaluation import ManifoldMetricEvaluator
|
| 16 |
+
|
| 17 |
+
# -- Alias principales
|
| 18 |
+
Model = ManifoldModel
|
| 19 |
+
Manifold = ManifoldModel
|
| 20 |
+
Trainer = GFNTrainer
|
| 21 |
+
|
| 22 |
+
def create(*args, **kwargs):
|
| 23 |
+
"""Factory para modelos Manifold (V5)."""
|
| 24 |
+
return ModelFactory.create(*args, **kwargs)
|
| 25 |
+
|
| 26 |
+
def loss(config, **kwargs):
|
| 27 |
+
"""Factory para funciones de pérdida (V5)."""
|
| 28 |
+
return LossFactory.create(config, **kwargs)
|
| 29 |
+
|
| 30 |
+
def save(model: nn.Module, path: str):
|
| 31 |
+
"""
|
| 32 |
+
Guarda el modelo y su configuración (HuggingFace Style).
|
| 33 |
+
"""
|
| 34 |
+
if hasattr(model, 'save_pretrained'):
|
| 35 |
+
model.save_pretrained(path)
|
| 36 |
+
else:
|
| 37 |
+
# Fallback para modelos que no heredan de BaseModel
|
| 38 |
+
torch.save({'state_dict': model.state_dict()}, path)
|
| 39 |
+
|
| 40 |
+
def load(path: str, device: Optional[str] = None):
|
| 41 |
+
"""
|
| 42 |
+
Carga un modelo guardado junto con su configuración.
|
| 43 |
+
Soporta directorios (HF Style) o archivos .pth/.bin legados.
|
| 44 |
+
"""
|
| 45 |
+
import os
|
| 46 |
+
if os.path.isdir(path):
|
| 47 |
+
return ModelFactory.from_pretrained(path)
|
| 48 |
+
|
| 49 |
+
# Fallback para archivos aislados legados
|
| 50 |
+
checkpoint = torch.load(path, map_location=device or 'cpu', weights_only=True)
|
| 51 |
+
config = checkpoint.get('config')
|
| 52 |
+
if config is None:
|
| 53 |
+
raise ValueError(f"No se encontró configuración en el checkpoint {path}. Use directorios HF para carga completa.")
|
| 54 |
+
|
| 55 |
+
model = create(config=config)
|
| 56 |
+
|
| 57 |
+
# Robust state_dict extraction (handles different saving conventions)
|
| 58 |
+
state_dict = checkpoint.get('state_dict') or checkpoint.get('model') or checkpoint
|
| 59 |
+
|
| 60 |
+
# Filter state_dict against the model's actual parameters
|
| 61 |
+
model_state = model.state_dict()
|
| 62 |
+
filtered_state = {k: v for k, v in state_dict.items() if k in model_state}
|
| 63 |
+
|
| 64 |
+
# Log filtered keys for debugging (optional)
|
| 65 |
+
n_filtered = len(state_dict) - len(filtered_state)
|
| 66 |
+
if n_filtered > 0:
|
| 67 |
+
import logging
|
| 68 |
+
logging.getLogger("gssm.api").info(f"Filtered {n_filtered} unexpected keys from state_dict (legacy or auxiliary data).")
|
| 69 |
+
|
| 70 |
+
# Load with strict=False to handle potential missing non-essential parameters
|
| 71 |
+
model.load_state_dict(filtered_state, strict=False)
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
def benchmark(model: nn.Module, dataloader: torch.utils.data.DataLoader,
|
| 75 |
+
device: Optional[str] = None) -> Dict[str, float]:
|
| 76 |
+
"""
|
| 77 |
+
Ejecuta una evaluación rápida de métricas geométricas y de tarea.
|
| 78 |
+
"""
|
| 79 |
+
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 80 |
+
model.to(device)
|
| 81 |
+
model.eval()
|
| 82 |
+
|
| 83 |
+
evaluator = ManifoldMetricEvaluator(model)
|
| 84 |
+
all_x, all_v, all_y = [], [], []
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
for x, y in dataloader:
|
| 88 |
+
x, y = x.to(device), y.to(device)
|
| 89 |
+
logits, (xf, vf), info = model(x)
|
| 90 |
+
|
| 91 |
+
all_x.append(xf.detach().cpu())
|
| 92 |
+
all_v.append(vf.detach().cpu())
|
| 93 |
+
all_y.append(y.detach().cpu())
|
| 94 |
+
|
| 95 |
+
if not all_x:
|
| 96 |
+
return {}
|
| 97 |
+
|
| 98 |
+
x_total = torch.cat(all_x, dim=0)
|
| 99 |
+
v_total = torch.cat(all_v, dim=0)
|
| 100 |
+
y_total = torch.cat(all_y, dim=0)
|
| 101 |
+
|
| 102 |
+
return evaluator.full_report(x_total, v_total, y_total)
|
gfn/realizations/gssm/config/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/config/__init__.py
|
| 3 |
+
Public API for the configuration module — GFN V5
|
| 4 |
+
Centralized configuration system for all GFN components.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .schema import (
|
| 8 |
+
TopologyConfig,
|
| 9 |
+
StabilityConfig,
|
| 10 |
+
DynamicTimeConfig,
|
| 11 |
+
HysteresisConfig,
|
| 12 |
+
ActiveInferenceConfig,
|
| 13 |
+
EmbeddingConfig,
|
| 14 |
+
ReadoutConfig,
|
| 15 |
+
MixtureConfig,
|
| 16 |
+
DynamicsConfig,
|
| 17 |
+
FractalConfig,
|
| 18 |
+
SingularityConfig,
|
| 19 |
+
PhysicsConfig,
|
| 20 |
+
TrainerConfig,
|
| 21 |
+
ManifoldConfig,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from .defaults import (
|
| 25 |
+
PHYSICS_DEFAULTS,
|
| 26 |
+
MODEL_DEFAULTS,
|
| 27 |
+
TRAINING_DEFAULTS,
|
| 28 |
+
LOSS_DEFAULTS,
|
| 29 |
+
get_default,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from .loader import dict_to_physics_config
|
| 33 |
+
from .validator import ConfigValidator, validate_manifold_config, validate_and_print
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
# Schema classes
|
| 37 |
+
"TopologyConfig",
|
| 38 |
+
"StabilityConfig",
|
| 39 |
+
"DynamicTimeConfig",
|
| 40 |
+
"HysteresisConfig",
|
| 41 |
+
"ActiveInferenceConfig",
|
| 42 |
+
"EmbeddingConfig",
|
| 43 |
+
"ReadoutConfig",
|
| 44 |
+
"MixtureConfig",
|
| 45 |
+
"DynamicsConfig",
|
| 46 |
+
"FractalConfig",
|
| 47 |
+
"SingularityConfig",
|
| 48 |
+
"PhysicsConfig",
|
| 49 |
+
"TrainerConfig",
|
| 50 |
+
"ManifoldConfig",
|
| 51 |
+
# Defaults
|
| 52 |
+
"PHYSICS_DEFAULTS",
|
| 53 |
+
"MODEL_DEFAULTS",
|
| 54 |
+
"TRAINING_DEFAULTS",
|
| 55 |
+
"LOSS_DEFAULTS",
|
| 56 |
+
"get_default",
|
| 57 |
+
# Loader
|
| 58 |
+
"dict_to_physics_config",
|
| 59 |
+
# Validator
|
| 60 |
+
"ConfigValidator",
|
| 61 |
+
"validate_manifold_config",
|
| 62 |
+
"validate_and_print",
|
| 63 |
+
]
|
gfn/realizations/gssm/config/defaults.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config/defaults.py — GFN V5
|
| 3 |
+
Valores por defecto centralizados para todas las configuraciones.
|
| 4 |
+
Elimina hardcodes dispersos en implementaciones.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from ..constants import (
|
| 9 |
+
DEFAULT_DT, DEFAULT_FRICTION, DEFAULT_PLASTICITY,
|
| 10 |
+
MAX_VELOCITY, CURVATURE_CLAMP, VELOCITY_FRICTION_SCALE,
|
| 11 |
+
SINGULARITY_THRESHOLD, BLACK_HOLE_STRENGTH, EPSILON_STANDARD,
|
| 12 |
+
TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# ─── Física ─────────────────────────────────────────────────────────────────
|
| 16 |
+
PHYSICS_DEFAULTS: Dict[str, Any] = {
|
| 17 |
+
# Topología
|
| 18 |
+
'topology_type': TOPOLOGY_EUCLIDEAN,
|
| 19 |
+
'riemannian_type': 'low_rank',
|
| 20 |
+
'major_radius_R': 2.0,
|
| 21 |
+
'minor_radius_r': 1.0,
|
| 22 |
+
|
| 23 |
+
# Estabilidad — referencias a constants.py (una sola fuente de verdad)
|
| 24 |
+
'base_dt': DEFAULT_DT,
|
| 25 |
+
'adaptive_dt': True,
|
| 26 |
+
'friction': DEFAULT_FRICTION,
|
| 27 |
+
'velocity_clamp': MAX_VELOCITY,
|
| 28 |
+
'curvature_clamp': CURVATURE_CLAMP,
|
| 29 |
+
'enable_trace_normalization': True,
|
| 30 |
+
'velocity_friction_scale': VELOCITY_FRICTION_SCALE,
|
| 31 |
+
'integrator_type': 'leapfrog',
|
| 32 |
+
'friction_mode': 'static',
|
| 33 |
+
|
| 34 |
+
# Inferencia activa
|
| 35 |
+
'active_inference_enabled': True,
|
| 36 |
+
'holographic_geometry': False,
|
| 37 |
+
'plasticity': DEFAULT_PLASTICITY,
|
| 38 |
+
|
| 39 |
+
# Singularities
|
| 40 |
+
'singularity_enabled': False,
|
| 41 |
+
'singularity_threshold': SINGULARITY_THRESHOLD,
|
| 42 |
+
'singularity_strength': BLACK_HOLE_STRENGTH,
|
| 43 |
+
'singularity_epsilon': EPSILON_STANDARD,
|
| 44 |
+
|
| 45 |
+
# Hysteresis
|
| 46 |
+
'hysteresis_enabled': False,
|
| 47 |
+
'hysteresis_decay': 0.95,
|
| 48 |
+
'hysteresis_ghost_force': True,
|
| 49 |
+
|
| 50 |
+
# Stochasticity / Curiosity
|
| 51 |
+
'stochasticity_enabled': False,
|
| 52 |
+
'stochasticity_type': 'brownian',
|
| 53 |
+
'stochasticity_sigma': 0.01,
|
| 54 |
+
'curiosity_enabled': False,
|
| 55 |
+
'curiosity_strength': 0.1,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# ─── Modelo ──────────────────────────────────────────────────────────────────
|
| 59 |
+
MODEL_DEFAULTS: Dict[str, Any] = {
|
| 60 |
+
'dim': 64,
|
| 61 |
+
'heads': 4,
|
| 62 |
+
'depth': 2,
|
| 63 |
+
'rank': 16,
|
| 64 |
+
'vocab_size': 256,
|
| 65 |
+
'holographic': False,
|
| 66 |
+
'pooling_type': None,
|
| 67 |
+
'initial_spread': 1e-3,
|
| 68 |
+
'n_trajectories': 1,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# ─── Entrenamiento ───────────────────────────────────────────────────────────
|
| 72 |
+
TRAINING_DEFAULTS: Dict[str, Any] = {
|
| 73 |
+
'lr': 1e-3,
|
| 74 |
+
'optimizer_type': 'adam',
|
| 75 |
+
'weight_decay': 0.0,
|
| 76 |
+
'grad_clip': 1.0,
|
| 77 |
+
'epochs': 10,
|
| 78 |
+
'batch_size': 32,
|
| 79 |
+
'scheduler_type': 'cosine_warmup',
|
| 80 |
+
'warmup_steps': 100,
|
| 81 |
+
'min_lr': 1e-6,
|
| 82 |
+
'task': 'lm',
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# ─── Pérdidas ────────────────────────────────────────────────────────────────
|
| 86 |
+
LOSS_DEFAULTS: Dict[str, Any] = {
|
| 87 |
+
'type': 'generative',
|
| 88 |
+
'mode': 'nll',
|
| 89 |
+
'entropy_coef': 0.0,
|
| 90 |
+
'label_smoothing': 0.0,
|
| 91 |
+
|
| 92 |
+
# Physics-informed
|
| 93 |
+
'lambda_physics': 0.01,
|
| 94 |
+
'lambda_geo': 0.001,
|
| 95 |
+
'lambda_ham': 0.0,
|
| 96 |
+
'lambda_kin': 0.0,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_default(section: str, key: str, fallback=None):
|
| 101 |
+
"""
|
| 102 |
+
Obtiene un valor por defecto desde la sección correspondiente.
|
| 103 |
+
Uso: get_default('physics', 'base_dt') -> 0.1
|
| 104 |
+
"""
|
| 105 |
+
mapping = {
|
| 106 |
+
'physics': PHYSICS_DEFAULTS,
|
| 107 |
+
'model': MODEL_DEFAULTS,
|
| 108 |
+
'training': TRAINING_DEFAULTS,
|
| 109 |
+
'loss': LOSS_DEFAULTS,
|
| 110 |
+
}
|
| 111 |
+
return mapping.get(section, {}).get(key, fallback)
|
gfn/realizations/gssm/config/loader.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config/loader.py — GFN V5
|
| 3 |
+
Conversión de dicts de configuración a PhysicsConfig tipado.
|
| 4 |
+
Soporte para overrides anidados sobre configs existentes.
|
| 5 |
+
"""
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
from .schema import (
|
| 8 |
+
PhysicsConfig, TopologyConfig, StabilityConfig, DynamicsConfig,
|
| 9 |
+
ActiveInferenceConfig, DynamicTimeConfig, HysteresisConfig,
|
| 10 |
+
EmbeddingConfig, FractalConfig, SingularityConfig,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def dict_to_physics_config(d: Dict[str, Any]) -> PhysicsConfig:
|
| 15 |
+
"""
|
| 16 |
+
Convierte un dict anidado en un PhysicsConfig tipado.
|
| 17 |
+
|
| 18 |
+
Soporta todos los sub-campos de PhysicsConfig. Los campos no presentes
|
| 19 |
+
en el dict mantienen sus valores default del schema.
|
| 20 |
+
Si `d` ya es PhysicsConfig, lo devuelve intacto.
|
| 21 |
+
"""
|
| 22 |
+
if isinstance(d, PhysicsConfig):
|
| 23 |
+
return d
|
| 24 |
+
|
| 25 |
+
cfg = PhysicsConfig()
|
| 26 |
+
_apply_dict_to_physics_config(cfg, d)
|
| 27 |
+
return cfg
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_physics_overrides(cfg: PhysicsConfig, overrides: Dict[str, Any]) -> PhysicsConfig:
|
| 31 |
+
"""
|
| 32 |
+
Aplica un dict de overrides sobre un PhysicsConfig EXISTENTE (in-place).
|
| 33 |
+
|
| 34 |
+
A diferencia de dict_to_physics_config(), esta función NO parte de defaults
|
| 35 |
+
sino que modifica solo los campos presentes en el dict, dejando el resto intacto.
|
| 36 |
+
Es la función que usa ModelFactory cuando se combina preset + physics kwarg.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
cfg: PhysicsConfig existente (ej. resultado de get_preset())
|
| 40 |
+
overrides: Dict anidado con los campos a sobreescribir
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
El mismo cfg modificado in-place (también retornado para encadenamiento).
|
| 44 |
+
"""
|
| 45 |
+
if not overrides:
|
| 46 |
+
return cfg
|
| 47 |
+
_apply_dict_to_physics_config(cfg, overrides)
|
| 48 |
+
return cfg
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _apply_dict_to_physics_config(cfg: PhysicsConfig, d: Dict[str, Any]) -> None:
|
| 52 |
+
"""Función interna — aplica los campos del dict sobre cfg in-place."""
|
| 53 |
+
|
| 54 |
+
# ── Topology ──────────────────────────────────────────────────────────────
|
| 55 |
+
t_d = d.get('topology', d.get('topology_config', {}))
|
| 56 |
+
if isinstance(t_d, dict) and t_d:
|
| 57 |
+
_apply(cfg.topology, t_d, [
|
| 58 |
+
'type', 'R', 'r', 'curvature',
|
| 59 |
+
'riemannian_type', 'riemannian_rank', 'riemannian_class',
|
| 60 |
+
'geometry_scope'
|
| 61 |
+
])
|
| 62 |
+
if 'major_radius' in t_d: cfg.topology.R = t_d['major_radius']
|
| 63 |
+
if 'minor_radius' in t_d: cfg.topology.r = t_d['minor_radius']
|
| 64 |
+
|
| 65 |
+
# ── Stability ─────────────────────────────────────────────────────────────
|
| 66 |
+
s_d = d.get('stability', d.get('stability_config', {}))
|
| 67 |
+
if isinstance(s_d, dict) and s_d:
|
| 68 |
+
_apply(cfg.stability, s_d, [
|
| 69 |
+
'base_dt', 'adaptive', 'dt_min', 'dt_max',
|
| 70 |
+
'enable_trace_normalization', 'wrap_x',
|
| 71 |
+
'friction', 'velocity_friction_scale',
|
| 72 |
+
'curvature_clamp', 'friction_mode',
|
| 73 |
+
'integrator_type',
|
| 74 |
+
# alias legacy
|
| 75 |
+
'velocity_saturation', # → ignorado, no existe en StabilityConfig
|
| 76 |
+
])
|
| 77 |
+
# Alias de nombres legacy
|
| 78 |
+
if 'toroidal_curvature_scale' in s_d:
|
| 79 |
+
cfg.stability.curvature_clamp = s_d['toroidal_curvature_scale']
|
| 80 |
+
|
| 81 |
+
# ── Dynamics ──────────────────────────────────────────────────────────────
|
| 82 |
+
dyn_d = d.get('dynamics', d.get('dynamics_config', {}))
|
| 83 |
+
if isinstance(dyn_d, dict) and dyn_d:
|
| 84 |
+
if 'type' in dyn_d:
|
| 85 |
+
cfg.dynamics.type = dyn_d['type']
|
| 86 |
+
|
| 87 |
+
# ── Active Inference ──────────────────────────────────────────────────────
|
| 88 |
+
ai_d = d.get('active_inference', d.get('active_inference_config', {}))
|
| 89 |
+
if isinstance(ai_d, dict) and ai_d:
|
| 90 |
+
_apply(cfg.active_inference, ai_d, [
|
| 91 |
+
'enabled', 'holographic_geometry',
|
| 92 |
+
'thermodynamic_geometry', 'plasticity',
|
| 93 |
+
])
|
| 94 |
+
# Dynamic time
|
| 95 |
+
dt_d = ai_d.get('dynamic_time', {})
|
| 96 |
+
if isinstance(dt_d, dict) and dt_d:
|
| 97 |
+
_apply(cfg.active_inference.dynamic_time, dt_d, ['enabled', 'type'])
|
| 98 |
+
# Reactive curvature — es un dict interno
|
| 99 |
+
rc_d = ai_d.get('reactive_curvature', {})
|
| 100 |
+
if isinstance(rc_d, dict) and rc_d:
|
| 101 |
+
cfg.active_inference.reactive_curvature.update(rc_d)
|
| 102 |
+
# Stochasticity — es un dict interno
|
| 103 |
+
st_d = ai_d.get('stochasticity', {})
|
| 104 |
+
if isinstance(st_d, dict) and st_d:
|
| 105 |
+
cfg.active_inference.stochasticity.update(st_d)
|
| 106 |
+
# Curiosity — es un dict interno
|
| 107 |
+
cu_d = ai_d.get('curiosity', {})
|
| 108 |
+
if isinstance(cu_d, dict) and cu_d:
|
| 109 |
+
cfg.active_inference.curiosity.update(cu_d)
|
| 110 |
+
# ── Hysteresis (pueden estar en raíz O dentro de active_inference) ────────
|
| 111 |
+
hyst_src = d.get('hysteresis', ai_d.get('hysteresis', {}) if isinstance(ai_d, dict) else {})
|
| 112 |
+
if isinstance(hyst_src, dict) and hyst_src:
|
| 113 |
+
_apply(cfg.hysteresis, hyst_src, [
|
| 114 |
+
'enabled', 'ghost_force', 'hyst_decay',
|
| 115 |
+
'hyst_update_w', 'hyst_update_b',
|
| 116 |
+
'hyst_readout_w', 'hyst_readout_b',
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
# ── Singularities (pueden estar en raíz O dentro de active_inference) ─────
|
| 120 |
+
sing_src = d.get('singularities', ai_d.get('singularities', {}) if isinstance(ai_d, dict) else {})
|
| 121 |
+
if isinstance(sing_src, dict) and sing_src:
|
| 122 |
+
_apply(cfg.singularities, sing_src, [
|
| 123 |
+
'enabled', 'epsilon', 'strength', 'threshold'
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
# ── Embedding ─────────────────────────────────────────────────────────────
|
| 127 |
+
emb_d = d.get('embedding', d.get('embedding_config', {}))
|
| 128 |
+
if isinstance(emb_d, dict) and emb_d:
|
| 129 |
+
_apply(cfg.embedding, emb_d, [
|
| 130 |
+
'type', 'mode', 'coord_dim', 'impulse_scale', 'omega_0'
|
| 131 |
+
])
|
| 132 |
+
|
| 133 |
+
# ── Readout ───────────────────────────────────────────────────────────────
|
| 134 |
+
read_d = d.get('readout', d.get('readout_config', {}))
|
| 135 |
+
if isinstance(read_d, dict) and read_d:
|
| 136 |
+
_apply(cfg.readout, read_d, ['type'])
|
| 137 |
+
|
| 138 |
+
# ── Mixture ───────────────────────────────────────────────────────────────
|
| 139 |
+
mix_d = d.get('mixture', d.get('mixture_config', {}))
|
| 140 |
+
if isinstance(mix_d, dict) and mix_d:
|
| 141 |
+
_apply(cfg.mixture, mix_d, ['coupler_mode'])
|
| 142 |
+
|
| 143 |
+
# ── Fractal ───────────────────────────────────────────────────────────────
|
| 144 |
+
frac_d = d.get('fractal', {})
|
| 145 |
+
if isinstance(frac_d, dict) and frac_d:
|
| 146 |
+
_apply(cfg.fractal, frac_d, ['enabled', 'threshold', 'alpha'])
|
| 147 |
+
|
| 148 |
+
# ── Top-level trajectory_mode ─────────────────────────────────────────────
|
| 149 |
+
if 'trajectory_mode' in d:
|
| 150 |
+
cfg.trajectory_mode = d['trajectory_mode']
|
| 151 |
+
|
| 152 |
+
# ── Attention/mixer alias (legacy ECG configs) ────────────────────────────
|
| 153 |
+
# 'attention': {'mixer_type': 'low_rank'} — se ignora acá, aplica en ManifoldConfig
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _apply(target, source: dict, keys: list) -> None:
|
| 157 |
+
"""Copia las claves presentes en source hacia target (setattr)."""
|
| 158 |
+
for k in keys:
|
| 159 |
+
if k in source:
|
| 160 |
+
try:
|
| 161 |
+
setattr(target, k, source[k])
|
| 162 |
+
except AttributeError:
|
| 163 |
+
pass # clave no existe en el dataclass — ignorar silenciosamente
|
gfn/realizations/gssm/config/schema.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# schema.py — GFN V5
|
| 2 |
+
# Definiciones de clases de configuración (Schema)
|
| 3 |
+
# SEPARACIÓN: Los valores por defecto van a defaults.py, las constantes físicas a constants.py
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field, asdict
|
| 6 |
+
from typing import Dict, Any, Optional, List
|
| 7 |
+
|
| 8 |
+
# Importar constantes físicas正确adas
|
| 9 |
+
from ..constants import (
|
| 10 |
+
EPSILON_STANDARD,
|
| 11 |
+
TOPOLOGY_TORUS,
|
| 12 |
+
MIN_DT,
|
| 13 |
+
MAX_DT,
|
| 14 |
+
CURVATURE_CLAMP,
|
| 15 |
+
SINGULARITY_THRESHOLD,
|
| 16 |
+
BLACK_HOLE_STRENGTH,
|
| 17 |
+
DEFAULT_DT,
|
| 18 |
+
DEFAULT_FRICTION,
|
| 19 |
+
DEFAULT_PLASTICITY,
|
| 20 |
+
MAX_VELOCITY,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TopologyConfig:
|
| 26 |
+
"""Configuración de topología del manifold."""
|
| 27 |
+
type: str = TOPOLOGY_TORUS
|
| 28 |
+
R: float = 2.0 # Radio mayor del toro (default)
|
| 29 |
+
r: float = 1.0 # Radio menor del toro (default)
|
| 30 |
+
curvature: float = 0.0
|
| 31 |
+
riemannian_type: str = 'reactive'
|
| 32 |
+
riemannian_rank: int = 16
|
| 33 |
+
riemannian_class: Optional[str] = None
|
| 34 |
+
geometry_scope: str = 'local' # 'local' (dim/heads) or 'global' (full dim)
|
| 35 |
+
# NUEVO: Parámetros aprendibles
|
| 36 |
+
learnable_R: bool = True # Hacer R aprendible (como dice el paper)
|
| 37 |
+
learnable_r: bool = True # Hacer r aprendible (como dice el paper)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class StabilityConfig:
|
| 42 |
+
"""Configuración de estabilidad numérica."""
|
| 43 |
+
base_dt: float = DEFAULT_DT
|
| 44 |
+
adaptive: bool = True
|
| 45 |
+
dt_min: float = MIN_DT
|
| 46 |
+
dt_max: float = MAX_DT
|
| 47 |
+
enable_trace_normalization: bool = True
|
| 48 |
+
wrap_x: bool = True
|
| 49 |
+
friction: float = DEFAULT_FRICTION
|
| 50 |
+
velocity_friction_scale: float = 0.0
|
| 51 |
+
velocity_saturation: float = 0.0 # P2.3: 0 = disabled, >0 = clamp magnitude via tanh
|
| 52 |
+
curvature_clamp: float = CURVATURE_CLAMP
|
| 53 |
+
friction_mode: str = 'static' # 'static' or 'lif'
|
| 54 |
+
integrator_type: str = 'leapfrog'
|
| 55 |
+
toroidal_curvature_scale: float = 0.01 # scale for torus Christoffel contribution
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class DynamicTimeConfig:
|
| 60 |
+
enabled: bool = False
|
| 61 |
+
type: str = 'riemannian'
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class HysteresisConfig:
|
| 66 |
+
enabled: bool = False
|
| 67 |
+
ghost_force: bool = True
|
| 68 |
+
hyst_decay: float = 0.1
|
| 69 |
+
hyst_update_w: float = 1.0
|
| 70 |
+
hyst_update_b: float = 0.0
|
| 71 |
+
hyst_readout_w: float = 1.0
|
| 72 |
+
hyst_readout_b: float = 0.0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ActiveInferenceConfig:
|
| 77 |
+
enabled: bool = False
|
| 78 |
+
holographic_geometry: bool = False
|
| 79 |
+
thermodynamic_geometry: bool = False
|
| 80 |
+
plasticity: float = DEFAULT_PLASTICITY
|
| 81 |
+
dynamic_time: DynamicTimeConfig = field(default_factory=DynamicTimeConfig)
|
| 82 |
+
reactive_curvature: Dict[str, Any] = field(default_factory=lambda: {
|
| 83 |
+
"enabled": False,
|
| 84 |
+
"plasticity": 0.0
|
| 85 |
+
})
|
| 86 |
+
geodesic_lensing: Dict[str, Any] = field(default_factory=lambda: {"enabled": False})
|
| 87 |
+
|
| 88 |
+
# Exploration / Noise
|
| 89 |
+
stochasticity: Dict[str, Any] = field(default_factory=lambda: {
|
| 90 |
+
"enabled": False,
|
| 91 |
+
"type": "brownian",
|
| 92 |
+
"sigma": 0.01,
|
| 93 |
+
"theta": 0.15,
|
| 94 |
+
"mu": 0.0
|
| 95 |
+
})
|
| 96 |
+
curiosity: Dict[str, Any] = field(default_factory=lambda: {
|
| 97 |
+
"enabled": False,
|
| 98 |
+
"strength": 0.1,
|
| 99 |
+
"decay": 0.99
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class EmbeddingConfig:
|
| 105 |
+
type: str = 'standard'
|
| 106 |
+
mode: str = 'linear'
|
| 107 |
+
coord_dim: int = 16
|
| 108 |
+
impulse_scale: float = 1.0
|
| 109 |
+
omega_0: float = 30.0
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class ReadoutConfig:
|
| 114 |
+
type: str = 'standard'
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class MixtureConfig:
|
| 119 |
+
coupler_mode: str = 'mean_field'
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclass
|
| 123 |
+
class DynamicsConfig:
|
| 124 |
+
type: str = 'direct'
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@dataclass
|
| 128 |
+
class FractalConfig:
|
| 129 |
+
enabled: bool = False
|
| 130 |
+
threshold: float = 0.5
|
| 131 |
+
alpha: float = 0.2
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class SingularityConfig:
|
| 136 |
+
enabled: bool = False
|
| 137 |
+
epsilon: float = EPSILON_STANDARD
|
| 138 |
+
strength: float = BLACK_HOLE_STRENGTH
|
| 139 |
+
threshold: float = SINGULARITY_THRESHOLD
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class PhysicsConfig:
|
| 144 |
+
"""Configuración completa de física."""
|
| 145 |
+
topology: TopologyConfig = field(default_factory=TopologyConfig)
|
| 146 |
+
stability: StabilityConfig = field(default_factory=StabilityConfig)
|
| 147 |
+
dynamics: DynamicsConfig = field(default_factory=DynamicsConfig)
|
| 148 |
+
active_inference: ActiveInferenceConfig = field(default_factory=ActiveInferenceConfig)
|
| 149 |
+
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
| 150 |
+
readout: ReadoutConfig = field(default_factory=ReadoutConfig)
|
| 151 |
+
mixture: MixtureConfig = field(default_factory=MixtureConfig)
|
| 152 |
+
fractal: FractalConfig = field(default_factory=FractalConfig)
|
| 153 |
+
hysteresis: HysteresisConfig = field(default_factory=HysteresisConfig)
|
| 154 |
+
singularities: SingularityConfig = field(default_factory=SingularityConfig)
|
| 155 |
+
trajectory_mode: str = 'partition'
|
| 156 |
+
lensing: Dict[str, Any] = field(default_factory=lambda: {'enabled': False})
|
| 157 |
+
checkpointing: Dict[str, Any] = field(default_factory=lambda: {'enabled': False})
|
| 158 |
+
|
| 159 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 160 |
+
return asdict(self)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclass
|
| 164 |
+
class TrainerConfig:
|
| 165 |
+
lr: float = 1e-3
|
| 166 |
+
optimizer: str = 'adamw'
|
| 167 |
+
max_lr: Optional[float] = None
|
| 168 |
+
total_steps: Optional[int] = None
|
| 169 |
+
loss_config: Dict[str, Any] = field(default_factory=lambda: {
|
| 170 |
+
'lambda_g': 0.001,
|
| 171 |
+
'lambda_h': 0.0,
|
| 172 |
+
'geodesic_mode': 'magnitude'
|
| 173 |
+
})
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@dataclass
|
| 177 |
+
class ManifoldConfig:
|
| 178 |
+
"""Configuración principal del modelo Manifold."""
|
| 179 |
+
vocab_size: int
|
| 180 |
+
dim: int = 512
|
| 181 |
+
depth: int = 4
|
| 182 |
+
heads: int = 4
|
| 183 |
+
rank: int = 32
|
| 184 |
+
integrator: str = 'leapfrog'
|
| 185 |
+
physics: PhysicsConfig = field(default_factory=PhysicsConfig)
|
| 186 |
+
trainer: TrainerConfig = field(default_factory=TrainerConfig)
|
| 187 |
+
adjoint_enabled: bool = False
|
| 188 |
+
adjoint_rtol: float = 1e-4
|
| 189 |
+
adjoint_atol: float = 1e-4
|
| 190 |
+
holographic: bool = False
|
| 191 |
+
impulse_scale: float = 1.0
|
| 192 |
+
dynamics_type: str = 'direct'
|
| 193 |
+
mixer_type: str = 'low_rank'
|
| 194 |
+
trajectory_mode: str = 'partition'
|
| 195 |
+
coupler_mode: str = 'mean_field'
|
| 196 |
+
initial_spread: float = 1e-3
|
| 197 |
+
|
| 198 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 199 |
+
return asdict(self)
|
gfn/realizations/gssm/config/serialization.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from typing import Any, Dict, Type, TypeVar, get_type_hints, get_args, get_origin, Union
|
| 3 |
+
|
| 4 |
+
T = TypeVar('T')
|
| 5 |
+
|
| 6 |
+
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
| 7 |
+
"""
|
| 8 |
+
Reconstructs a nested dataclass from a dictionary.
|
| 9 |
+
Handles nested dataclasses and basic types.
|
| 10 |
+
"""
|
| 11 |
+
if not dataclasses.is_dataclass(cls):
|
| 12 |
+
return data
|
| 13 |
+
|
| 14 |
+
field_types = get_type_hints(cls)
|
| 15 |
+
kwargs = {}
|
| 16 |
+
|
| 17 |
+
for field in dataclasses.fields(cls):
|
| 18 |
+
if field.name in data:
|
| 19 |
+
value = data[field.name]
|
| 20 |
+
field_type = field_types[field.name]
|
| 21 |
+
|
| 22 |
+
# Handle Optional[T]
|
| 23 |
+
origin = get_origin(field_type)
|
| 24 |
+
if origin is Union:
|
| 25 |
+
args = get_args(field_type)
|
| 26 |
+
if type(None) in args:
|
| 27 |
+
# It's an Optional. Find the non-None type
|
| 28 |
+
field_type = [arg for arg in args if arg is not type(None)][0]
|
| 29 |
+
|
| 30 |
+
# Handle nested dataclasses
|
| 31 |
+
if dataclasses.is_dataclass(field_type):
|
| 32 |
+
if value is not None:
|
| 33 |
+
kwargs[field.name] = from_dict(field_type, value)
|
| 34 |
+
else:
|
| 35 |
+
kwargs[field.name] = None
|
| 36 |
+
else:
|
| 37 |
+
kwargs[field.name] = value
|
| 38 |
+
|
| 39 |
+
return cls(**kwargs)
|
gfn/realizations/gssm/config/validator.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validación de configuraciones — GFN V5
|
| 3 |
+
Verifica la compatibilidad de parámetros antes de construir componentes.
|
| 4 |
+
Fusionado de utils/validation.py y config/validator.py original.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Any, List, Optional
|
| 8 |
+
from .schema import ManifoldConfig, PhysicsConfig
|
| 9 |
+
from ..constants import TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN, TOPOLOGY_SPHERE
|
| 10 |
+
|
| 11 |
+
class ConfigValidationError(Exception):
|
| 12 |
+
"""Error de validación de configuración crítica."""
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
class ConfigValidator:
|
| 16 |
+
"""Central validator for GFN configurations."""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def validate_physics(cfg: PhysicsConfig, dim: Optional[int] = None, heads: Optional[int] = None):
|
| 20 |
+
"""
|
| 21 |
+
Validate physical and architectural consistency of PhysicsConfig.
|
| 22 |
+
Raises ConfigValidationError if strict topology/stability rules are violated.
|
| 23 |
+
"""
|
| 24 |
+
# 1. Topology checks
|
| 25 |
+
if cfg.topology.type == TOPOLOGY_TORUS:
|
| 26 |
+
if dim is not None and heads is not None:
|
| 27 |
+
head_dim = dim // heads
|
| 28 |
+
if head_dim % 2 != 0:
|
| 29 |
+
raise ConfigValidationError(
|
| 30 |
+
f"Toroid geometry requires head_dim (dim//heads) to be even. "
|
| 31 |
+
f"Found {dim}//{heads}={head_dim}"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if cfg.topology.type == TOPOLOGY_SPHERE and cfg.topology.curvature <= 0:
|
| 35 |
+
raise ConfigValidationError("Spherical topology requires positive curvature.")
|
| 36 |
+
|
| 37 |
+
# 2. Stability checks
|
| 38 |
+
if cfg.stability.base_dt <= 0:
|
| 39 |
+
raise ConfigValidationError("base_dt must be positive.")
|
| 40 |
+
if cfg.stability.friction < 0:
|
| 41 |
+
raise ConfigValidationError("friction cannot be negative.")
|
| 42 |
+
|
| 43 |
+
# 3. Mode Compatibility
|
| 44 |
+
if cfg.trajectory_mode == 'ensemble' and heads is not None and heads <= 1:
|
| 45 |
+
raise ConfigValidationError("Ensemble trajectory mode requires more than 1 head.")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def validate_manifold_config(config: ManifoldConfig) -> List[str]:
|
| 49 |
+
"""
|
| 50 |
+
Valida un ManifoldConfig completo y su PhysicsConfig anidado.
|
| 51 |
+
Retorna lista de warnings (vacía si todo está OK).
|
| 52 |
+
Lanza ConfigValidationError en errores críticos o de compatibilidad.
|
| 53 |
+
"""
|
| 54 |
+
warnings = []
|
| 55 |
+
|
| 56 |
+
# Validaciones críticas (Raise exceptions)
|
| 57 |
+
if config.dim % config.heads != 0:
|
| 58 |
+
raise ConfigValidationError(
|
| 59 |
+
f"dim={config.dim} no es divisible por heads={config.heads}. "
|
| 60 |
+
f"head_dim={config.dim/config.heads:.1f} no es entero."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if config.vocab_size <= 0:
|
| 64 |
+
raise ConfigValidationError(f"vocab_size={config.vocab_size} debe ser > 0.")
|
| 65 |
+
|
| 66 |
+
if config.depth <= 0:
|
| 67 |
+
raise ConfigValidationError(f"depth={config.depth} debe ser > 0.")
|
| 68 |
+
|
| 69 |
+
# Validate Physics properties via centralized method
|
| 70 |
+
ConfigValidator.validate_physics(config.physics, config.dim, config.heads)
|
| 71 |
+
|
| 72 |
+
# Validaciones suaves (Warnings)
|
| 73 |
+
head_dim = config.dim // config.heads
|
| 74 |
+
topo_type = config.physics.topology.type.lower()
|
| 75 |
+
|
| 76 |
+
if topo_type == TOPOLOGY_TORUS and head_dim % 2 != 0:
|
| 77 |
+
warnings.append(
|
| 78 |
+
f"[WARN] Para geometría toroidal, head_dim={head_dim} debería ser par "
|
| 79 |
+
f"para representaciones sin/cos. Considera usar heads={config.dim // (head_dim + 1)} o similar."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if config.rank > config.dim:
|
| 83 |
+
warnings.append(
|
| 84 |
+
f"[WARN] rank={config.rank} > dim={config.dim}. "
|
| 85 |
+
f"La descomposición no es de rango bajo. ¿Intencional?"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
dt = config.physics.stability.base_dt
|
| 89 |
+
if dt > 1.0:
|
| 90 |
+
warnings.append(f"[WARN] base_dt={dt} > 1.0 puede causar inestabilidad numérica.")
|
| 91 |
+
if dt < 1e-5:
|
| 92 |
+
warnings.append(f"[WARN] base_dt={dt} < 1e-5 puede ralentizar convergencia.")
|
| 93 |
+
|
| 94 |
+
return warnings
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def validate_and_print(config: ManifoldConfig) -> bool:
|
| 98 |
+
"""
|
| 99 |
+
Valida la configuración e imprime warnings.
|
| 100 |
+
Retorna True si es válida, False si hubo errores.
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
warnings = validate_manifold_config(config)
|
| 104 |
+
for w in warnings:
|
| 105 |
+
print(w)
|
| 106 |
+
return True
|
| 107 |
+
except ConfigValidationError as e:
|
| 108 |
+
print(f"[CONFIG ERROR] {e}")
|
| 109 |
+
return False
|
gfn/realizations/gssm/constants.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# constants.py — GFN V5
|
| 2 |
+
# Constantes físicas y matemáticas universales.
|
| 3 |
+
# NO contiene hiperparámetros de entrenamiento ni valores configurables.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# ─── Constantes Matemáticas ─────────────────────────────────────────────────
|
| 8 |
+
PI = 3.14159265358979
|
| 9 |
+
E = 2.718281828459045
|
| 10 |
+
SQRT_2 = 1.4142135623730951
|
| 11 |
+
LOG_2 = 0.6931471805599453
|
| 12 |
+
|
| 13 |
+
# ─── Estabilidad Numérica ─────────────────────────────────────────────────
|
| 14 |
+
EPS = 1e-8
|
| 15 |
+
INF = 1e12
|
| 16 |
+
EPSILON_STANDARD = 1e-7
|
| 17 |
+
EPSILON_SMOOTH = 1e-9
|
| 18 |
+
EPSILON_STRONG = 1e-6
|
| 19 |
+
CLAMP_MIN_STRONG = 1e-4
|
| 20 |
+
|
| 21 |
+
# ─── Límites Físicos ───────────────────────────────────────────────────────
|
| 22 |
+
MIN_DT = 0.001
|
| 23 |
+
MAX_DT = 1.0
|
| 24 |
+
|
| 25 |
+
# ─── Geometría / Curvatura ─────────────────────────────────────────────────
|
| 26 |
+
CURVATURE_CLAMP = 5.0 # Maximum absolute value of Christoffel output
|
| 27 |
+
FRICTION_SCALE = 0.1 # Global friction scaling factor
|
| 28 |
+
VELOCITY_FRICTION_SCALE = 0.01
|
| 29 |
+
|
| 30 |
+
# ─── Gate initialization constants ─────────────────────────────────────────
|
| 31 |
+
GATE_BIAS_OPEN = 2.0 # sigmoid(2.0) ≈ 0.88
|
| 32 |
+
GATE_BIAS_CLOSED = -2.0 # sigmoid(-2.0) ≈ 0.12
|
| 33 |
+
|
| 34 |
+
# ─── Singularity / Active Inference ───────────────────────────────────────
|
| 35 |
+
SINGULARITY_THRESHOLD = 0.5
|
| 36 |
+
BLACK_HOLE_STRENGTH = 3.0
|
| 37 |
+
SINGULARITY_GATE_SLOPE = 10.0
|
| 38 |
+
|
| 39 |
+
# ─── Torus geometry ───────────────────────────────────────────────────────
|
| 40 |
+
TOROIDAL_MAJOR_RADIUS = 1.0
|
| 41 |
+
TOROIDAL_MINOR_RADIUS = 0.3
|
| 42 |
+
TOROIDAL_PERIOD = 2.0 * PI
|
| 43 |
+
TOROIDAL_CURVATURE_SCALE = 0.1
|
| 44 |
+
|
| 45 |
+
# ─── Tipo de dato por defecto ─────────────────────────────────────────────
|
| 46 |
+
DTYPE = torch.float32
|
| 47 |
+
|
| 48 |
+
# ─── Topology Names ───────────────────────────────────────────────────────
|
| 49 |
+
TOPOLOGY_TORUS = "torus"
|
| 50 |
+
TOPOLOGY_SPHERE = "spherical"
|
| 51 |
+
TOPOLOGY_HYPERBOLIC = "hyperbolic"
|
| 52 |
+
TOPOLOGY_EUCLIDEAN = "euclidean"
|
| 53 |
+
|
| 54 |
+
# ─── Dynamics Modes ───────────────────────────────────────────────────────
|
| 55 |
+
DYNAMICS_DIRECT = "direct"
|
| 56 |
+
DYNAMICS_RESIDUAL = "residual"
|
| 57 |
+
DYNAMICS_MIX = "mix"
|
| 58 |
+
DYNAMICS_GATED = "gated"
|
| 59 |
+
DYNAMICS_STOCHASTIC = "stochastic"
|
| 60 |
+
|
| 61 |
+
# ─── Alias de compatibilidad (valores por defecto que moved a defaults.py) ─
|
| 62 |
+
# NOTA: Estos valores se mantienen aquí por compatibilidad pero deberían
|
| 63 |
+
# imports desde config/defaults.py en código nuevo
|
| 64 |
+
DEFAULT_FRICTION = 0.01
|
| 65 |
+
DEFAULT_DT = 0.1
|
| 66 |
+
DEFAULT_PLASTICITY = 0.05
|
| 67 |
+
MAX_VELOCITY = 10.0
|
gfn/realizations/gssm/core/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/__init__.py — GFN V5
|
| 3 |
+
"""
|
| 4 |
+
from ..core.types import ManifoldState, Trajectory, StepResult, ModelOutput
|
| 5 |
+
from ..core.state import ManifoldStateManager
|
| 6 |
+
|
| 7 |
+
__all__ = ['ManifoldState', 'Trajectory', 'StepResult', 'ModelOutput', 'ManifoldStateManager']
|
gfn/realizations/gssm/core/state.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/state.py — GFN V5
|
| 3 |
+
Manejo de estado del manifold (posición + velocidad).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ManifoldStateManager:
|
| 12 |
+
"""
|
| 13 |
+
Gestiona la inicialización y manipulación del estado (x, v).
|
| 14 |
+
Compatible con batches y múltiples cabezales.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def initialize(x0: nn.Parameter, v0: nn.Parameter,
|
| 19 |
+
batch_size: int, n_trajectories: int = 1,
|
| 20 |
+
initial_spread: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 21 |
+
"""
|
| 22 |
+
Inicializa el estado (x, v) para un batch dado.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
x0, v0: Parámetros iniciales [1, H, HD]
|
| 26 |
+
batch_size: Tamaño del batch
|
| 27 |
+
n_trajectories: Número de trayectorias paralelas
|
| 28 |
+
initial_spread: Ruido inicial
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
(x, v) — [B, H, HD]
|
| 32 |
+
"""
|
| 33 |
+
x = x0.expand(batch_size, -1, -1)
|
| 34 |
+
v = v0.expand(batch_size, -1, -1)
|
| 35 |
+
|
| 36 |
+
if initial_spread > 0:
|
| 37 |
+
x = x + torch.randn_like(x) * initial_spread
|
| 38 |
+
|
| 39 |
+
return x.contiguous(), v.contiguous()
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def from_tuple(state: Optional[Tuple], x0: nn.Parameter, v0: nn.Parameter,
|
| 43 |
+
batch_size: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
"""
|
| 45 |
+
Construye (x, v) desde un estado previo o desde parámetros iniciales.
|
| 46 |
+
Compatible con el API de BasicModel.
|
| 47 |
+
"""
|
| 48 |
+
if state is not None and isinstance(state, (tuple, list)) and len(state) == 2:
|
| 49 |
+
return state[0], state[1]
|
| 50 |
+
return ManifoldStateManager.initialize(x0, v0, batch_size, **kwargs)
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def wrap_torus(x: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
"""Proyecta posición al dominio toroidal [-π, π]."""
|
| 55 |
+
return torch.atan2(torch.sin(x), torch.cos(x))
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def energy(v: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
"""Energía cinética H = 0.5 * ||v||² por muestra."""
|
| 60 |
+
return 0.5 * (v ** 2).sum(dim=-1)
|
gfn/realizations/gssm/core/types.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/types.py — GFN V5
|
| 3 |
+
Tipos y type aliases del framework.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Dict, Any, Tuple, Optional, List, Union
|
| 8 |
+
|
| 9 |
+
# ─── State types ─────────────────────────────────────────────────────────────
|
| 10 |
+
# (position, velocity) pair
|
| 11 |
+
ManifoldState = Tuple[torch.Tensor, torch.Tensor]
|
| 12 |
+
|
| 13 |
+
# Trajectory: list of (x, v) states over time
|
| 14 |
+
Trajectory = List[ManifoldState]
|
| 15 |
+
|
| 16 |
+
# Force tensor (same shape as x, v)
|
| 17 |
+
Force = torch.Tensor
|
| 18 |
+
|
| 19 |
+
# Integration step result
|
| 20 |
+
StepResult = Dict[str, torch.Tensor] # {'x': ..., 'v': ...}
|
| 21 |
+
|
| 22 |
+
# ─── Config types ─────────────────────────────────────────────────────────────
|
| 23 |
+
ConfigDict = Dict[str, Any]
|
| 24 |
+
|
| 25 |
+
# ─── Forward pass outputs ─────────────────────────────────────────────────────
|
| 26 |
+
# (logits, state, info_dict)
|
| 27 |
+
ModelOutput = Tuple[torch.Tensor, ManifoldState, Dict[str, Any]]
|
gfn/realizations/gssm/csrc/README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Native C++/CUDA Extensions
|
| 2 |
+
Store raw .cpp and .cu files here. Bindings remain in new_gfn/cuda/
|
gfn/realizations/gssm/csrc/compile_cuda_12.9.bat
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
|
| 3 |
+
REM Ensure we are in the script's directory so we run the correct setup.py
|
| 4 |
+
cd /d "%~dp0"
|
| 5 |
+
|
| 6 |
+
echo ================================================================
|
| 7 |
+
echo [GFN] Custom CUDA Kernel Compilation Pipeline (VS 2022 + CUDA 12.9)
|
| 8 |
+
echo ================================================================
|
| 9 |
+
|
| 10 |
+
REM --- 1. Find Visual Studio 2022 Installation ---
|
| 11 |
+
set "VS_PATH="
|
| 12 |
+
|
| 13 |
+
if exist "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" (
|
| 14 |
+
set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat"
|
| 15 |
+
) else if exist "C:\Program Files\Microsoft Visual Studio\2022\Professional\VC\Auxiliary\Build\vcvars64.bat" (
|
| 16 |
+
set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Professional\VC\Auxiliary\Build\vcvars64.bat"
|
| 17 |
+
) else if exist "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" (
|
| 18 |
+
set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if "%VS_PATH%"=="" (
|
| 22 |
+
echo [ERROR] Could not find Visual Studio 2022 installation.
|
| 23 |
+
pause
|
| 24 |
+
exit /b 1
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
echo [*] Found MSVC Environment: "%VS_PATH%"
|
| 28 |
+
echo [*] Initializing Developer Console...
|
| 29 |
+
call "%VS_PATH%"
|
| 30 |
+
|
| 31 |
+
REM --- 2. Setup CUDA Environment (12.9) ---
|
| 32 |
+
set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9"
|
| 33 |
+
set "PATH=%CUDA_PATH%\bin;%PATH%"
|
| 34 |
+
|
| 35 |
+
echo [*] Using CUDA Path: "%CUDA_PATH%"
|
| 36 |
+
nvcc --version
|
| 37 |
+
|
| 38 |
+
REM --- 3. Compile Kernels ---
|
| 39 |
+
echo [*] Cleaning old builds...
|
| 40 |
+
rmdir /s /q build
|
| 41 |
+
rmdir /s /q gfn_cuda.egg-info
|
| 42 |
+
del /q *.pyc
|
| 43 |
+
rmdir /s /q __pycache__
|
| 44 |
+
|
| 45 |
+
echo.
|
| 46 |
+
echo [*] Starting Setup Compilation (In-place)...
|
| 47 |
+
|
| 48 |
+
REM Fix for "It seems that the VC environment is activated..." warning
|
| 49 |
+
set DISTUTILS_USE_SDK=1
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
python setup.py build_ext --inplace
|
| 54 |
+
|
| 55 |
+
if %errorlevel% neq 0 (
|
| 56 |
+
echo [ERROR] Compilation failed.
|
| 57 |
+
echo Ensure you have PyTorch installed for CUDA 12.x:
|
| 58 |
+
echo pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
| 59 |
+
pause
|
| 60 |
+
exit /b 1
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
echo.
|
| 64 |
+
echo [SUCCESS] Kernels compiled to local .pyd file!
|
| 65 |
+
echo Verified import:
|
| 66 |
+
|
| 67 |
+
python -c "print('SUCCESS: gfn_cuda module imported directly!')"
|
| 68 |
+
pause
|
gfn/realizations/gssm/csrc/extension.cpp
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include "integrators/integrators.h"
|
| 4 |
+
|
| 5 |
+
// Declaración de funciones (Toroidal Loss)
|
| 6 |
+
torch::Tensor toroidal_distance_loss_fwd(const torch::Tensor& y_pred, const torch::Tensor& y_true);
|
| 7 |
+
torch::Tensor toroidal_distance_loss_bwd(const torch::Tensor& grad_output, const torch::Tensor& y_pred, const torch::Tensor& y_true);
|
| 8 |
+
|
| 9 |
+
// Declaración de funciones (Low Rank Christoffel)
|
| 10 |
+
torch::Tensor low_rank_christoffel_fwd(
|
| 11 |
+
const torch::Tensor& v, const torch::Tensor& U, const torch::Tensor& W,
|
| 12 |
+
double clamp_val, bool enable_trace_norm, bool is_paper_version);
|
| 13 |
+
|
| 14 |
+
// Implementación de Backward puro en ATen C++ (Compilado por MSVC, evadiendo bug de NVCC CICC)
|
| 15 |
+
std::vector<torch::Tensor> low_rank_christoffel_bwd(
|
| 16 |
+
const torch::Tensor& grad_gamma,
|
| 17 |
+
const torch::Tensor& v,
|
| 18 |
+
const torch::Tensor& U,
|
| 19 |
+
const torch::Tensor& W,
|
| 20 |
+
const torch::Tensor& gamma_out,
|
| 21 |
+
double clamp_val,
|
| 22 |
+
bool enable_trace_norm,
|
| 23 |
+
bool is_paper_version)
|
| 24 |
+
{
|
| 25 |
+
// Fast pure ATen operations avoiding Python overhead
|
| 26 |
+
auto g_norm = gamma_out / clamp_val;
|
| 27 |
+
auto d_tanh = 1.0 - g_norm.pow(2);
|
| 28 |
+
auto grad_raw = grad_gamma * d_tanh; // [B, H, D]
|
| 29 |
+
|
| 30 |
+
if (enable_trace_norm) {
|
| 31 |
+
auto mean_d = grad_raw.mean(-1, /*keepdim=*/true);
|
| 32 |
+
grad_raw = grad_raw - mean_d;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// Explicit Batched Matrix Multiplication (bmm) to avoid matmul broadcast crashes
|
| 36 |
+
// W is [H, D, R], grad_raw is [B, H, D]
|
| 37 |
+
auto grad_raw_h = grad_raw.permute({1, 0, 2}); // [H, B, D]
|
| 38 |
+
auto d_sq_h = torch::bmm(grad_raw_h, W); // [H, B, D] @ [H, D, R] -> [H, B, R]
|
| 39 |
+
auto d_sq = d_sq_h.permute({1, 0, 2}); // [B, H, R]
|
| 40 |
+
|
| 41 |
+
auto v_h = v.permute({1, 0, 2}); // [H, B, D]
|
| 42 |
+
auto v_r_h = torch::bmm(v_h, U); // [H, B, D] @ [H, D, R] -> [H, B, R]
|
| 43 |
+
auto v_r = v_r_h.permute({1, 0, 2}); // [B, H, R]
|
| 44 |
+
|
| 45 |
+
torch::Tensor d_vr;
|
| 46 |
+
if (is_paper_version) {
|
| 47 |
+
auto vr_norm = torch::norm(v_r, 2, -1, true);
|
| 48 |
+
auto denom = 1.0 + vr_norm;
|
| 49 |
+
|
| 50 |
+
// Correct Chain Rule for normalized denominator Coupling:
|
| 51 |
+
// grad_vr_j = grad_phi_j * (2v_j / denom) - v_j * Sum_k(grad_phi_k * v_k^2) / (||v|| * denom^2)
|
| 52 |
+
auto S = (d_sq * v_r.pow(2)).sum(-1, true);
|
| 53 |
+
auto term1 = d_sq * (2.0 * v_r / denom);
|
| 54 |
+
auto term2 = (v_r * S) / (vr_norm * denom.pow(2) + 1e-8);
|
| 55 |
+
d_vr = term1 - term2;
|
| 56 |
+
} else {
|
| 57 |
+
d_vr = d_sq * 2.0 * v_r;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
auto d_vr_h = d_vr.permute({1, 0, 2}); // [H, B, R]
|
| 61 |
+
auto U_t = U.transpose(-1, -2); // [H, R, D]
|
| 62 |
+
auto d_v_h = torch::bmm(d_vr_h, U_t); // [H, B, R] @ [H, R, D] -> [H, B, D]
|
| 63 |
+
auto d_v = d_v_h.permute({1, 0, 2}); // [B, H, D]
|
| 64 |
+
|
| 65 |
+
// We accumulate W and U gradients over Batch:
|
| 66 |
+
auto sq = is_paper_version ? v_r.pow(2) / (1.0 + torch::norm(v_r, 2, -1, true)) : v_r.pow(2); // [B, H, R]
|
| 67 |
+
auto sq_h = sq.permute({1, 0, 2}); // [H, B, R]
|
| 68 |
+
|
| 69 |
+
auto grad_raw_h_t = grad_raw_h.transpose(-1, -2); // [H, D, B]
|
| 70 |
+
auto d_W = torch::bmm(grad_raw_h_t, sq_h); // [H, D, B] @ [H, B, R] -> [H, D, R]
|
| 71 |
+
|
| 72 |
+
auto v_h_t = v_h.transpose(-1, -2); // [H, D, B]
|
| 73 |
+
auto d_U = torch::bmm(v_h_t, d_vr_h); // [H, D, B] @ [H, B, R] -> [H, D, R]
|
| 74 |
+
|
| 75 |
+
return {d_v, d_U, d_W};
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 79 |
+
m.def("toroidal_distance_loss_fwd", &toroidal_distance_loss_fwd, "Toroidal Distance Loss Forward (CUDA)");
|
| 80 |
+
m.def("toroidal_distance_loss_bwd", &toroidal_distance_loss_bwd, "Toroidal Distance Loss Backward (CUDA)");
|
| 81 |
+
|
| 82 |
+
m.def("low_rank_christoffel_fwd", &low_rank_christoffel_fwd, "Low Rank Christoffel Forward Kernel");
|
| 83 |
+
m.def("low_rank_christoffel_bwd", &low_rank_christoffel_bwd, "Low Rank Christoffel Backward ATen");
|
| 84 |
+
|
| 85 |
+
m.def("yoshida_fwd", &yoshida_fwd_aten, "Yoshida C++ Macro Integrator Step");
|
| 86 |
+
m.def("leapfrog_fwd", &leapfrog_fwd_aten, "Leapfrog C++ Macro Integrator Step");
|
| 87 |
+
}
|
gfn/realizations/gssm/csrc/geometry/low_rank.cu
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
|
| 5 |
+
template <typename scalar_t>
|
| 6 |
+
__global__ void low_rank_christoffel_fwd_kernel(
|
| 7 |
+
const scalar_t* __restrict__ v, // [B, H, D]
|
| 8 |
+
const scalar_t* __restrict__ U, // [H, D, R]
|
| 9 |
+
const scalar_t* __restrict__ W, // [H, D, R]
|
| 10 |
+
scalar_t* __restrict__ gamma, // [B, H, D]
|
| 11 |
+
const int B, const int H, const int D, const int R,
|
| 12 |
+
const scalar_t clamp_val,
|
| 13 |
+
const bool enable_trace_norm,
|
| 14 |
+
const bool is_paper_version)
|
| 15 |
+
{
|
| 16 |
+
// A block computes the output for one (b, h) pair
|
| 17 |
+
int bh = blockIdx.x;
|
| 18 |
+
if (bh >= B * H) return;
|
| 19 |
+
|
| 20 |
+
int h = bh % H;
|
| 21 |
+
|
| 22 |
+
// Dynamic shared memory allocations
|
| 23 |
+
extern __shared__ char smem[];
|
| 24 |
+
scalar_t* v_s_d = reinterpret_cast<scalar_t*>(smem); // Size: D
|
| 25 |
+
scalar_t* vr_sq_s = reinterpret_cast<scalar_t*>(&v_s_d[D]); // Size: R
|
| 26 |
+
scalar_t* gamma_s_d = reinterpret_cast<scalar_t*>(&vr_sq_s[R]); // Size: D
|
| 27 |
+
|
| 28 |
+
const scalar_t* v_b = v + bh * D;
|
| 29 |
+
scalar_t* gamma_b = gamma + bh * D;
|
| 30 |
+
|
| 31 |
+
// Pointers for H offset
|
| 32 |
+
const scalar_t* U_h = U + h * D * R;
|
| 33 |
+
const scalar_t* W_h = W + h * D * R;
|
| 34 |
+
|
| 35 |
+
const int tid = threadIdx.x;
|
| 36 |
+
const int bdim = blockDim.x;
|
| 37 |
+
|
| 38 |
+
// 1. Load v into shared memory
|
| 39 |
+
for (int i = tid; i < D; i += bdim) {
|
| 40 |
+
v_s_d[i] = v_b[i];
|
| 41 |
+
}
|
| 42 |
+
__syncthreads();
|
| 43 |
+
|
| 44 |
+
// 2. Compute v_r = v @ U -> sq = v_r^2
|
| 45 |
+
for (int r = tid; r < R; r += bdim) {
|
| 46 |
+
scalar_t sum = 0;
|
| 47 |
+
for (int j = 0; j < D; ++j) {
|
| 48 |
+
sum += v_s_d[j] * U_h[j * R + r];
|
| 49 |
+
}
|
| 50 |
+
vr_sq_s[r] = sum * sum;
|
| 51 |
+
}
|
| 52 |
+
__syncthreads();
|
| 53 |
+
|
| 54 |
+
// 3. Optional: Paper Low Rank denominator logic
|
| 55 |
+
if (is_paper_version) {
|
| 56 |
+
__shared__ scalar_t block_sum_sq;
|
| 57 |
+
if (tid == 0) block_sum_sq = 0;
|
| 58 |
+
__syncthreads();
|
| 59 |
+
|
| 60 |
+
scalar_t local_sq = 0;
|
| 61 |
+
for (int r = tid; r < R; r += bdim) {
|
| 62 |
+
local_sq += vr_sq_s[r];
|
| 63 |
+
}
|
| 64 |
+
atomicAdd(&block_sum_sq, local_sq);
|
| 65 |
+
__syncthreads();
|
| 66 |
+
|
| 67 |
+
scalar_t norm_vr = sqrt(block_sum_sq);
|
| 68 |
+
scalar_t denom = 1.0 + norm_vr;
|
| 69 |
+
|
| 70 |
+
for (int r = tid; r < R; r += bdim) {
|
| 71 |
+
vr_sq_s[r] = vr_sq_s[r] / denom;
|
| 72 |
+
}
|
| 73 |
+
__syncthreads();
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// 4. Compute gamma_raw = sq @ W.T
|
| 77 |
+
for (int d = tid; d < D; d += bdim) {
|
| 78 |
+
scalar_t sum = 0;
|
| 79 |
+
for (int r = 0; r < R; ++r) {
|
| 80 |
+
sum += vr_sq_s[r] * W_h[d * R + r];
|
| 81 |
+
}
|
| 82 |
+
gamma_s_d[d] = sum;
|
| 83 |
+
}
|
| 84 |
+
__syncthreads();
|
| 85 |
+
|
| 86 |
+
// 5. Trace normalization (mean subtraction)
|
| 87 |
+
scalar_t mean_val = 0;
|
| 88 |
+
if (enable_trace_norm) {
|
| 89 |
+
__shared__ scalar_t block_sum_gamma;
|
| 90 |
+
if (tid == 0) block_sum_gamma = 0;
|
| 91 |
+
__syncthreads();
|
| 92 |
+
|
| 93 |
+
scalar_t local_gamma_sum = 0;
|
| 94 |
+
for (int d = tid; d < D; d += bdim) {
|
| 95 |
+
local_gamma_sum += gamma_s_d[d];
|
| 96 |
+
}
|
| 97 |
+
atomicAdd(&block_sum_gamma, local_gamma_sum);
|
| 98 |
+
__syncthreads();
|
| 99 |
+
|
| 100 |
+
mean_val = block_sum_gamma / D;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// 6. Normalization and storage
|
| 104 |
+
for (int d = tid; d < D; d += bdim) {
|
| 105 |
+
scalar_t g = gamma_s_d[d];
|
| 106 |
+
if (enable_trace_norm) {
|
| 107 |
+
g -= mean_val;
|
| 108 |
+
}
|
| 109 |
+
g = clamp_val * tanh(g / clamp_val);
|
| 110 |
+
gamma_b[d] = g;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 115 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 116 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 117 |
+
|
| 118 |
+
torch::Tensor low_rank_christoffel_fwd(
|
| 119 |
+
const torch::Tensor& v,
|
| 120 |
+
const torch::Tensor& U,
|
| 121 |
+
const torch::Tensor& W,
|
| 122 |
+
double clamp_val,
|
| 123 |
+
bool enable_trace_norm,
|
| 124 |
+
bool is_paper_version)
|
| 125 |
+
{
|
| 126 |
+
CHECK_INPUT(v);
|
| 127 |
+
CHECK_INPUT(U);
|
| 128 |
+
CHECK_INPUT(W);
|
| 129 |
+
|
| 130 |
+
// Ensure shapes: v is [B, H, D], U is [H, D, R], W is [H, D, R]
|
| 131 |
+
int B = v.size(0);
|
| 132 |
+
int H = v.size(1);
|
| 133 |
+
int D = v.size(2);
|
| 134 |
+
int R = U.size(2);
|
| 135 |
+
|
| 136 |
+
auto gamma = torch::empty_like(v);
|
| 137 |
+
|
| 138 |
+
const int threads = 256;
|
| 139 |
+
const int blocks = B * H;
|
| 140 |
+
|
| 141 |
+
// Shared memory size: (D + R + D) * sizeof(float)
|
| 142 |
+
const int shared_mem_size = (2 * D + R) * sizeof(float);
|
| 143 |
+
|
| 144 |
+
if (v.scalar_type() == torch::kFloat32) {
|
| 145 |
+
low_rank_christoffel_fwd_kernel<float><<<blocks, threads, shared_mem_size>>>(
|
| 146 |
+
v.data_ptr<float>(),
|
| 147 |
+
U.data_ptr<float>(),
|
| 148 |
+
W.data_ptr<float>(),
|
| 149 |
+
gamma.data_ptr<float>(),
|
| 150 |
+
B, H, D, R,
|
| 151 |
+
static_cast<float>(clamp_val),
|
| 152 |
+
enable_trace_norm,
|
| 153 |
+
is_paper_version
|
| 154 |
+
);
|
| 155 |
+
} else {
|
| 156 |
+
TORCH_CHECK(false, "low_rank_christoffel_fwd only supports float32");
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
return gamma;
|
| 160 |
+
}
|
gfn/realizations/gssm/csrc/integrators/integrators.cpp
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#define _USE_MATH_DEFINES
|
| 4 |
+
#include <cmath>
|
| 5 |
+
|
| 6 |
+
#ifndef M_PI
|
| 7 |
+
#define M_PI 3.14159265358979323846
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
// -------------------------------------------------------------
|
| 11 |
+
// Pure ATen Implementation of the Integrators Loop
|
| 12 |
+
// -------------------------------------------------------------
|
| 13 |
+
// We build this in standard C++ using ATen so it runs in a single
|
| 14 |
+
// Python call but is compiled by MSVC, averting NVCC OOM bugs.
|
| 15 |
+
// It performs `steps` loop using the exact GFN LowRank geometry.
|
| 16 |
+
|
| 17 |
+
// Helper to compute Christoffel Gamma
|
| 18 |
+
torch::Tensor _compute_gamma(
|
| 19 |
+
const torch::Tensor& v,
|
| 20 |
+
const torch::Tensor& U,
|
| 21 |
+
const torch::Tensor& W,
|
| 22 |
+
double clamp_val,
|
| 23 |
+
bool enable_trace_norm,
|
| 24 |
+
bool is_paper_version)
|
| 25 |
+
{
|
| 26 |
+
auto v_r = torch::matmul(v.unsqueeze(-2), U).squeeze(-2); // [..., R]
|
| 27 |
+
torch::Tensor sq;
|
| 28 |
+
if (is_paper_version) {
|
| 29 |
+
auto vr_norm = torch::norm(v_r, 2, -1, true);
|
| 30 |
+
sq = v_r.pow(2) / (1.0 + vr_norm);
|
| 31 |
+
} else {
|
| 32 |
+
sq = v_r.pow(2);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
auto gamma = torch::matmul(sq.unsqueeze(-2), W.transpose(-1, -2)).squeeze(-2); // [..., D]
|
| 36 |
+
|
| 37 |
+
if (enable_trace_norm) {
|
| 38 |
+
auto mean_g = gamma.mean(-1, /*keepdim=*/true);
|
| 39 |
+
gamma = gamma - mean_g;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
return clamp_val * torch::tanh(gamma / clamp_val);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// Helper for velocity saturation (soft-clamp)
|
| 46 |
+
torch::Tensor _clamp_velocity(const torch::Tensor& v, double v_sat) {
|
| 47 |
+
if (v_sat > 0) {
|
| 48 |
+
return v_sat * torch::tanh(v / v_sat);
|
| 49 |
+
}
|
| 50 |
+
return v;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Yoshida 4th order coefficients
|
| 54 |
+
const double w1 = 1.3512071919596576;
|
| 55 |
+
const double w0 = -1.7024143839193153;
|
| 56 |
+
const double y_c1 = w1 / 2.0;
|
| 57 |
+
const double y_c2 = (w0 + w1) / 2.0;
|
| 58 |
+
const double y_c3 = y_c2;
|
| 59 |
+
const double y_c4 = y_c1;
|
| 60 |
+
const double y_d1 = w1;
|
| 61 |
+
const double y_d2 = w0;
|
| 62 |
+
const double y_d3 = w1;
|
| 63 |
+
|
| 64 |
+
// Helper for Gated Friction (Active Inference)
|
| 65 |
+
torch::Tensor _compute_mu(
|
| 66 |
+
const torch::Tensor& x,
|
| 67 |
+
const torch::Tensor& v,
|
| 68 |
+
const torch::Tensor& gate_w,
|
| 69 |
+
const torch::Tensor& gate_b,
|
| 70 |
+
double base_friction,
|
| 71 |
+
double vel_fric_scale)
|
| 72 |
+
{
|
| 73 |
+
const double eps = 1e-8;
|
| 74 |
+
const double D = x.size(-1);
|
| 75 |
+
|
| 76 |
+
// mu_base = base_friction
|
| 77 |
+
torch::Tensor mu = torch::full_like(x.select(-1, 0).unsqueeze(-1), base_friction);
|
| 78 |
+
|
| 79 |
+
// If gate weigths are provided, calculate learnable friction component
|
| 80 |
+
if (gate_w.numel() > 0) {
|
| 81 |
+
torch::Tensor feat;
|
| 82 |
+
// Check if we need Torus features [sin, cos] (gate_w dim will be 2*D)
|
| 83 |
+
if (gate_w.size(1) == 2 * D) {
|
| 84 |
+
feat = torch::cat({torch::sin(x), torch::cos(x)}, -1); // [..., 2D]
|
| 85 |
+
} else {
|
| 86 |
+
feat = x; // Euclidean / Flat
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// Linear gate: sigmoid(feat @ w + b)
|
| 90 |
+
auto gate_out = torch::matmul(feat.unsqueeze(-2), gate_w).squeeze(-2); // [B, H, 1]
|
| 91 |
+
if (gate_b.numel() > 0) {
|
| 92 |
+
gate_out = gate_out + gate_b;
|
| 93 |
+
}
|
| 94 |
+
mu = mu + torch::sigmoid(gate_out);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// Velocity-dependent scaling: mu * (1 + scale * ||v||)
|
| 98 |
+
auto v_norm = torch::norm(v, 2, -1, true) / (std::sqrt(D) + eps);
|
| 99 |
+
mu = mu * (1.0 + vel_fric_scale * v_norm);
|
| 100 |
+
|
| 101 |
+
return mu;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Helper for Singularity Damping
|
| 105 |
+
torch::Tensor _apply_singularity_damping(
|
| 106 |
+
const torch::Tensor& acc,
|
| 107 |
+
const torch::Tensor& v,
|
| 108 |
+
const torch::Tensor& U,
|
| 109 |
+
double sing_thresh,
|
| 110 |
+
double sing_strength)
|
| 111 |
+
{
|
| 112 |
+
if (sing_strength <= 1.0 || sing_thresh <= 0.0) return acc;
|
| 113 |
+
|
| 114 |
+
// Detect singularity: metrics are low near singular points.
|
| 115 |
+
// In LowRank, g_diag = sum(U^2).
|
| 116 |
+
auto g_diag = (U.pow(2)).sum(-1); // [H, D]
|
| 117 |
+
|
| 118 |
+
// Potential = sigmoid(5.0 * (g - thresh))
|
| 119 |
+
auto soft_mask = torch::sigmoid(5.0 * (g_diag - sing_thresh));
|
| 120 |
+
|
| 121 |
+
// Scale acceleration by soft_mask (Damping Shield)
|
| 122 |
+
// Near singularity: soft_mask -> 0, damping the forces.
|
| 123 |
+
return acc * soft_mask;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
std::vector<torch::Tensor> yoshida_fwd_aten(
|
| 127 |
+
const torch::Tensor& x_init,
|
| 128 |
+
const torch::Tensor& v_init,
|
| 129 |
+
const torch::Tensor& U,
|
| 130 |
+
const torch::Tensor& W,
|
| 131 |
+
const torch::Tensor& force,
|
| 132 |
+
const torch::Tensor& dt,
|
| 133 |
+
int steps,
|
| 134 |
+
double clamp_val,
|
| 135 |
+
double friction,
|
| 136 |
+
double vel_fric_scale,
|
| 137 |
+
double vel_sat,
|
| 138 |
+
const torch::Tensor& gate_w,
|
| 139 |
+
const torch::Tensor& gate_b,
|
| 140 |
+
double sing_thresh,
|
| 141 |
+
double sing_strength,
|
| 142 |
+
bool enable_trace_norm,
|
| 143 |
+
bool is_paper_version)
|
| 144 |
+
{
|
| 145 |
+
auto x = x_init.clone();
|
| 146 |
+
auto v = v_init.clone();
|
| 147 |
+
|
| 148 |
+
const double eps = 1e-8;
|
| 149 |
+
|
| 150 |
+
for (int i = 0; i < steps; ++i) {
|
| 151 |
+
// Sub-step 1
|
| 152 |
+
x = x + y_c1 * dt * v;
|
| 153 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI; // Toroidal resolve
|
| 154 |
+
|
| 155 |
+
auto gamma1 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 156 |
+
auto a1_nf = force - gamma1;
|
| 157 |
+
a1_nf = _apply_singularity_damping(a1_nf, v, U, sing_thresh, sing_strength);
|
| 158 |
+
|
| 159 |
+
auto mu1 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 160 |
+
|
| 161 |
+
v = (v + y_d1 * dt * a1_nf) / (1.0 + y_d1 * dt * mu1 + eps);
|
| 162 |
+
v = _clamp_velocity(v, vel_sat);
|
| 163 |
+
|
| 164 |
+
// Sub-step 2
|
| 165 |
+
x = x + y_c2 * dt * v;
|
| 166 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 167 |
+
|
| 168 |
+
auto gamma2 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 169 |
+
auto a2_nf = force - gamma2;
|
| 170 |
+
a2_nf = _apply_singularity_damping(a2_nf, v, U, sing_thresh, sing_strength);
|
| 171 |
+
|
| 172 |
+
auto mu2 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 173 |
+
|
| 174 |
+
v = (v + y_d2 * dt * a2_nf) / (1.0 + y_d2 * dt * mu2 + eps);
|
| 175 |
+
v = _clamp_velocity(v, vel_sat);
|
| 176 |
+
|
| 177 |
+
// Sub-step 3
|
| 178 |
+
x = x + y_c3 * dt * v;
|
| 179 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 180 |
+
|
| 181 |
+
auto gamma3 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 182 |
+
auto a3_nf = force - gamma3;
|
| 183 |
+
a3_nf = _apply_singularity_damping(a3_nf, v, U, sing_thresh, sing_strength);
|
| 184 |
+
|
| 185 |
+
auto mu3 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 186 |
+
|
| 187 |
+
v = (v + y_d3 * dt * a3_nf) / (1.0 + y_d3 * dt * mu3 + eps);
|
| 188 |
+
v = _clamp_velocity(v, vel_sat);
|
| 189 |
+
|
| 190 |
+
// Final drift
|
| 191 |
+
x = x + y_c4 * dt * v;
|
| 192 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return {x, v};
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
std::vector<torch::Tensor> leapfrog_fwd_aten(
|
| 199 |
+
const torch::Tensor& x_init,
|
| 200 |
+
const torch::Tensor& v_init,
|
| 201 |
+
const torch::Tensor& U,
|
| 202 |
+
const torch::Tensor& W,
|
| 203 |
+
const torch::Tensor& force,
|
| 204 |
+
const torch::Tensor& dt,
|
| 205 |
+
int steps,
|
| 206 |
+
double clamp_val,
|
| 207 |
+
double friction,
|
| 208 |
+
double vel_fric_scale,
|
| 209 |
+
double vel_sat,
|
| 210 |
+
const torch::Tensor& gate_w,
|
| 211 |
+
const torch::Tensor& gate_b,
|
| 212 |
+
double sing_thresh,
|
| 213 |
+
double sing_strength,
|
| 214 |
+
bool enable_trace_norm,
|
| 215 |
+
bool is_paper_version)
|
| 216 |
+
{
|
| 217 |
+
auto x = x_init.clone();
|
| 218 |
+
auto v = v_init.clone();
|
| 219 |
+
|
| 220 |
+
const double eps = 1e-8;
|
| 221 |
+
|
| 222 |
+
for (int i = 0; i < steps; ++i) {
|
| 223 |
+
// Half-kick 1
|
| 224 |
+
auto gamma1 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 225 |
+
auto a1_nf = force - gamma1;
|
| 226 |
+
a1_nf = _apply_singularity_damping(a1_nf, v, U, sing_thresh, sing_strength);
|
| 227 |
+
|
| 228 |
+
auto mu1 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 229 |
+
|
| 230 |
+
auto v_half = (v + 0.5 * dt * a1_nf) / (1.0 + 0.5 * dt * mu1 + eps);
|
| 231 |
+
v_half = _clamp_velocity(v_half, vel_sat);
|
| 232 |
+
|
| 233 |
+
// Drift
|
| 234 |
+
x = x + dt * v_half;
|
| 235 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 236 |
+
|
| 237 |
+
// Half-kick 2
|
| 238 |
+
auto gamma2 = _compute_gamma(v_half, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 239 |
+
auto a2_nf = force - gamma2;
|
| 240 |
+
a2_nf = _apply_singularity_damping(a2_nf, v_half, U, sing_thresh, sing_strength);
|
| 241 |
+
|
| 242 |
+
auto mu2 = _compute_mu(x, v_half, gate_w, gate_b, friction, vel_fric_scale);
|
| 243 |
+
|
| 244 |
+
auto a_avg = (a1_nf + a2_nf) / 2.0;
|
| 245 |
+
auto mu_avg = (mu1 + mu2) / 2.0;
|
| 246 |
+
|
| 247 |
+
v = (v + dt * a_avg) / (1.0 + dt * mu_avg + eps);
|
| 248 |
+
v = _clamp_velocity(v, vel_sat);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
return {x, v};
|
| 252 |
+
}
|
gfn/realizations/gssm/csrc/integrators/integrators.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
|
| 4 |
+
// Forward declarations for the integrators
|
| 5 |
+
std::vector<torch::Tensor> yoshida_fwd_aten(
|
| 6 |
+
const torch::Tensor& x_init,
|
| 7 |
+
const torch::Tensor& v_init,
|
| 8 |
+
const torch::Tensor& U,
|
| 9 |
+
const torch::Tensor& W,
|
| 10 |
+
const torch::Tensor& force,
|
| 11 |
+
const torch::Tensor& dt,
|
| 12 |
+
int steps,
|
| 13 |
+
double clamp_val,
|
| 14 |
+
double friction,
|
| 15 |
+
double vel_fric_scale,
|
| 16 |
+
double vel_sat,
|
| 17 |
+
const torch::Tensor& gate_w,
|
| 18 |
+
const torch::Tensor& gate_b,
|
| 19 |
+
double sing_thresh,
|
| 20 |
+
double sing_strength,
|
| 21 |
+
bool enable_trace_norm,
|
| 22 |
+
bool is_paper_version);
|
| 23 |
+
|
| 24 |
+
std::vector<torch::Tensor> leapfrog_fwd_aten(
|
| 25 |
+
const torch::Tensor& x_init,
|
| 26 |
+
const torch::Tensor& v_init,
|
| 27 |
+
const torch::Tensor& U,
|
| 28 |
+
const torch::Tensor& W,
|
| 29 |
+
const torch::Tensor& force,
|
| 30 |
+
const torch::Tensor& dt,
|
| 31 |
+
int steps,
|
| 32 |
+
double clamp_val,
|
| 33 |
+
double friction,
|
| 34 |
+
double vel_fric_scale,
|
| 35 |
+
double vel_sat,
|
| 36 |
+
const torch::Tensor& gate_w,
|
| 37 |
+
const torch::Tensor& gate_b,
|
| 38 |
+
double sing_thresh,
|
| 39 |
+
double sing_strength,
|
| 40 |
+
bool enable_trace_norm,
|
| 41 |
+
bool is_paper_version);
|
gfn/realizations/gssm/csrc/losses/toroidal.cu
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
#include <math_constants.h>
|
| 5 |
+
|
| 6 |
+
// ------------------------------------------------------------------------
|
| 7 |
+
// Toroidal Distance Loss
|
| 8 |
+
// L(y_pred, y_true) = (atan2(sin(y_pred - y_true), cos(y_pred - y_true)))^2
|
| 9 |
+
// ------------------------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
template <typename scalar_t>
|
| 12 |
+
__global__ void toroidal_distance_loss_fwd_kernel(
|
| 13 |
+
const scalar_t* __restrict__ y_pred,
|
| 14 |
+
const scalar_t* __restrict__ y_true,
|
| 15 |
+
scalar_t* __restrict__ out,
|
| 16 |
+
const int numel)
|
| 17 |
+
{
|
| 18 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 19 |
+
if (idx < numel) {
|
| 20 |
+
scalar_t diff = y_pred[idx] - y_true[idx];
|
| 21 |
+
scalar_t wrapped = atan2(sin(diff), cos(diff));
|
| 22 |
+
out[idx] = wrapped * wrapped;
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template <typename scalar_t>
|
| 27 |
+
__global__ void toroidal_distance_loss_bwd_kernel(
|
| 28 |
+
const scalar_t* __restrict__ grad_output,
|
| 29 |
+
const scalar_t* __restrict__ y_pred,
|
| 30 |
+
const scalar_t* __restrict__ y_true,
|
| 31 |
+
scalar_t* __restrict__ grad_pred,
|
| 32 |
+
const int numel)
|
| 33 |
+
{
|
| 34 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 35 |
+
if (idx < numel) {
|
| 36 |
+
// Derivada de atan2(sin(x), cos(x))^2 respecto a x es 2 * atan2(sin(x), cos(x))
|
| 37 |
+
scalar_t diff = y_pred[idx] - y_true[idx];
|
| 38 |
+
scalar_t wrapped = atan2(sin(diff), cos(diff));
|
| 39 |
+
grad_pred[idx] = grad_output[idx] * 2.0 * wrapped;
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// ------------------------------------------------------------------------
|
| 44 |
+
// Wrappers ATen
|
| 45 |
+
// ------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 48 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 49 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 50 |
+
|
| 51 |
+
torch::Tensor toroidal_distance_loss_fwd(const torch::Tensor& y_pred, const torch::Tensor& y_true) {
|
| 52 |
+
CHECK_INPUT(y_pred);
|
| 53 |
+
CHECK_INPUT(y_true);
|
| 54 |
+
|
| 55 |
+
auto out = torch::empty_like(y_pred);
|
| 56 |
+
int numel = y_pred.numel();
|
| 57 |
+
|
| 58 |
+
const int threads = 256;
|
| 59 |
+
const int blocks = (numel + threads - 1) / threads;
|
| 60 |
+
|
| 61 |
+
if (y_pred.scalar_type() == torch::kFloat32) {
|
| 62 |
+
toroidal_distance_loss_fwd_kernel<float><<<blocks, threads>>>(
|
| 63 |
+
y_pred.data_ptr<float>(),
|
| 64 |
+
y_true.data_ptr<float>(),
|
| 65 |
+
out.data_ptr<float>(),
|
| 66 |
+
numel
|
| 67 |
+
);
|
| 68 |
+
} else {
|
| 69 |
+
TORCH_CHECK(false, "toroidal_distance_loss_fwd only supports float32");
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
return out;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
torch::Tensor toroidal_distance_loss_bwd(const torch::Tensor& grad_output, const torch::Tensor& y_pred, const torch::Tensor& y_true) {
|
| 76 |
+
CHECK_INPUT(grad_output);
|
| 77 |
+
CHECK_INPUT(y_pred);
|
| 78 |
+
CHECK_INPUT(y_true);
|
| 79 |
+
|
| 80 |
+
auto grad_pred = torch::empty_like(y_pred);
|
| 81 |
+
int numel = y_pred.numel();
|
| 82 |
+
|
| 83 |
+
const int threads = 256;
|
| 84 |
+
const int blocks = (numel + threads - 1) / threads;
|
| 85 |
+
|
| 86 |
+
if (y_pred.scalar_type() == torch::kFloat32) {
|
| 87 |
+
toroidal_distance_loss_bwd_kernel<float><<<blocks, threads>>>(
|
| 88 |
+
grad_output.data_ptr<float>(),
|
| 89 |
+
y_pred.data_ptr<float>(),
|
| 90 |
+
y_true.data_ptr<float>(),
|
| 91 |
+
grad_pred.data_ptr<float>(),
|
| 92 |
+
numel
|
| 93 |
+
);
|
| 94 |
+
} else {
|
| 95 |
+
TORCH_CHECK(false, "toroidal_distance_loss_bwd only supports float32");
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
return grad_pred;
|
| 99 |
+
}
|
gfn/realizations/gssm/csrc/setup.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from setuptools import setup
|
| 3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 4 |
+
|
| 5 |
+
# Force sequential compilation to prevent NVCC Out-Of-Memory (LLVM ERROR)
|
| 6 |
+
os.environ["MAX_JOBS"] = "1"
|
| 7 |
+
|
| 8 |
+
# Directorio base
|
| 9 |
+
csrc_dir = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
+
|
| 11 |
+
sources = [
|
| 12 |
+
os.path.join(csrc_dir, "extension.cpp"),
|
| 13 |
+
os.path.join(csrc_dir, "losses", "toroidal.cu"),
|
| 14 |
+
os.path.join(csrc_dir, "geometry", "low_rank.cu"),
|
| 15 |
+
os.path.join(csrc_dir, "integrators", "integrators.cpp")
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
# Configuración específica para MSVC/Windows vs Linux
|
| 19 |
+
extra_compile_args = {
|
| 20 |
+
'cxx': ['-O2'],
|
| 21 |
+
'nvcc': ['-O2', '-allow-unsupported-compiler']
|
| 22 |
+
}
|
| 23 |
+
if os.name == 'nt':
|
| 24 |
+
extra_compile_args['cxx'].append('/std:c++17')
|
| 25 |
+
|
| 26 |
+
setup(
|
| 27 |
+
name='gfn_cuda',
|
| 28 |
+
ext_modules=[
|
| 29 |
+
CUDAExtension(
|
| 30 |
+
name='gfn_cuda',
|
| 31 |
+
sources=sources,
|
| 32 |
+
extra_compile_args=extra_compile_args,
|
| 33 |
+
)
|
| 34 |
+
],
|
| 35 |
+
cmdclass={
|
| 36 |
+
'build_ext': BuildExtension
|
| 37 |
+
}
|
| 38 |
+
)
|
gfn/realizations/gssm/cuda/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/cuda/__init__.py
|
| 3 |
+
Infraestructura CUDA para GFN V5.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
CUDA_AVAILABLE = torch.cuda.is_available()
|
| 8 |
+
|
| 9 |
+
def is_cuda_active(tensor: torch.Tensor) -> bool:
|
| 10 |
+
"""Verifica si CUDA está disponible y el tensor está en un dispositivo GPU."""
|
| 11 |
+
return CUDA_AVAILABLE and tensor.is_cuda
|
gfn/realizations/gssm/cuda/autograd/__init__.py
ADDED
|
File without changes
|
gfn/realizations/gssm/cuda/kernels/__init__.py
ADDED
|
File without changes
|
gfn/realizations/gssm/cuda/kernels/geometry_kernels.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geometry Kernels — GFN V5
|
| 3 |
+
Unified entry points for geometric computations with hardware dispatching.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Optional, Tuple, Union, Any
|
| 8 |
+
from ...registry import GEOMETRY_REGISTRY
|
| 9 |
+
from ...cuda import is_cuda_active
|
| 10 |
+
|
| 11 |
+
# Lazy import for CUDA ops to avoid loading if not available
|
| 12 |
+
_christoffel_cuda = None
|
| 13 |
+
|
| 14 |
+
def _get_cuda_ops():
|
| 15 |
+
global _christoffel_cuda
|
| 16 |
+
if _christoffel_cuda is None:
|
| 17 |
+
try:
|
| 18 |
+
from ...cuda.ops import christoffel_cuda_fwd
|
| 19 |
+
_christoffel_cuda = christoffel_cuda_fwd
|
| 20 |
+
except ImportError:
|
| 21 |
+
pass
|
| 22 |
+
return _christoffel_cuda
|
| 23 |
+
|
| 24 |
+
def unified_christoffel_fwd(
|
| 25 |
+
x: torch.Tensor,
|
| 26 |
+
v: torch.Tensor,
|
| 27 |
+
U: torch.Tensor,
|
| 28 |
+
W: torch.Tensor,
|
| 29 |
+
clamp_val: float = 5.0,
|
| 30 |
+
**kwargs: Any
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Unified forward pass for Christoffel symbols.
|
| 34 |
+
Dispatches to CUDA kernel if available and on GPU, otherwise falls back to PyTorch.
|
| 35 |
+
"""
|
| 36 |
+
if is_cuda_active(v):
|
| 37 |
+
cuda_op = _get_cuda_ops()
|
| 38 |
+
if cuda_op is not None:
|
| 39 |
+
try:
|
| 40 |
+
return _run_cuda_christoffel(x, v, U, W, clamp_val, cuda_op, **kwargs)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
# print(f"[Dispatcher] CUDA Error: {e}. Falling back.")
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
return _run_pytorch_christoffel(x, v, U, W, clamp_val, **kwargs)
|
| 46 |
+
|
| 47 |
+
def _run_cuda_christoffel(x, v, U, W, clamp_val, cuda_op, **kwargs):
|
| 48 |
+
# Kernel expects Head-Aware tensors [B, H, HD]
|
| 49 |
+
# Check if x,v are [B, D] or [B, H, HD]
|
| 50 |
+
if v.dim() == 2:
|
| 51 |
+
x_k, v_k = x.unsqueeze(1), v.unsqueeze(1)
|
| 52 |
+
else:
|
| 53 |
+
x_k, v_k = x, v
|
| 54 |
+
|
| 55 |
+
# U, W handling (assuming LowRank format)
|
| 56 |
+
# This logic matches the legacy kernel expectations
|
| 57 |
+
# W_k handling (ensuring rank-R is preserved per output dimension)
|
| 58 |
+
if U.dim() == 3:
|
| 59 |
+
U_k = U.transpose(1, 2).contiguous() # [H, R, HD]
|
| 60 |
+
W_k = W.transpose(1, 2).contiguous() # [H, R, HD]
|
| 61 |
+
else:
|
| 62 |
+
U_k = U.T.unsqueeze(0).contiguous() # [1, R, HD]
|
| 63 |
+
W_k = W.T.unsqueeze(0).contiguous() # [1, R, HD]
|
| 64 |
+
|
| 65 |
+
# Execute CUDA kernel
|
| 66 |
+
gamma = cuda_op(U_k, W_k, x_k, v_k, 0, 2.0, 1.0, 0.0)
|
| 67 |
+
|
| 68 |
+
if v.dim() == 2:
|
| 69 |
+
gamma = gamma.squeeze(1)
|
| 70 |
+
|
| 71 |
+
return clamp_val * torch.tanh(gamma / clamp_val)
|
| 72 |
+
|
| 73 |
+
def _run_pytorch_christoffel(x, v, U, W, clamp_val, **kwargs):
|
| 74 |
+
# Multi-head PyTorch fallback
|
| 75 |
+
if v.dim() == 3:
|
| 76 |
+
B, H, HD = v.shape
|
| 77 |
+
# Flatten batch and heads to use efficient matmuls
|
| 78 |
+
v_flat = v.reshape(B * H, HD)
|
| 79 |
+
if U.dim() == 3:
|
| 80 |
+
# U: [H, HD, R], W: [H, R, HD]
|
| 81 |
+
# Need to apply per-head
|
| 82 |
+
proj = torch.bmm(v.transpose(0, 1), U).transpose(0, 1) # [B, H, R]
|
| 83 |
+
sq = proj * proj
|
| 84 |
+
W_t = W.transpose(-1, -2)
|
| 85 |
+
gamma = torch.bmm(sq.transpose(0, 1), W_t).transpose(0, 1) # [B, H, HD]
|
| 86 |
+
else:
|
| 87 |
+
# Shared U, W across heads
|
| 88 |
+
proj = torch.matmul(v_flat, U) # [B*H, R]
|
| 89 |
+
sq = proj * proj
|
| 90 |
+
gamma_flat = torch.matmul(sq, W.t()) # [B*H, HD]
|
| 91 |
+
gamma = gamma_flat.view(B, H, HD)
|
| 92 |
+
else:
|
| 93 |
+
# Single head [B, D]
|
| 94 |
+
proj = torch.matmul(v, U[0] if U.dim() == 3 else U)
|
| 95 |
+
sq = proj * proj
|
| 96 |
+
W_t = (W[0] if W.dim() == 3 else W).t()
|
| 97 |
+
gamma = torch.matmul(sq, W_t)
|
| 98 |
+
|
| 99 |
+
return clamp_val * torch.tanh(gamma / clamp_val)
|
gfn/realizations/gssm/cuda/kernels/integrator_kernels.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integrator Kernels — GFN V5
|
| 3 |
+
Unified entry points for numerical integration with hardware dispatching.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Optional, Tuple, Any, Callable
|
| 8 |
+
from ...cuda import is_cuda_active
|
| 9 |
+
|
| 10 |
+
# Lazy imports for CUDA kernels
|
| 11 |
+
_euler_fused = None
|
| 12 |
+
_rk4_fused = None
|
| 13 |
+
_leapfrog_fused = None
|
| 14 |
+
|
| 15 |
+
def _get_cuda_integrators():
|
| 16 |
+
global _euler_fused, _rk4_fused, _leapfrog_fused
|
| 17 |
+
if _euler_fused is None:
|
| 18 |
+
try:
|
| 19 |
+
from ...cuda.ops import euler_fused, rk4_fused, leapfrog_fused
|
| 20 |
+
_euler_fused = euler_fused
|
| 21 |
+
_rk4_fused = rk4_fused
|
| 22 |
+
_leapfrog_fused = leapfrog_fused
|
| 23 |
+
except ImportError:
|
| 24 |
+
pass
|
| 25 |
+
return _euler_fused, _rk4_fused, _leapfrog_fused
|
| 26 |
+
|
| 27 |
+
def unified_leapfrog_step(
|
| 28 |
+
x: torch.Tensor,
|
| 29 |
+
v: torch.Tensor,
|
| 30 |
+
force: Optional[torch.Tensor],
|
| 31 |
+
U: torch.Tensor,
|
| 32 |
+
W: torch.Tensor,
|
| 33 |
+
dt: float,
|
| 34 |
+
steps: int = 1,
|
| 35 |
+
**kwargs
|
| 36 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 37 |
+
"""Unified Leapfrog integration step."""
|
| 38 |
+
if is_cuda_active(v):
|
| 39 |
+
_, _, f_leapfrog = _get_cuda_integrators()
|
| 40 |
+
if f_leapfrog is not None:
|
| 41 |
+
try:
|
| 42 |
+
# Prep parameters for CUDA kernel
|
| 43 |
+
topo_id = kwargs.get('topology_id', 0)
|
| 44 |
+
R = kwargs.get('R', 2.0)
|
| 45 |
+
r = kwargs.get('r', 1.0)
|
| 46 |
+
H = x.shape[1] if x.dim() == 3 else 1
|
| 47 |
+
|
| 48 |
+
# U, W transformations
|
| 49 |
+
if U.dim() == 2:
|
| 50 |
+
# [D, R] -> [1, R, D] -> [H, R, D]
|
| 51 |
+
U_k = U.T.unsqueeze(0).expand(H, -1, -1).contiguous()
|
| 52 |
+
else:
|
| 53 |
+
# [H, D, R] -> [H, R, D]
|
| 54 |
+
U_k = U.transpose(1, 2).contiguous()
|
| 55 |
+
|
| 56 |
+
if W.dim() == 2:
|
| 57 |
+
# [D, R] -> [1, R] -> [H, R]
|
| 58 |
+
# Use mean instead of sum to preserve effective force scale
|
| 59 |
+
W_k = W.mean(dim=0).unsqueeze(0).expand(H, -1).contiguous()
|
| 60 |
+
else:
|
| 61 |
+
# [H, D, R] -> [H, R]
|
| 62 |
+
W_k = W.abs().mean(dim=1).contiguous()
|
| 63 |
+
|
| 64 |
+
cx, cv = x, v
|
| 65 |
+
for _ in range(steps):
|
| 66 |
+
cx, cv = f_leapfrog(U_k, W_k, cx, cv, force, float(dt), int(topo_id), float(R), float(r), 0.0)
|
| 67 |
+
return cx, cv
|
| 68 |
+
except Exception:
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
# Python fallback is handled by the higher-level Integrator classes in gfn/integrators/
|
| 72 |
+
# This unified layer is primarily for hardware acceleration.
|
| 73 |
+
return None, None # Signal fallback
|
gfn/realizations/gssm/cuda/ops/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/cuda/ops/__init__.py
|
| 3 |
+
Exports all fused CUDA operations.
|
| 4 |
+
Gracefully returns None for any op whose kernel is not compiled.
|
| 5 |
+
"""
|
| 6 |
+
from ...cuda import CUDA_AVAILABLE
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
# Ensure gfn/csrc is in PYTHONPATH to find the compiled .pyd
|
| 11 |
+
_csrc_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "csrc"))
|
| 12 |
+
if _csrc_path not in sys.path:
|
| 13 |
+
sys.path.insert(0, _csrc_path)
|
| 14 |
+
|
| 15 |
+
def _get_op(module_path: str, name: str):
|
| 16 |
+
"""Safely import a CUDA binding, returning None on failure."""
|
| 17 |
+
try:
|
| 18 |
+
import importlib
|
| 19 |
+
mod = importlib.import_module(module_path)
|
| 20 |
+
return getattr(mod, name, None)
|
| 21 |
+
except Exception:
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
# ── Geometry ──────────────────────────────────────────────────────────────────
|
| 25 |
+
christoffel_cuda_fwd = _get_op("gfn_cuda", "compute_christoffel_symbols_fwd")
|
| 26 |
+
christoffel_cuda_bwd = _get_op("gfn_cuda", "compute_christoffel_symbols_bwd")
|
| 27 |
+
low_rank_christoffel_fwd = _get_op("gfn_cuda", "low_rank_christoffel_fwd")
|
| 28 |
+
low_rank_christoffel_bwd = _get_op("gfn_cuda", "low_rank_christoffel_bwd")
|
| 29 |
+
toroidal_christ_fwd = _get_op("gfn_cuda", "toroidal_geo_christoffel_fwd")
|
| 30 |
+
|
| 31 |
+
# ── Integrators ───────────────────────────────────────────────────────────────
|
| 32 |
+
heun_fused = _get_op("gfn_cuda", "heun_fwd")
|
| 33 |
+
leapfrog_fused = _get_op("gfn_cuda", "leapfrog_fwd")
|
| 34 |
+
yoshida_fused = _get_op("gfn_cuda", "yoshida_fwd")
|
| 35 |
+
rk4_fused = _get_op("gfn_cuda", "rk4_fwd")
|
| 36 |
+
|
| 37 |
+
# ── Loss ──────────────────────────────────────────────────────────────────────
|
| 38 |
+
toroidal_loss_fwd = _get_op("gfn_cuda", "toroidal_distance_loss_fwd")
|
| 39 |
+
toroidal_loss_bwd = _get_op("gfn_cuda", "toroidal_distance_loss_bwd")
|
| 40 |
+
|
| 41 |
+
def __getattr__(name):
|
| 42 |
+
if name.endswith(("_fused", "_fwd", "_bwd", "_cuda")):
|
| 43 |
+
return None
|
| 44 |
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
| 45 |
+
|
| 46 |
+
__all__ = [
|
| 47 |
+
"CUDA_AVAILABLE",
|
| 48 |
+
"christoffel_cuda_fwd", "christoffel_cuda_bwd",
|
| 49 |
+
"low_rank_christoffel_fwd", "low_rank_christoffel_bwd",
|
| 50 |
+
"toroidal_christ_fwd", "heun_fused", "leapfrog_fused",
|
| 51 |
+
"yoshida_fused", "rk4_fused", "toroidal_loss_fwd", "toroidal_loss_bwd"
|
| 52 |
+
]
|
gfn/realizations/gssm/data/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data/__init__.py — GFN V5
|
| 3 |
+
"""
|
| 4 |
+
from ..data.dataset import SequenceDataset
|
| 5 |
+
from ..data.loader import create_dataloaders
|
| 6 |
+
from ..data.transforms import shift_targets, add_bos_token, pad_sequences
|
| 7 |
+
from ..data.replay import TrajectoryReplayBuffer
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'SequenceDataset',
|
| 11 |
+
'create_dataloaders',
|
| 12 |
+
'shift_targets',
|
| 13 |
+
'add_bos_token',
|
| 14 |
+
'pad_sequences',
|
| 15 |
+
'TrajectoryReplayBuffer'
|
| 16 |
+
]
|
gfn/realizations/gssm/data/dataset.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
|
| 4 |
+
class SequenceDataset(Dataset):
|
| 5 |
+
"""Simple sequence dataset for (X, Y) pairs."""
|
| 6 |
+
def __init__(self, x: torch.Tensor, y: torch.Tensor):
|
| 7 |
+
self.x = x
|
| 8 |
+
self.y = y
|
| 9 |
+
|
| 10 |
+
def __len__(self):
|
| 11 |
+
return len(self.x)
|
| 12 |
+
|
| 13 |
+
def __getitem__(self, idx):
|
| 14 |
+
return self.x[idx], self.y[idx]
|
gfn/realizations/gssm/data/loader.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data/loader.py — GFN V5
|
| 3 |
+
DataLoaders y datasets para tareas GFN.
|
| 4 |
+
WATCHOUT: El DataLoader se crea UNA vez fuera del loop de entrenamiento.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
| 9 |
+
from typing import Tuple, Optional
|
| 10 |
+
from ..data.dataset import SequenceDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_dataloaders(
|
| 14 |
+
x: torch.Tensor,
|
| 15 |
+
y: torch.Tensor,
|
| 16 |
+
batch_size: int = 32,
|
| 17 |
+
val_split: float = 0.1,
|
| 18 |
+
shuffle: bool = True,
|
| 19 |
+
num_workers: int = 0,
|
| 20 |
+
seed: int = 42,
|
| 21 |
+
) -> Tuple[DataLoader, Optional[DataLoader]]:
|
| 22 |
+
"""
|
| 23 |
+
Crea train y validation DataLoaders desde tensores.
|
| 24 |
+
IMPORTANTE: Crear los DataLoaders UNA VEZ fuera del loop — no dentro.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
x, y: Tensores de entrada y objetivo
|
| 28 |
+
batch_size: Tamaño de batch
|
| 29 |
+
val_split: Fracción de datos para validación (0 = sin validación)
|
| 30 |
+
shuffle: Mezclar datos de entrenamiento
|
| 31 |
+
num_workers: Workers para carga de datos
|
| 32 |
+
seed: Semilla para reproducibilidad del split
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
(train_loader, val_loader) — val_loader es None si val_split=0
|
| 36 |
+
"""
|
| 37 |
+
dataset = SequenceDataset(x, y)
|
| 38 |
+
|
| 39 |
+
if val_split > 0:
|
| 40 |
+
n_val = max(1, int(len(dataset) * val_split))
|
| 41 |
+
n_train = len(dataset) - n_val
|
| 42 |
+
generator = torch.Generator().manual_seed(seed)
|
| 43 |
+
train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=generator)
|
| 44 |
+
|
| 45 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size,
|
| 46 |
+
shuffle=shuffle, num_workers=num_workers)
|
| 47 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size,
|
| 48 |
+
shuffle=False, num_workers=num_workers)
|
| 49 |
+
return train_loader, val_loader
|
| 50 |
+
|
| 51 |
+
train_loader = DataLoader(dataset, batch_size=batch_size,
|
| 52 |
+
shuffle=shuffle, num_workers=num_workers)
|
| 53 |
+
return train_loader, None
|
gfn/realizations/gssm/data/replay.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Replay / Trajectory Buffer — GFN V5
|
| 3 |
+
Maneja el almacenamiento y muestreo de estados físicos (x, v, forces)
|
| 4 |
+
para soporte de entrenamiento Off-Policy y exploración de GFlowNets reales.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
class TrajectoryReplayBuffer:
|
| 11 |
+
"""
|
| 12 |
+
A persistent buffer for storing and managing manifold trajectories (x, v states).
|
| 13 |
+
Serves as replay memory for Hamiltonian/Geodesic flows in V5.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
capacity: int,
|
| 18 |
+
dim: int,
|
| 19 |
+
device: torch.device = torch.device('cpu'),
|
| 20 |
+
dtype: torch.dtype = torch.float32
|
| 21 |
+
):
|
| 22 |
+
self.capacity = capacity
|
| 23 |
+
self.dim = dim
|
| 24 |
+
self.device = device
|
| 25 |
+
self.dtype = dtype
|
| 26 |
+
|
| 27 |
+
# Buffers for state (x), velocity (v), and optional force
|
| 28 |
+
# Shape: [capacity, dim] or [capacity, heads, head_dim] depending on input
|
| 29 |
+
# Note: we flatten the capacity dimension but keep the geometry shape.
|
| 30 |
+
self._initialized_shape = False
|
| 31 |
+
|
| 32 |
+
self.pointer = 0
|
| 33 |
+
self.size = 0
|
| 34 |
+
self.is_full = False
|
| 35 |
+
|
| 36 |
+
def _init_buffers(self, example_shape: torch.Size):
|
| 37 |
+
"""Initializes the tensor buffers based on the first observed shape."""
|
| 38 |
+
# example_shape might be [Batch, Dim] or [Batch, Heads, HeadDim]
|
| 39 |
+
# We need [Capacity, *shape[1:]]
|
| 40 |
+
element_shape = example_shape[1:]
|
| 41 |
+
|
| 42 |
+
# Memory check safeguard
|
| 43 |
+
import math
|
| 44 |
+
bytes_per_el = 4 if self.dtype == torch.float32 else 8
|
| 45 |
+
total_elements = self.capacity * math.prod(element_shape)
|
| 46 |
+
# 3 buffers (x, v, force)
|
| 47 |
+
total_mb = (3 * total_elements * bytes_per_el) / (1024 ** 2)
|
| 48 |
+
if total_mb > 1024 and self.device.type == 'cuda':
|
| 49 |
+
import logging
|
| 50 |
+
logging.warning(f"ReplayBuffer: Allocating {total_mb:.1f} MB on CUDA. Risk of OOM.")
|
| 51 |
+
|
| 52 |
+
self.x_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
|
| 53 |
+
self.v_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
|
| 54 |
+
self.force_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
|
| 55 |
+
self._initialized_shape = True
|
| 56 |
+
|
| 57 |
+
def add(
|
| 58 |
+
self,
|
| 59 |
+
x: torch.Tensor,
|
| 60 |
+
v: torch.Tensor,
|
| 61 |
+
force: Optional[torch.Tensor] = None
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
Adds a batch of transitions to the buffer.
|
| 65 |
+
"""
|
| 66 |
+
batch_size = x.size(0)
|
| 67 |
+
|
| 68 |
+
if not self._initialized_shape:
|
| 69 |
+
self._init_buffers(x.shape)
|
| 70 |
+
|
| 71 |
+
# Handle wrap-around indexing
|
| 72 |
+
indices = torch.arange(self.pointer, self.pointer + batch_size) % self.capacity
|
| 73 |
+
|
| 74 |
+
self.x_buffer[indices] = x.to(self.device).detach()
|
| 75 |
+
self.v_buffer[indices] = v.to(self.device).detach()
|
| 76 |
+
if force is not None:
|
| 77 |
+
self.force_buffer[indices] = force.to(self.device).detach()
|
| 78 |
+
|
| 79 |
+
self.pointer = (self.pointer + batch_size) % self.capacity
|
| 80 |
+
self.size = min(self.size + batch_size, self.capacity)
|
| 81 |
+
if self.size == self.capacity:
|
| 82 |
+
self.is_full = True
|
| 83 |
+
|
| 84 |
+
def sample_random(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 85 |
+
"""
|
| 86 |
+
Randomly samples a batch of states from the buffer.
|
| 87 |
+
Returns: (x, v, force)
|
| 88 |
+
"""
|
| 89 |
+
if self.size == 0:
|
| 90 |
+
raise ValueError("Cannot sample from an empty buffer.")
|
| 91 |
+
|
| 92 |
+
indices = torch.randint(0, self.size, (batch_size,), device=self.device)
|
| 93 |
+
|
| 94 |
+
return (
|
| 95 |
+
self.x_buffer[indices],
|
| 96 |
+
self.v_buffer[indices],
|
| 97 |
+
self.force_buffer[indices]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def sample_recent(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 101 |
+
"""Samples the most recently added transitions."""
|
| 102 |
+
if self.size == 0:
|
| 103 |
+
raise ValueError("Cannot sample from an empty buffer.")
|
| 104 |
+
|
| 105 |
+
if self.size < batch_size:
|
| 106 |
+
idx = torch.arange(0, self.size, device=self.device)
|
| 107 |
+
else:
|
| 108 |
+
idx = (torch.arange(self.pointer - batch_size, self.pointer, device=self.device) % self.capacity)
|
| 109 |
+
|
| 110 |
+
return (
|
| 111 |
+
self.x_buffer[idx],
|
| 112 |
+
self.v_buffer[idx],
|
| 113 |
+
self.force_buffer[idx]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def sample_with_noise(self, batch_size: int, noise_std: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 117 |
+
"""Samples with Gaussian jitter to improve robust training."""
|
| 118 |
+
x, v, _ = self.sample_random(batch_size)
|
| 119 |
+
x_noisy = x + torch.randn_like(x) * noise_std
|
| 120 |
+
return x_noisy, v
|
| 121 |
+
|
| 122 |
+
def clear(self):
|
| 123 |
+
"""Resets the buffer."""
|
| 124 |
+
self.pointer = 0
|
| 125 |
+
self.size = 0
|
| 126 |
+
self.is_full = False
|
| 127 |
+
self._initialized_shape = False
|
| 128 |
+
|
| 129 |
+
def __len__(self):
|
| 130 |
+
return self.size
|
gfn/realizations/gssm/data/transforms.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data/transforms.py — GFN V5
|
| 3 |
+
Transformaciones de datos para secuencias.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Tuple, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def shift_targets(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 11 |
+
"""
|
| 12 |
+
Crea pares (input, target) desplazados for language modeling.
|
| 13 |
+
input = x[:, :-1]
|
| 14 |
+
target = x[:, 1:]
|
| 15 |
+
"""
|
| 16 |
+
return x[:, :-1], x[:, 1:]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def add_bos_token(x: torch.Tensor, bos_id: int = 0) -> torch.Tensor:
|
| 20 |
+
"""Añade token BOS al inicio de cada secuencia."""
|
| 21 |
+
bos = torch.full((x.size(0), 1), bos_id, dtype=x.dtype, device=x.device)
|
| 22 |
+
return torch.cat([bos, x], dim=1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pad_sequences(sequences, max_len: int, pad_id: int = 0) -> torch.Tensor:
|
| 26 |
+
"""Padea una lista de secuencias de longitud variable."""
|
| 27 |
+
result = torch.full((len(sequences), max_len), pad_id, dtype=torch.long)
|
| 28 |
+
for i, seq in enumerate(sequences):
|
| 29 |
+
length = min(len(seq), max_len)
|
| 30 |
+
result[i, :length] = torch.tensor(seq[:length])
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_attention_mask(lengths: torch.Tensor, max_len: int) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
Crea attention mask desde longitudes de secuencia.
|
| 37 |
+
Returns: [B, max_len] con True donde hay datos válidos.
|
| 38 |
+
"""
|
| 39 |
+
indices = torch.arange(max_len, device=lengths.device).unsqueeze(0)
|
| 40 |
+
return indices < lengths.unsqueeze(1)
|
gfn/realizations/gssm/errors.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class GFNError(Exception):
|
| 2 |
+
"""Base exception for all GFN errors."""
|
| 3 |
+
pass
|
| 4 |
+
|
| 5 |
+
class ConfigurationError(GFNError):
|
| 6 |
+
"""Raised when a configuration is invalid or inconsistent."""
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
class GeometryError(GFNError):
|
| 10 |
+
"""Raised when a geometric operation fails (e.g., out of manifold)."""
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
class PhysicsError(GFNError):
|
| 14 |
+
"""Raised during physics engine failures (e.g., NaN detected)."""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
class IntegrationError(GFNError):
|
| 18 |
+
"""Raised during numerical integration failures."""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
class TrainingError(GFNError):
|
| 22 |
+
"""Raised during model training or optimization failures."""
|
| 23 |
+
pass
|
gfn/realizations/gssm/geometry/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/geometry/__init__.py
|
| 3 |
+
Public API for the geometry module — GFN V5
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Base and factory
|
| 7 |
+
from ..geometry.base import BaseGeometry
|
| 8 |
+
from ..geometry.factory import GeometryFactory
|
| 9 |
+
|
| 10 |
+
# Concrete geometries (imports trigger @register_geometry decorators)
|
| 11 |
+
from ..geometry.euclidean import EuclideanGeometry
|
| 12 |
+
from ..geometry.torus import ToroidalRiemannianGeometry, FlatToroidalRiemannianGeometry
|
| 13 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry, PaperLowRankRiemannianGeometry
|
| 14 |
+
from ..geometry.adaptive import AdaptiveRiemannianGeometry
|
| 15 |
+
from ..geometry.reactive import ReactiveRiemannianGeometry
|
| 16 |
+
from ..geometry.hyperbolic import HyperRiemannianGeometry
|
| 17 |
+
from ..geometry.holographic import HolographicRiemannianGeometry
|
| 18 |
+
from ..geometry.spherical import SphericalGeometry
|
| 19 |
+
from ..geometry.hierarchical import HierarchicalGeometry
|
| 20 |
+
|
| 21 |
+
# Re-export FrictionGate from unified physics.components location
|
| 22 |
+
from ..physics.components.friction import FrictionGate
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
# Base
|
| 26 |
+
"BaseGeometry",
|
| 27 |
+
"GeometryFactory",
|
| 28 |
+
# Implementations
|
| 29 |
+
"EuclideanGeometry",
|
| 30 |
+
"ToroidalRiemannianGeometry",
|
| 31 |
+
"FlatToroidalRiemannianGeometry",
|
| 32 |
+
"LowRankRiemannianGeometry",
|
| 33 |
+
"PaperLowRankRiemannianGeometry",
|
| 34 |
+
"AdaptiveRiemannianGeometry",
|
| 35 |
+
"ReactiveRiemannianGeometry",
|
| 36 |
+
"HyperRiemannianGeometry",
|
| 37 |
+
"HolographicRiemannianGeometry",
|
| 38 |
+
"SphericalGeometry",
|
| 39 |
+
"HierarchicalGeometry",
|
| 40 |
+
# Shared components
|
| 41 |
+
"FrictionGate",
|
| 42 |
+
]
|
gfn/realizations/gssm/geometry/adaptive.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptiveRiemannianGeometry — GFN V5
|
| 3 |
+
Adaptive rank Christoffel symbol decomposition.
|
| 4 |
+
Migrated from gfn/geo/riemannian/adaptive_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
from ..constants import CURVATURE_CLAMP
|
| 12 |
+
from ..config.schema import PhysicsConfig
|
| 13 |
+
from ..geometry.base import BaseGeometry
|
| 14 |
+
from ..registry import register_geometry
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_geometry('adaptive')
|
| 18 |
+
class AdaptiveRiemannianGeometry(BaseGeometry):
|
| 19 |
+
"""
|
| 20 |
+
Adjusts the effective curvature rank dynamically based on velocity complexity.
|
| 21 |
+
|
| 22 |
+
Architecture:
|
| 23 |
+
eff_rank = f(||v||) in [min_rank, max_rank]
|
| 24 |
+
Γ(v) = W[:, :eff_rank] @ (U[:, :eff_rank]^T v)^2
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, dim: int, max_rank: int = 64, config: Optional[PhysicsConfig] = None):
|
| 28 |
+
super().__init__(config)
|
| 29 |
+
self.dim = dim
|
| 30 |
+
self.max_rank = max_rank
|
| 31 |
+
self.min_rank_ratio = 0.1
|
| 32 |
+
|
| 33 |
+
self.U_full = nn.Parameter(torch.randn(dim, max_rank) * 0.01)
|
| 34 |
+
self.W_full = nn.Parameter(torch.randn(dim, max_rank) * 0.01)
|
| 35 |
+
|
| 36 |
+
# Complexity predictor: maps v → rank_ratio ∈ [0, 1]
|
| 37 |
+
self.complexity_net = nn.Sequential(
|
| 38 |
+
nn.Linear(dim, 32),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Linear(32, 1),
|
| 41 |
+
nn.Sigmoid()
|
| 42 |
+
)
|
| 43 |
+
# Initialize bias to start with a mostly-open rank to avoid vanishing gradients
|
| 44 |
+
nn.init.constant_(self.complexity_net[-2].bias, 1.0)
|
| 45 |
+
|
| 46 |
+
self.return_friction_separately = True
|
| 47 |
+
|
| 48 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 49 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 50 |
+
if v is None:
|
| 51 |
+
return torch.zeros_like(x)
|
| 52 |
+
|
| 53 |
+
# 1. Predict rank weight (soft-mask)
|
| 54 |
+
# Use complexity_net to predict a value p in [0, 1]
|
| 55 |
+
p = self.complexity_net(v) # [B, 1]
|
| 56 |
+
|
| 57 |
+
# Create a soft mask for the rank dimension: [B, max_rank]
|
| 58 |
+
# Mask[i] = sigmoid(slope * (p * max_rank - i))
|
| 59 |
+
# This approximates hard-slicing but is differentiable.
|
| 60 |
+
indices = torch.arange(self.max_rank, device=v.device).float()
|
| 61 |
+
slope = 10.0
|
| 62 |
+
soft_mask = torch.sigmoid(slope * (p * self.max_rank - indices)) # [B, max_rank]
|
| 63 |
+
|
| 64 |
+
# 2. Christoffel using all components modulated by mask
|
| 65 |
+
proj = torch.matmul(v, self.U_full) # [B, max_rank]
|
| 66 |
+
sq = proj * proj # [B, max_rank]
|
| 67 |
+
modulated_sq = sq * soft_mask # [B, max_rank]
|
| 68 |
+
gamma = torch.matmul(modulated_sq, self.W_full.t()) # [B, dim]
|
| 69 |
+
|
| 70 |
+
# 3. Friction (ensure mu is not just zero)
|
| 71 |
+
# Fallback to config friction or a base value
|
| 72 |
+
friction_base = getattr(self.config.stability, 'friction', 0.1)
|
| 73 |
+
mu = torch.full_like(v, friction_base)
|
| 74 |
+
|
| 75 |
+
gamma_clamped = CURVATURE_CLAMP * torch.tanh(gamma / CURVATURE_CLAMP)
|
| 76 |
+
|
| 77 |
+
if self.return_friction_separately:
|
| 78 |
+
return gamma_clamped, mu
|
| 79 |
+
|
| 80 |
+
return gamma_clamped + mu * v
|
| 81 |
+
|
| 82 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
return torch.ones_like(x)
|
gfn/realizations/gssm/geometry/base.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
from ..interfaces.geometry import Geometry
|
| 5 |
+
from ..config.schema import PhysicsConfig
|
| 6 |
+
from ..constants import TOPOLOGY_EUCLIDEAN
|
| 7 |
+
|
| 8 |
+
class BaseGeometry(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Base implementation for Riemannian Geometries in GFN V5.
|
| 11 |
+
Conforms to the Geometry protocol.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, config: Optional[PhysicsConfig] = None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.config = config or PhysicsConfig()
|
| 16 |
+
self.return_friction_separately = True
|
| 17 |
+
self.topology_type = self.config.topology.type
|
| 18 |
+
|
| 19 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""Default metric is identity (Euclidean). Subclasses should override."""
|
| 21 |
+
return torch.ones_like(x)
|
| 22 |
+
|
| 23 |
+
def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
"""Default Christoffel symbols are zero (Euclidean). Subclasses should override."""
|
| 25 |
+
return torch.zeros_like(x)
|
| 26 |
+
|
| 27 |
+
def compute_kinetic_energy(self, x: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Calculates Riemannian kinetic energy: T = (1/2) Σ_i g_ii v_i²
|
| 30 |
+
Supports position-dependent metrics (like Torus).
|
| 31 |
+
"""
|
| 32 |
+
g = self.metric_tensor(x) # [..., D]
|
| 33 |
+
return 0.5 * (g * v.pow(2)).sum(dim=-1)
|
| 34 |
+
|
| 35 |
+
def compute_potential_energy(self, x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Calculates physical potential energy V(x).
|
| 38 |
+
Default is 0.0 unless overwritten by specific topologies or forces.
|
| 39 |
+
"""
|
| 40 |
+
return torch.zeros_like(x).sum(dim=-1)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None, force: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 43 |
+
"""
|
| 44 |
+
Computes acceleration: acc = -Gamma(v, v) + F/g
|
| 45 |
+
Subclasses can override for more complex physics.
|
| 46 |
+
"""
|
| 47 |
+
if v is None:
|
| 48 |
+
return torch.zeros_like(x)
|
| 49 |
+
|
| 50 |
+
gamma = self.christoffel_symbols(x)
|
| 51 |
+
# Standard geodesic acceleration: -Gamma^k_ij v^i v^j
|
| 52 |
+
# In our simplified 1D-per-dimension metric, it's often just a point-wise product
|
| 53 |
+
acc = -gamma * (v**2)
|
| 54 |
+
|
| 55 |
+
if force is not None:
|
| 56 |
+
g = self.metric_tensor(x)
|
| 57 |
+
acc = acc + (force / (g + 1e-8))
|
| 58 |
+
|
| 59 |
+
if getattr(self, 'return_friction_separately', False):
|
| 60 |
+
return acc, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 61 |
+
|
| 62 |
+
return acc
|
| 63 |
+
|
| 64 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
"""Default projection is identity. Subclasses should override for periodic spaces."""
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""Default distance is Euclidean. Subclasses should override."""
|
| 70 |
+
return torch.norm(x1 - x2, dim=-1)
|
gfn/realizations/gssm/geometry/euclidean.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .base import BaseGeometry
|
| 3 |
+
from ..registry import register_geometry
|
| 4 |
+
from ..constants import TOPOLOGY_EUCLIDEAN
|
| 5 |
+
|
| 6 |
+
@register_geometry(TOPOLOGY_EUCLIDEAN)
|
| 7 |
+
class EuclideanGeometry(BaseGeometry):
|
| 8 |
+
"""Standard Euclidean Space (Flat)."""
|
| 9 |
+
|
| 10 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 11 |
+
return torch.ones_like(x)
|
| 12 |
+
|
| 13 |
+
def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
return torch.zeros_like(x)
|
| 15 |
+
|
| 16 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
return torch.norm(x1 - x2, dim=-1)
|
gfn/realizations/gssm/geometry/factory.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeometryFactory — GFN V5
|
| 3 |
+
Creates geometry instances from PhysicsConfig.
|
| 4 |
+
Supports: euclidean, torus, low_rank, reactive, adaptive, hyperbolic, holographic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from ..config.schema import PhysicsConfig
|
| 9 |
+
from ..registry import GEOMETRY_REGISTRY
|
| 10 |
+
from ..constants import TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
_GEOMETRIES_REGISTERED = False
|
| 16 |
+
|
| 17 |
+
def _register_all_geometries():
|
| 18 |
+
"""Importa los submódulos explícitamente para registrar las geometrías."""
|
| 19 |
+
global _GEOMETRIES_REGISTERED
|
| 20 |
+
if _GEOMETRIES_REGISTERED:
|
| 21 |
+
return
|
| 22 |
+
from . import euclidean
|
| 23 |
+
from . import torus
|
| 24 |
+
from . import low_rank
|
| 25 |
+
from . import adaptive
|
| 26 |
+
from . import reactive
|
| 27 |
+
from . import hyperbolic
|
| 28 |
+
_GEOMETRIES_REGISTERED = True
|
| 29 |
+
|
| 30 |
+
class GeometryFactory:
|
| 31 |
+
"""
|
| 32 |
+
Creates manifold geometries from configuration.
|
| 33 |
+
|
| 34 |
+
Primary key: topology.type ('euclidean', 'torus', 'hyperbolic', ...)
|
| 35 |
+
Secondary key: topology.riemannian_type ('low_rank', 'reactive', 'adaptive', ...)
|
| 36 |
+
|
| 37 |
+
riemannian_type overrides topology.type when explicitly set and registered.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _lookup_key(config: PhysicsConfig) -> str:
|
| 42 |
+
_register_all_geometries()
|
| 43 |
+
topo_type = config.topology.type.lower()
|
| 44 |
+
riem_type = getattr(config.topology, 'riemannian_type', 'reactive').lower()
|
| 45 |
+
available = GEOMETRY_REGISTRY.list_keys()
|
| 46 |
+
|
| 47 |
+
# Priority Logic:
|
| 48 |
+
# 1. Prioritize learned Riemannian geometries (low_rank, reactive, adaptive)
|
| 49 |
+
# even if the topology is specialized (torus, etc.), as they handle topology via features.
|
| 50 |
+
learned_types = {'low_rank', 'reactive', 'adaptive', 'low_rank_paper'}
|
| 51 |
+
if riem_type in learned_types and riem_type in available:
|
| 52 |
+
return riem_type
|
| 53 |
+
|
| 54 |
+
# 2. Otherwise, if topology is specific (torus, hyperbolic, etc.), use its analytical model.
|
| 55 |
+
if topo_type in available and topo_type != TOPOLOGY_EUCLIDEAN:
|
| 56 |
+
return topo_type
|
| 57 |
+
|
| 58 |
+
# 3. Fallback to riem_type or topo_type
|
| 59 |
+
if riem_type in available:
|
| 60 |
+
return riem_type
|
| 61 |
+
|
| 62 |
+
return topo_type
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def create(config: PhysicsConfig):
|
| 66 |
+
"""
|
| 67 |
+
Create geometry using default dim from config.
|
| 68 |
+
Looks for 'dim' in topology config or falls back to 64.
|
| 69 |
+
"""
|
| 70 |
+
lookup_key = GeometryFactory._lookup_key(config)
|
| 71 |
+
available = GEOMETRY_REGISTRY.list_keys()
|
| 72 |
+
|
| 73 |
+
if lookup_key in available:
|
| 74 |
+
geometry_cls = GEOMETRY_REGISTRY.get(lookup_key)
|
| 75 |
+
try:
|
| 76 |
+
dim = getattr(config, 'dim', 64)
|
| 77 |
+
rank = getattr(config.topology, 'riemannian_rank', 16)
|
| 78 |
+
return geometry_cls(dim=dim, rank=rank, config=config)
|
| 79 |
+
except TypeError:
|
| 80 |
+
try:
|
| 81 |
+
return geometry_cls(config=config)
|
| 82 |
+
except TypeError:
|
| 83 |
+
return geometry_cls()
|
| 84 |
+
|
| 85 |
+
logger.warning(f"Geometry '{lookup_key}' not found. Using EuclideanGeometry.")
|
| 86 |
+
from .euclidean import EuclideanGeometry
|
| 87 |
+
return EuclideanGeometry(config=config)
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def create_with_dim(dim: int, rank: int, num_heads: int, config: PhysicsConfig):
|
| 91 |
+
"""
|
| 92 |
+
Create geometry with explicit dim and rank.
|
| 93 |
+
Used by ModelFactory to pass head_dim (not total dim) to the geometry,
|
| 94 |
+
since geometry operates on per-head tensors [B, H, HD].
|
| 95 |
+
"""
|
| 96 |
+
lookup_key = GeometryFactory._lookup_key(config)
|
| 97 |
+
available = GEOMETRY_REGISTRY.list_keys()
|
| 98 |
+
|
| 99 |
+
if lookup_key in available:
|
| 100 |
+
geometry_cls = GEOMETRY_REGISTRY.get(lookup_key)
|
| 101 |
+
try:
|
| 102 |
+
return geometry_cls(dim=dim, rank=rank, num_heads=num_heads, config=config)
|
| 103 |
+
except TypeError:
|
| 104 |
+
try:
|
| 105 |
+
return geometry_cls(dim=dim, rank=rank, config=config)
|
| 106 |
+
except TypeError:
|
| 107 |
+
try:
|
| 108 |
+
return geometry_cls(config=config)
|
| 109 |
+
except TypeError:
|
| 110 |
+
return geometry_cls()
|
| 111 |
+
|
| 112 |
+
logger.warning(f"Geometry '{lookup_key}' not found. Using EuclideanGeometry.")
|
| 113 |
+
from .euclidean import EuclideanGeometry
|
| 114 |
+
try:
|
| 115 |
+
return EuclideanGeometry(dim=dim, num_heads=num_heads, config=config)
|
| 116 |
+
except TypeError:
|
| 117 |
+
return EuclideanGeometry(config=config)
|
gfn/realizations/gssm/geometry/hierarchical.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List, Optional, Union, Tuple, Any
|
| 4 |
+
from ..geometry.base import BaseGeometry
|
| 5 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry
|
| 6 |
+
from ..registry import register_geometry
|
| 7 |
+
|
| 8 |
+
@register_geometry('hierarchical')
|
| 9 |
+
class HierarchicalGeometry(BaseGeometry):
|
| 10 |
+
"""
|
| 11 |
+
Multi-Scale Riemannian Geometry (Christoffel Mixture).
|
| 12 |
+
Combines multiple geometries (typically Low-Rank) with different scales.
|
| 13 |
+
|
| 14 |
+
Migrated from gfn_old HierarchicalRiemannianGeometry.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, dim: int, rank: int = 16, ranks: Optional[List[int]] = None,
|
| 17 |
+
num_heads: int = 1, config: Optional[Any] = None, **kwargs):
|
| 18 |
+
super().__init__(config)
|
| 19 |
+
self.dim = dim
|
| 20 |
+
self.ranks = ranks if ranks is not None else [8, 16, 32]
|
| 21 |
+
if rank not in self.ranks:
|
| 22 |
+
# Optionally include the factory-suggested rank
|
| 23 |
+
self.ranks = sorted(list(set(self.ranks + [rank])))
|
| 24 |
+
self.num_heads = num_heads
|
| 25 |
+
|
| 26 |
+
# Initialize sub-geometries (defaulting to LowRank)
|
| 27 |
+
self.scales = nn.ModuleList([
|
| 28 |
+
LowRankRiemannianGeometry(dim, rank=r, num_heads=num_heads, config=config)
|
| 29 |
+
for r in self.ranks
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
# Learnable mixing weights
|
| 33 |
+
self.scale_weights = nn.Parameter(torch.ones(len(self.ranks)) / len(self.ranks))
|
| 34 |
+
self.return_friction_separately = False
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 37 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 38 |
+
|
| 39 |
+
gammas = []
|
| 40 |
+
frictions = []
|
| 41 |
+
|
| 42 |
+
# Execute each scale
|
| 43 |
+
for scale in self.scales:
|
| 44 |
+
# Temporarily ensure consistent return mode
|
| 45 |
+
was_sep = getattr(scale, 'return_friction_separately', False)
|
| 46 |
+
scale.return_friction_separately = True
|
| 47 |
+
|
| 48 |
+
res = scale(x, v, force=force, **kwargs)
|
| 49 |
+
if isinstance(res, tuple):
|
| 50 |
+
g, f = res
|
| 51 |
+
else:
|
| 52 |
+
g, f = res, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 53 |
+
|
| 54 |
+
gammas.append(g)
|
| 55 |
+
frictions.append(f)
|
| 56 |
+
scale.return_friction_separately = was_sep
|
| 57 |
+
|
| 58 |
+
# Mix using softmax weights
|
| 59 |
+
weights = torch.softmax(self.scale_weights, dim=0)
|
| 60 |
+
|
| 61 |
+
gamma_mixed = sum(w * g for w, g in zip(weights, gammas))
|
| 62 |
+
friction_mixed = sum(w * f for w, f in zip(weights, frictions))
|
| 63 |
+
|
| 64 |
+
if self.return_friction_separately:
|
| 65 |
+
return gamma_mixed, friction_mixed
|
| 66 |
+
|
| 67 |
+
if v is not None:
|
| 68 |
+
return gamma_mixed + friction_mixed * v
|
| 69 |
+
return gamma_mixed
|
| 70 |
+
|
| 71 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
weights = torch.softmax(self.scale_weights, dim=0)
|
| 73 |
+
metrics = [scale.metric_tensor(x) for scale in self.scales]
|
| 74 |
+
return sum(w * m for w, m in zip(weights, metrics))
|
| 75 |
+
|
| 76 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
# Hierarchical usually just projects using the first scale (geometry consistent)
|
| 78 |
+
from typing import cast
|
| 79 |
+
return cast(BaseGeometry, self.scales[0]).project(x)
|
| 80 |
+
|
| 81 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
weights = torch.softmax(self.scale_weights, dim=0)
|
| 83 |
+
dists = [scale.dist(x1, x2) for scale in self.scales]
|
| 84 |
+
return sum(w * d for w, d in zip(weights, dists))
|
gfn/realizations/gssm/geometry/holographic.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HolographicRiemannianGeometry — GFN V5
|
| 3 |
+
AdS/CFT-inspired holographic extensions (Paper 18).
|
| 4 |
+
Migrated from gfn/geo/physical/holographic_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
from ..config.schema import PhysicsConfig
|
| 12 |
+
from ..geometry.base import BaseGeometry
|
| 13 |
+
from ..registry import register_geometry
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@register_geometry('holographic')
|
| 17 |
+
class HolographicRiemannianGeometry(BaseGeometry):
|
| 18 |
+
"""
|
| 19 |
+
Conformal manifold inspired by Bulk-Boundary (AdS/CFT) correspondence.
|
| 20 |
+
|
| 21 |
+
Lifts boundary state x → bulk (x, z) where z is the holographic radial dim.
|
| 22 |
+
Conformal metric: g_ij = (1/z(x)²) · δ_ij
|
| 23 |
+
|
| 24 |
+
The Christoffel correction adds an AdS-geodesic term to any base geometry.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, base_geometry: BaseGeometry, z_min: float = 0.1,
|
| 28 |
+
z_max: float = 10.0, config: Optional[PhysicsConfig] = None):
|
| 29 |
+
super().__init__(config)
|
| 30 |
+
self.base_geometry = base_geometry
|
| 31 |
+
self.dim = getattr(base_geometry, 'dim', None)
|
| 32 |
+
self.z_min = z_min
|
| 33 |
+
self.z_max = z_max
|
| 34 |
+
|
| 35 |
+
dim = self.dim or 0
|
| 36 |
+
if dim > 0:
|
| 37 |
+
self.radial_net: nn.Module = nn.Sequential(
|
| 38 |
+
nn.Linear(dim, dim // 2),
|
| 39 |
+
nn.SiLU(),
|
| 40 |
+
nn.Linear(dim // 2, 1),
|
| 41 |
+
nn.Softplus()
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
self.radial_net = nn.Identity()
|
| 45 |
+
|
| 46 |
+
def get_z_and_grad(self, x: torch.Tensor):
|
| 47 |
+
x_req = x.detach().requires_grad_(True)
|
| 48 |
+
with torch.enable_grad():
|
| 49 |
+
z = self.radial_net(x_req) + self.z_min
|
| 50 |
+
z = torch.clamp(z, max=self.z_max)
|
| 51 |
+
grad_z = torch.autograd.grad(
|
| 52 |
+
z.sum(), x_req,
|
| 53 |
+
create_graph=self.training,
|
| 54 |
+
retain_graph=False
|
| 55 |
+
)[0]
|
| 56 |
+
return z, grad_z
|
| 57 |
+
|
| 58 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
z, _ = self.get_z_and_grad(x)
|
| 60 |
+
return (1.0 / z.pow(2)) * torch.ones_like(x)
|
| 61 |
+
|
| 62 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 63 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 64 |
+
out_base = self.base_geometry(x, v, force=force, **kwargs)
|
| 65 |
+
if isinstance(out_base, tuple):
|
| 66 |
+
gamma_base, mu = out_base
|
| 67 |
+
else:
|
| 68 |
+
gamma_base, mu = out_base, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 69 |
+
|
| 70 |
+
if v is None:
|
| 71 |
+
if self.return_friction_separately:
|
| 72 |
+
return gamma_base, mu
|
| 73 |
+
return gamma_base
|
| 74 |
+
|
| 75 |
+
z, grad_z = self.get_z_and_grad(x)
|
| 76 |
+
v_dot_gradz = (v * grad_z).sum(dim=-1, keepdim=True)
|
| 77 |
+
v_sq = (v * v).sum(dim=-1, keepdim=True)
|
| 78 |
+
gamma_ads = -(1.0 / z) * (2.0 * v_dot_gradz * v - v_sq * grad_z)
|
| 79 |
+
|
| 80 |
+
gamma_total = gamma_base + gamma_ads
|
| 81 |
+
|
| 82 |
+
if self.return_friction_separately:
|
| 83 |
+
return gamma_total, mu
|
| 84 |
+
|
| 85 |
+
return gamma_total + mu * v
|
| 86 |
+
|
| 87 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
return self.base_geometry.project(x)
|
| 89 |
+
|
| 90 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
return self.base_geometry.dist(x1, x2)
|
gfn/realizations/gssm/geometry/hyperbolic.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HyperRiemannianGeometry — GFN V5
|
| 3 |
+
Context-dependent (gated) Christoffel symbols.
|
| 4 |
+
Migrated from gfn/geo/topological/hyperbolic_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
from ..constants import CURVATURE_CLAMP, EPS, TOPOLOGY_TORUS
|
| 12 |
+
from ..config.schema import PhysicsConfig
|
| 13 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry
|
| 14 |
+
from ..registry import register_geometry
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_geometry('hyperbolic')
|
| 18 |
+
class HyperRiemannianGeometry(LowRankRiemannianGeometry):
|
| 19 |
+
"""
|
| 20 |
+
Hyper-Christoffel: geometry conditioned on current position.
|
| 21 |
+
|
| 22 |
+
Architecture:
|
| 23 |
+
U(x) = U_static * diag(Gate_u(x)) — position-scaled basis
|
| 24 |
+
W(x) = W_static * diag(Gate_w(x))
|
| 25 |
+
Γ(v | x) = W(x) @ (U(x)^T v)²
|
| 26 |
+
|
| 27 |
+
Gates output values in [0, 2] initialized near 1.0 (identity).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 31 |
+
config: Optional[PhysicsConfig] = None):
|
| 32 |
+
super().__init__(dim, rank, num_heads=num_heads, config=config)
|
| 33 |
+
self.return_friction_separately = True
|
| 34 |
+
|
| 35 |
+
self.gate_u = nn.Linear(dim, rank)
|
| 36 |
+
self.gate_w = nn.Linear(dim, rank)
|
| 37 |
+
nn.init.zeros_(self.gate_u.weight); nn.init.zeros_(self.gate_u.bias)
|
| 38 |
+
nn.init.zeros_(self.gate_w.weight); nn.init.zeros_(self.gate_w.bias)
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 41 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 42 |
+
if v is None:
|
| 43 |
+
return torch.zeros_like(x)
|
| 44 |
+
|
| 45 |
+
original_shape = v.shape
|
| 46 |
+
# Handle multi-head [B, H, HD] -> [B*H, HD]
|
| 47 |
+
if v.dim() == 3:
|
| 48 |
+
B, H, HD = v.shape
|
| 49 |
+
v_flat = v.reshape(B * H, HD)
|
| 50 |
+
x_flat = x.reshape(B * H, HD)
|
| 51 |
+
else:
|
| 52 |
+
v_flat = v
|
| 53 |
+
x_flat = x
|
| 54 |
+
B, H = None, None
|
| 55 |
+
|
| 56 |
+
# Context gates in [0, 2]
|
| 57 |
+
g_u = torch.sigmoid(self.gate_u(x_flat)) * 2.0 # [B*H, rank]
|
| 58 |
+
g_w = torch.sigmoid(self.gate_w(x_flat)) * 2.0
|
| 59 |
+
|
| 60 |
+
# Modulate static basis
|
| 61 |
+
# self.U is [HD, rank] or [H, HD, rank]
|
| 62 |
+
U_eff = self.U if self.U.dim() == 2 else self.U.mean(0)
|
| 63 |
+
proj_static = torch.matmul(v_flat, U_eff) # [B*H, rank]
|
| 64 |
+
proj_dynamic = proj_static * g_u
|
| 65 |
+
|
| 66 |
+
# Soft-saturation to prevent energy explosion
|
| 67 |
+
sq_dynamic = (proj_dynamic * proj_dynamic) / (1.0 + torch.abs(proj_dynamic) + EPS)
|
| 68 |
+
sq_modulated = sq_dynamic * g_w
|
| 69 |
+
|
| 70 |
+
W_t = self.W.t() if self.W.dim() == 2 else self.W.mean(0).t()
|
| 71 |
+
gamma = torch.matmul(sq_modulated, W_t) # [B*H, HD]
|
| 72 |
+
|
| 73 |
+
# Restore original shape if multi-head
|
| 74 |
+
if B is not None:
|
| 75 |
+
gamma = gamma.view(original_shape)
|
| 76 |
+
x_flat_for_mu = x_flat
|
| 77 |
+
v_flat_for_mu = v_flat
|
| 78 |
+
else:
|
| 79 |
+
x_flat_for_mu = x
|
| 80 |
+
v_flat_for_mu = v
|
| 81 |
+
|
| 82 |
+
# Friction
|
| 83 |
+
x_in = torch.cat([torch.sin(x_flat), torch.cos(x_flat)], dim=-1) \
|
| 84 |
+
if self.topology_type == TOPOLOGY_TORUS else x_flat
|
| 85 |
+
mu_base = self.friction + self.friction_gate(x_in, force=force)
|
| 86 |
+
v_norm = torch.norm(v, dim=-1, keepdim=True) / (self.dim ** 0.5 + EPS)
|
| 87 |
+
mu = mu_base * (1.0 + self.velocity_friction_scale * v_norm)
|
| 88 |
+
if mu.shape != v.shape:
|
| 89 |
+
mu = mu.view_as(v) if mu.numel() == v.numel() else mu.mean(dim=-1, keepdim=True)
|
| 90 |
+
|
| 91 |
+
gamma = self._normalize(gamma)
|
| 92 |
+
gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
|
| 93 |
+
|
| 94 |
+
if self.return_friction_separately:
|
| 95 |
+
return gamma, mu
|
| 96 |
+
|
| 97 |
+
return gamma + mu * v
|
gfn/realizations/gssm/geometry/low_rank.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LowRankRiemannianGeometry — GFN V5
|
| 3 |
+
Computes Christoffel symbols via a low-rank decomposition.
|
| 4 |
+
Migrated from gfn/geo/riemannian/low_rank_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple, Dict, Any
|
| 10 |
+
|
| 11 |
+
from ..constants import (
|
| 12 |
+
EPS, MAX_VELOCITY, TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN,
|
| 13 |
+
DEFAULT_FRICTION, CURVATURE_CLAMP, GATE_BIAS_OPEN
|
| 14 |
+
)
|
| 15 |
+
from ..config.schema import PhysicsConfig
|
| 16 |
+
from ..geometry.base import BaseGeometry
|
| 17 |
+
from ..registry import register_geometry
|
| 18 |
+
from ..cuda.ops import CUDA_AVAILABLE, low_rank_christoffel_fwd, low_rank_christoffel_bwd
|
| 19 |
+
|
| 20 |
+
class LowRankChristoffelFunction(torch.autograd.Function):
|
| 21 |
+
@staticmethod
|
| 22 |
+
def forward(ctx, v, U, W, clamp_val, enable_trace_norm, is_paper_version=False):
|
| 23 |
+
v_c = v.contiguous()
|
| 24 |
+
U_c = U.contiguous()
|
| 25 |
+
W_c = W.contiguous()
|
| 26 |
+
|
| 27 |
+
gamma = low_rank_christoffel_fwd(v_c, U_c, W_c, float(clamp_val), enable_trace_norm, is_paper_version)
|
| 28 |
+
ctx.save_for_backward(v_c, U_c, W_c, gamma)
|
| 29 |
+
ctx.clamp_val = float(clamp_val)
|
| 30 |
+
ctx.enable_trace_norm = enable_trace_norm
|
| 31 |
+
ctx.is_paper_version = is_paper_version
|
| 32 |
+
return gamma
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def backward(ctx, grad_gamma):
|
| 36 |
+
v_c, U_c, W_c, gamma_out = ctx.saved_tensors
|
| 37 |
+
if grad_gamma is None:
|
| 38 |
+
return None, None, None, None, None, None
|
| 39 |
+
|
| 40 |
+
grad_gamma_c = grad_gamma.contiguous()
|
| 41 |
+
d_v, d_U, d_W = low_rank_christoffel_bwd(
|
| 42 |
+
grad_gamma_c, v_c, U_c, W_c, gamma_out,
|
| 43 |
+
ctx.clamp_val, ctx.enable_trace_norm, ctx.is_paper_version
|
| 44 |
+
)
|
| 45 |
+
return d_v, d_U, d_W, None, None, None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Use unified FrictionGate from physics.components (no duplication)
|
| 49 |
+
from ..physics.components.friction import FrictionGate
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@register_geometry('low_rank')
|
| 53 |
+
class LowRankRiemannianGeometry(BaseGeometry):
|
| 54 |
+
r"""
|
| 55 |
+
Low-rank Christoffel symbol decomposition.
|
| 56 |
+
|
| 57 |
+
Γ^k_ij ≈ Σ_r W_{rk} * (U_ir * U_jr)
|
| 58 |
+
This is an approximation — symmetry is preserved but Bianchi identities are not guaranteed.
|
| 59 |
+
Chosen for computational efficiency.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
dim: Manifold dimension.
|
| 63 |
+
rank: Rank of the decomposition.
|
| 64 |
+
num_heads: Number of parallel heads.
|
| 65 |
+
config: PhysicsConfig instance.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 69 |
+
config: Optional[PhysicsConfig] = None):
|
| 70 |
+
super().__init__(config)
|
| 71 |
+
self.dim = dim
|
| 72 |
+
self.rank = rank
|
| 73 |
+
self.num_heads = num_heads
|
| 74 |
+
|
| 75 |
+
topo = self.config.topology.type.lower()
|
| 76 |
+
self.topology = topo
|
| 77 |
+
self.clamp_val = self.config.stability.curvature_clamp
|
| 78 |
+
self.enable_trace_normalization = self.config.stability.enable_trace_normalization
|
| 79 |
+
self.enable_trace_normalization = self.config.stability.enable_trace_normalization
|
| 80 |
+
# Friction parameters are now handled by PhysicsEngine to avoid duplication
|
| 81 |
+
|
| 82 |
+
# Feature dimension for gate input (Fourier for torus)
|
| 83 |
+
gate_input_dim = dim * 2 if topo == TOPOLOGY_TORUS else dim
|
| 84 |
+
|
| 85 |
+
# Low-rank basis parameters - initialized with small noise to break symmetry
|
| 86 |
+
if num_heads > 1:
|
| 87 |
+
self.U = nn.Parameter(torch.randn(num_heads, dim, rank) * 1e-4)
|
| 88 |
+
self.W = nn.Parameter(torch.randn(num_heads, dim, rank) * 1e-4)
|
| 89 |
+
else:
|
| 90 |
+
self.U = nn.Parameter(torch.randn(dim, rank) * 1e-4)
|
| 91 |
+
self.W = nn.Parameter(torch.randn(dim, rank) * 1e-4)
|
| 92 |
+
|
| 93 |
+
# Friction gate
|
| 94 |
+
friction_mode = getattr(self.config.stability, 'friction_mode', 'static')
|
| 95 |
+
self.friction_gate = FrictionGate(dim, gate_input_dim, mode=friction_mode, num_heads=num_heads)
|
| 96 |
+
|
| 97 |
+
# CONTRACT: LowRank ALWAYS returns (gamma_christoffel, mu_friction) separately.
|
| 98 |
+
# The physics engine is the single authority on when/how friction is applied.
|
| 99 |
+
# This prevents the P0.1 double-friction bug (geometry + engine both applying mu*v).
|
| 100 |
+
self.return_friction_separately = True
|
| 101 |
+
|
| 102 |
+
def _get_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""Convert position to gate input features."""
|
| 104 |
+
if self.topology == TOPOLOGY_TORUS:
|
| 105 |
+
return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
def connection(self, v: torch.Tensor, w: torch.Tensor,
|
| 109 |
+
x: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
Bilinear Christoffel contraction: Γ(v,w)^k
|
| 112 |
+
Γ^k_ij ≈ Σ_r W[r,k] * (U[i,r] * U[j,r])
|
| 113 |
+
"""
|
| 114 |
+
if v.dim() == 3 and self.U.dim() == 3:
|
| 115 |
+
v_r = torch.einsum('bhd, hdr -> bhr', v, self.U)
|
| 116 |
+
w_r = torch.einsum('bhd, hdr -> bhr', w, self.U)
|
| 117 |
+
vw_r = v_r * w_r
|
| 118 |
+
gamma = torch.einsum('bhr, hdr -> bhd', vw_r, self.W)
|
| 119 |
+
else:
|
| 120 |
+
v_r = v @ self.U # [..., rank] (works for both 2D and 3D U)
|
| 121 |
+
w_r = w @ self.U
|
| 122 |
+
vw_r = v_r * w_r
|
| 123 |
+
W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
|
| 124 |
+
gamma = vw_r @ W_t
|
| 125 |
+
return torch.clamp(gamma, -self.clamp_val, self.clamp_val)
|
| 126 |
+
|
| 127 |
+
def _normalize(self, gamma: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
"""Symmetry-preserving trace normalization."""
|
| 129 |
+
if gamma.dim() < 2:
|
| 130 |
+
return gamma
|
| 131 |
+
|
| 132 |
+
is_multi_head = (gamma.dim() == 3 and self.num_heads > 1)
|
| 133 |
+
|
| 134 |
+
# Matrix case [..., D, D]
|
| 135 |
+
if not is_multi_head and gamma.dim() >= 3 and gamma.shape[-1] == gamma.shape[-2]:
|
| 136 |
+
gamma_sym = 0.5 * (gamma + gamma.transpose(-1, -2))
|
| 137 |
+
if self.enable_trace_normalization:
|
| 138 |
+
diag_mean = torch.diagonal(gamma_sym, dim1=-1, dim2=-2).mean(dim=-1, keepdim=True)
|
| 139 |
+
correction = torch.diag_embed(diag_mean.expand(-1, self.dim))
|
| 140 |
+
return gamma_sym - correction
|
| 141 |
+
return gamma_sym
|
| 142 |
+
|
| 143 |
+
# Vector case [..., D]
|
| 144 |
+
if self.enable_trace_normalization:
|
| 145 |
+
return gamma - gamma.mean(dim=-1, keepdim=True)
|
| 146 |
+
return gamma
|
| 147 |
+
|
| 148 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 149 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 150 |
+
if v is None:
|
| 151 |
+
return torch.zeros_like(x)
|
| 152 |
+
|
| 153 |
+
original_shape = v.shape
|
| 154 |
+
# Handle multi-head [B, H, HD] → reshape to [B*H, HD] for matmul with U=[HD, rank]
|
| 155 |
+
if v.dim() == 3 and self.U.dim() == 2:
|
| 156 |
+
B, H, HD = v.shape
|
| 157 |
+
v_flat = v.reshape(B * H, HD) # [B*H, HD]
|
| 158 |
+
x_flat = x.reshape(B * H, HD)
|
| 159 |
+
else:
|
| 160 |
+
v_flat = v
|
| 161 |
+
x_flat = x
|
| 162 |
+
B, H, HD = None, None, v_flat.shape[-1]
|
| 163 |
+
R = self.rank
|
| 164 |
+
|
| 165 |
+
# Check if we can take the fast CUDA path
|
| 166 |
+
use_cuda_fused = (
|
| 167 |
+
CUDA_AVAILABLE and
|
| 168 |
+
low_rank_christoffel_fwd is not None and
|
| 169 |
+
v_flat.is_cuda and
|
| 170 |
+
v_flat.dtype == torch.float32 and
|
| 171 |
+
self.W.dim() == 3
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if use_cuda_fused:
|
| 175 |
+
# Reshape [B*H, HD] -> [B, H, HD] to match what kernel expects
|
| 176 |
+
actual_B = original_shape[0] if v.dim() == 3 else 1
|
| 177 |
+
actual_H = self.num_heads
|
| 178 |
+
|
| 179 |
+
v_re = v_flat.view(actual_B, actual_H, HD)
|
| 180 |
+
U_re = self.U.view(actual_H, HD, R) # self.U is [H, D, R]
|
| 181 |
+
W_re = self.W.view(actual_H, HD, R)
|
| 182 |
+
|
| 183 |
+
gamma_re = LowRankChristoffelFunction.apply(
|
| 184 |
+
v_re, U_re, W_re, self.clamp_val, self.enable_trace_normalization, False
|
| 185 |
+
)
|
| 186 |
+
gamma = gamma_re.view_as(v_flat)
|
| 187 |
+
else:
|
| 188 |
+
# Christoffel symbols via self-connection (Native Fallback)
|
| 189 |
+
if v_flat.dim() == 3 and self.U.dim() == 3:
|
| 190 |
+
v_r = torch.einsum('bhd, hdr -> bhr', v_flat, self.U)
|
| 191 |
+
sq = v_r * v_r
|
| 192 |
+
gamma = torch.einsum('bhr, hdr -> bhd', sq, self.W)
|
| 193 |
+
else:
|
| 194 |
+
v_r = v_flat @ self.U # [..., rank]
|
| 195 |
+
sq = v_r * v_r
|
| 196 |
+
W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
|
| 197 |
+
gamma = sq @ W_t # [..., HD]
|
| 198 |
+
|
| 199 |
+
# Friction coefficient (position-dependent, gated)
|
| 200 |
+
x_in = self._get_features(x_flat)
|
| 201 |
+
mu = self.friction_gate(x_in, force=force)
|
| 202 |
+
|
| 203 |
+
# Note: PhysicsEngine will add base friction and apply velocity scaling
|
| 204 |
+
|
| 205 |
+
# Friction and normalizing only if not already done in CUDA
|
| 206 |
+
if not use_cuda_fused:
|
| 207 |
+
gamma = self._normalize(gamma)
|
| 208 |
+
gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
|
| 209 |
+
|
| 210 |
+
# Restore original shape if we reshaped
|
| 211 |
+
if B is not None:
|
| 212 |
+
gamma = gamma.view(original_shape)
|
| 213 |
+
mu = mu.view(original_shape)
|
| 214 |
+
|
| 215 |
+
# CONTRACT: always return (gamma_pure, mu) so engine has single authority over friction
|
| 216 |
+
return gamma, mu
|
| 217 |
+
|
| 218 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
"""
|
| 220 |
+
Implicit Riemannian metric from the low-rank decomposition.
|
| 221 |
+
|
| 222 |
+
The Christoffel parametrization Γ^k_ij ≈ Σ_r W_rk (U_ir U_jr) implies
|
| 223 |
+
an underlying metric: g_ij ≈ Σ_r U_ir * U_jr = diag(U @ Uᵀ)
|
| 224 |
+
|
| 225 |
+
Returns per-coordinate metric scale [..., D] or ones if shape unknown.
|
| 226 |
+
T = (1/2) Σ_i g_ii v_i² (Riemannian kinetic energy)
|
| 227 |
+
"""
|
| 228 |
+
if self.U.dim() == 2:
|
| 229 |
+
# Single head: U is [D, rank] → g_diag is [D]
|
| 230 |
+
g_diag = (self.U ** 2).sum(dim=-1) # [D]
|
| 231 |
+
# Broadcast to x shape: handles [B, D], [B*H, D], any [..., D]
|
| 232 |
+
return g_diag.expand_as(x)
|
| 233 |
+
else:
|
| 234 |
+
# Multi-head: U is [H, D, rank] → g_diag is [H, D]
|
| 235 |
+
g_diag = (self.U ** 2).sum(dim=-1) # [H, D]
|
| 236 |
+
if x.dim() == 3 and x.shape[1] == self.num_heads:
|
| 237 |
+
# [B, H, D]: structured multi-head
|
| 238 |
+
return g_diag.unsqueeze(0).expand_as(x)
|
| 239 |
+
else:
|
| 240 |
+
# [B, H*D]: flat layout — expand g_diag [H,D] → [H*D], broadcast to [B, H*D]
|
| 241 |
+
g_flat = g_diag.reshape(-1) # [H*D]
|
| 242 |
+
return g_flat.expand(x.shape[0], -1) if x.dim() == 2 else g_flat.expand_as(x)
|
| 243 |
+
|
| 244 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 245 |
+
if self.topology == TOPOLOGY_TORUS:
|
| 246 |
+
diff = x1 - x2
|
| 247 |
+
diff = torch.atan2(torch.sin(diff), torch.cos(diff))
|
| 248 |
+
return torch.norm(diff, dim=-1)
|
| 249 |
+
return torch.norm(x1 - x2, dim=-1)
|
| 250 |
+
|
| 251 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
if self.topology == TOPOLOGY_TORUS:
|
| 253 |
+
return torch.atan2(torch.sin(x), torch.cos(x))
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@register_geometry('low_rank_paper')
|
| 258 |
+
class PaperLowRankRiemannianGeometry(LowRankRiemannianGeometry):
|
| 259 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 260 |
+
config: Optional[PhysicsConfig] = None):
|
| 261 |
+
super().__init__(dim, rank, num_heads=num_heads, config=config)
|
| 262 |
+
self.return_friction_separately = True
|
| 263 |
+
|
| 264 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 265 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 266 |
+
if v is None:
|
| 267 |
+
return torch.zeros_like(x)
|
| 268 |
+
|
| 269 |
+
original_shape = v.shape
|
| 270 |
+
if v.dim() == 3 and self.U.dim() == 2:
|
| 271 |
+
B, H, HD = v.shape
|
| 272 |
+
v_flat = v.reshape(B * H, HD)
|
| 273 |
+
x_flat = x.reshape(B * H, HD)
|
| 274 |
+
else:
|
| 275 |
+
v_flat = v
|
| 276 |
+
x_flat = x
|
| 277 |
+
B, H, HD = None, None, v_flat.shape[-1]
|
| 278 |
+
R = self.rank
|
| 279 |
+
|
| 280 |
+
use_cuda_fused = (
|
| 281 |
+
CUDA_AVAILABLE and
|
| 282 |
+
low_rank_christoffel_fwd is not None and
|
| 283 |
+
v_flat.is_cuda and
|
| 284 |
+
v_flat.dtype == torch.float32 and
|
| 285 |
+
self.W.dim() == 3
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if use_cuda_fused:
|
| 289 |
+
actual_B = original_shape[0] if v.dim() == 3 else 1
|
| 290 |
+
actual_H = self.num_heads
|
| 291 |
+
v_re = v_flat.view(actual_B, actual_H, HD)
|
| 292 |
+
U_re = self.U.view(actual_H, HD, R)
|
| 293 |
+
W_re = self.W.view(actual_H, HD, R)
|
| 294 |
+
|
| 295 |
+
gamma_re = LowRankChristoffelFunction.apply(
|
| 296 |
+
v_re, U_re, W_re, self.clamp_val, self.enable_trace_normalization, True
|
| 297 |
+
)
|
| 298 |
+
gamma = gamma_re.view_as(v_flat)
|
| 299 |
+
else:
|
| 300 |
+
if v_flat.dim() == 3 and self.U.dim() == 3:
|
| 301 |
+
v_r = torch.einsum('bhd, hdr -> bhr', v_flat, self.U)
|
| 302 |
+
denom = 1.0 + torch.norm(v_r, dim=-1, keepdim=True)
|
| 303 |
+
phi = (v_r * v_r) / denom
|
| 304 |
+
gamma = torch.einsum('bhr, hdr -> bhd', phi, self.W)
|
| 305 |
+
else:
|
| 306 |
+
v_r = v_flat @ self.U
|
| 307 |
+
denom = 1.0 + torch.norm(v_r, dim=-1, keepdim=True)
|
| 308 |
+
phi = (v_r * v_r) / denom
|
| 309 |
+
W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
|
| 310 |
+
gamma = phi @ W_t
|
| 311 |
+
|
| 312 |
+
x_in = self._get_features(x_flat)
|
| 313 |
+
mu = self.friction_gate(x_in, force=force)
|
| 314 |
+
|
| 315 |
+
if not use_cuda_fused:
|
| 316 |
+
gamma = self._normalize(gamma)
|
| 317 |
+
gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
|
| 318 |
+
|
| 319 |
+
if B is not None:
|
| 320 |
+
gamma = gamma.view(original_shape)
|
| 321 |
+
mu = mu.view(original_shape)
|
| 322 |
+
|
| 323 |
+
# CONTRACT: Always return (gamma_pure, mu) for unified PhysicsEngine handling.
|
| 324 |
+
return gamma, mu
|
gfn/realizations/gssm/geometry/reactive.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ReactiveRiemannianGeometry — GFN V5
|
| 3 |
+
Active-inference geometry: curvature reacts to system state.
|
| 4 |
+
Migrated from gfn/geo/physical/reactive_field_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple, Dict, Any
|
| 10 |
+
|
| 11 |
+
from ..constants import CURVATURE_CLAMP, EPS, DEFAULT_PLASTICITY, TOPOLOGY_TORUS
|
| 12 |
+
from ..config.schema import PhysicsConfig
|
| 13 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry
|
| 14 |
+
from ..registry import register_geometry
|
| 15 |
+
|
| 16 |
+
# Default constants
|
| 17 |
+
SINGULARITY_THRESHOLD = 0.5
|
| 18 |
+
BLACK_HOLE_STRENGTH = 3.0
|
| 19 |
+
SINGULARITY_GATE_SLOPE = 10.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@register_geometry('reactive')
|
| 23 |
+
class ReactiveRiemannianGeometry(LowRankRiemannianGeometry):
|
| 24 |
+
"""
|
| 25 |
+
Geometry that reacts to the system's own state via active inference.
|
| 26 |
+
|
| 27 |
+
Enhancements over LowRank:
|
| 28 |
+
1. Plasticity: Christoffel symbols scaled by kinetic energy (curv. amplification ≈ attention).
|
| 29 |
+
2. Singularities: Soft curvature amplification near semantic attractors.
|
| 30 |
+
|
| 31 |
+
Note: These are regularization/attention mechanisms, NOT physical manifold properties.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 35 |
+
config: Optional[PhysicsConfig] = None):
|
| 36 |
+
super().__init__(dim, rank, num_heads=num_heads, config=config)
|
| 37 |
+
self.return_friction_separately = True
|
| 38 |
+
|
| 39 |
+
self.active_cfg = self.config.active_inference
|
| 40 |
+
self.plasticity = getattr(self.active_cfg, 'plasticity', DEFAULT_PLASTICITY)
|
| 41 |
+
|
| 42 |
+
sing_cfg = self.config.singularities
|
| 43 |
+
sing_enabled = getattr(sing_cfg, 'enabled', False)
|
| 44 |
+
|
| 45 |
+
if sing_enabled:
|
| 46 |
+
self.semantic_certainty_threshold = getattr(sing_cfg, 'threshold', SINGULARITY_THRESHOLD)
|
| 47 |
+
self.curvature_amplification_factor = getattr(sing_cfg, 'strength', BLACK_HOLE_STRENGTH)
|
| 48 |
+
gate_input_dim = dim * 2 if self.topology == TOPOLOGY_TORUS else dim
|
| 49 |
+
if num_heads > 1:
|
| 50 |
+
self.V_weight = nn.Parameter(torch.zeros(num_heads, gate_input_dim, 1))
|
| 51 |
+
else:
|
| 52 |
+
self.V = nn.Linear(gate_input_dim, 1)
|
| 53 |
+
nn.init.zeros_(self.V.weight)
|
| 54 |
+
nn.init.constant_(self.V.bias, -2.0) # Start gate closed
|
| 55 |
+
else:
|
| 56 |
+
self.semantic_certainty_threshold = SINGULARITY_THRESHOLD
|
| 57 |
+
self.curvature_amplification_factor = BLACK_HOLE_STRENGTH
|
| 58 |
+
self.V = None
|
| 59 |
+
|
| 60 |
+
def _get_potential(self, x_in: torch.Tensor) -> Optional[torch.Tensor]:
|
| 61 |
+
"""Compute singularity potential, returns None if disabled."""
|
| 62 |
+
if not getattr(self.config.singularities, 'enabled', False):
|
| 63 |
+
return None
|
| 64 |
+
if self.num_heads > 1:
|
| 65 |
+
return torch.sigmoid(torch.matmul(x_in.unsqueeze(-2), self.V_weight).squeeze(-2))
|
| 66 |
+
elif self.V is not None:
|
| 67 |
+
return torch.sigmoid(self.V(x_in))
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 71 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 72 |
+
if v is None:
|
| 73 |
+
return torch.zeros_like(x)
|
| 74 |
+
|
| 75 |
+
# 1. Base curvature from LowRank
|
| 76 |
+
res = super().forward(x, v, force=force, **kwargs)
|
| 77 |
+
if isinstance(res, tuple):
|
| 78 |
+
gamma, mu = res
|
| 79 |
+
else:
|
| 80 |
+
gamma, mu = res, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 81 |
+
|
| 82 |
+
if not self.active_cfg.enabled:
|
| 83 |
+
if self.return_friction_separately:
|
| 84 |
+
return gamma, mu
|
| 85 |
+
return gamma + mu * v if v is not None else gamma
|
| 86 |
+
|
| 87 |
+
# 2. Plasticity: scale curvature by kinetic energy
|
| 88 |
+
react_cfg = self.active_cfg.reactive_curvature
|
| 89 |
+
react_enabled = react_cfg.get('enabled', False) if isinstance(react_cfg, dict) else False
|
| 90 |
+
if react_enabled and self.plasticity > 0.0:
|
| 91 |
+
energy = torch.tanh(v.pow(2).mean(dim=-1, keepdim=True))
|
| 92 |
+
gamma = gamma * (1.0 + self.plasticity * energy)
|
| 93 |
+
|
| 94 |
+
# 3. Singularity amplification
|
| 95 |
+
if getattr(self.config.singularities, 'enabled', False) and x is not None:
|
| 96 |
+
x_in = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) if self.topology == TOPOLOGY_TORUS else x
|
| 97 |
+
potential = self._get_potential(x_in)
|
| 98 |
+
if potential is not None:
|
| 99 |
+
gate_slope = getattr(self.config.singularities, 'gate_slope', SINGULARITY_GATE_SLOPE)
|
| 100 |
+
is_amplified = torch.sigmoid(gate_slope * (potential - self.semantic_certainty_threshold))
|
| 101 |
+
amp = 1.0 + is_amplified * (self.curvature_amplification_factor - 1.0)
|
| 102 |
+
gamma = gamma * amp
|
| 103 |
+
limit = self.curvature_amplification_factor * CURVATURE_CLAMP
|
| 104 |
+
gamma = limit * torch.tanh(gamma / limit)
|
| 105 |
+
|
| 106 |
+
if self.return_friction_separately:
|
| 107 |
+
return gamma, mu
|
| 108 |
+
|
| 109 |
+
return gamma + mu * v if v is not None else gamma
|
gfn/realizations/gssm/geometry/spherical.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Optional, Union, Tuple, Any
|
| 3 |
+
from ..geometry.base import BaseGeometry
|
| 4 |
+
from ..registry import register_geometry
|
| 5 |
+
from ..constants import EPS, TOPOLOGY_SPHERE
|
| 6 |
+
|
| 7 |
+
@register_geometry(TOPOLOGY_SPHERE)
|
| 8 |
+
class SphericalGeometry(BaseGeometry):
|
| 9 |
+
"""
|
| 10 |
+
Spherical Geometry (Analytical).
|
| 11 |
+
Computes Christoffel symbols for a constant positive curvature space.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, dim: int, rank: int = 16, config: Optional[Any] = None, **kwargs):
|
| 14 |
+
super().__init__(config)
|
| 15 |
+
self.dim = dim
|
| 16 |
+
|
| 17 |
+
def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""Analytical Christoffel symbols for S^n are typically zero in standard embedding or complex in other charts."""
|
| 19 |
+
# For our GFN purposes, we often use the simplified 'spherical_christoffel_torch'
|
| 20 |
+
# which acts as a centering/restoring force towards the sphere surface.
|
| 21 |
+
return torch.zeros_like(x)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 24 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 25 |
+
if v is None:
|
| 26 |
+
return torch.zeros_like(x)
|
| 27 |
+
|
| 28 |
+
# Simplified analytical spherical coupling (restoring force)
|
| 29 |
+
xv = torch.sum(x * v, dim=-1, keepdim=True)
|
| 30 |
+
vv = torch.sum(v * v, dim=-1, keepdim=True)
|
| 31 |
+
|
| 32 |
+
# Gamma = -(2.0 * xv * v - vv * x)
|
| 33 |
+
gamma = -(2.0 * xv * v - vv * x)
|
| 34 |
+
|
| 35 |
+
# Apply standard V5 clamping
|
| 36 |
+
clamp_val = getattr(self, 'clamp_val', 5.0)
|
| 37 |
+
return clamp_val * torch.tanh(gamma / clamp_val)
|
| 38 |
+
|
| 39 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""Identity metric for simplified spherical chart."""
|
| 41 |
+
return torch.ones_like(x)
|
| 42 |
+
|
| 43 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""Great-circle distance approximation."""
|
| 45 |
+
dot = torch.sum(x1 * x2, dim=-1)
|
| 46 |
+
# Assuming points are on unit sphere
|
| 47 |
+
return torch.acos(torch.clamp(dot, -1.0 + EPS, 1.0 - EPS))
|
gfn/realizations/gssm/geometry/torus.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ToroidalRiemannianGeometry — GFN V5
|
| 3 |
+
Full analytical torus geometry (canonical implementation).
|
| 4 |
+
Replaces old stub. Migrated from gfn/geo/topological/toroidal_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple, Dict, Any
|
| 10 |
+
|
| 11 |
+
from ..constants import (
|
| 12 |
+
EPS, TOPOLOGY_TORUS, CURVATURE_CLAMP,
|
| 13 |
+
DEFAULT_FRICTION, GATE_BIAS_CLOSED
|
| 14 |
+
)
|
| 15 |
+
from ..config.schema import PhysicsConfig
|
| 16 |
+
from ..geometry.base import BaseGeometry
|
| 17 |
+
from ..registry import register_geometry
|
| 18 |
+
|
| 19 |
+
# Import modular friction component
|
| 20 |
+
from ..physics.components.friction import FrictionGate
|
| 21 |
+
|
| 22 |
+
# Torus-specific constants
|
| 23 |
+
TOROIDAL_CURVATURE_SCALE = 0.1
|
| 24 |
+
CLAMP_MIN_STRONG = 1e-4
|
| 25 |
+
EPSILON_SMOOTH = 1e-9
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import toroidal_cuda
|
| 29 |
+
HAS_CUDA_EXT = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
HAS_CUDA_EXT = False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@register_geometry(TOPOLOGY_TORUS)
|
| 36 |
+
class ToroidalRiemannianGeometry(BaseGeometry):
|
| 37 |
+
"""
|
| 38 |
+
Curved torus geometry with exact (analytical) Christoffel symbols.
|
| 39 |
+
|
| 40 |
+
Metric (2D torus paired dimensions):
|
| 41 |
+
g_theta = r²
|
| 42 |
+
g_phi = (R + r·cos θ)²
|
| 43 |
+
|
| 44 |
+
Generalizes to N dims by pairing (theta_0,phi_0), (theta_1,phi_1), ...
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, dim: int = 64, rank: int = 16, num_heads: int = 1,
|
| 48 |
+
config: Optional[PhysicsConfig] = None):
|
| 49 |
+
super().__init__(config)
|
| 50 |
+
self.dim = dim
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
|
| 53 |
+
topo = self.config.topology
|
| 54 |
+
|
| 55 |
+
# CORRECCIÓN: Hacer R y r aprendibles según GFN_Paper_Complete.md Sección 5.1
|
| 56 |
+
# "where R_i are the (learnable) radii"
|
| 57 |
+
# Y 01_HYPER_TORUS.md Sección 2.2: R es radio mayor, r es radio menor
|
| 58 |
+
learnable_R = getattr(topo, 'learnable_R', True)
|
| 59 |
+
learnable_r = getattr(topo, 'learnable_r', True)
|
| 60 |
+
|
| 61 |
+
if learnable_R:
|
| 62 |
+
# R como parámetro aprendible
|
| 63 |
+
self.R = nn.Parameter(torch.tensor(topo.R, dtype=torch.float32))
|
| 64 |
+
else:
|
| 65 |
+
# R como buffer no-entrenable
|
| 66 |
+
self.register_buffer('R', torch.tensor(topo.R, dtype=torch.float32))
|
| 67 |
+
|
| 68 |
+
if learnable_r:
|
| 69 |
+
# r como parámetro aprendible
|
| 70 |
+
self.r = nn.Parameter(torch.tensor(topo.r, dtype=torch.float32))
|
| 71 |
+
else:
|
| 72 |
+
# r como buffer no-entrenable
|
| 73 |
+
self.register_buffer('r', torch.tensor(topo.r, dtype=torch.float32))
|
| 74 |
+
|
| 75 |
+
self.topology = topo.type.lower()
|
| 76 |
+
self.clamp_val = self.config.stability.curvature_clamp
|
| 77 |
+
|
| 78 |
+
active_cfg = self.config.active_inference
|
| 79 |
+
self.active_cfg = active_cfg
|
| 80 |
+
rc = active_cfg.reactive_curvature
|
| 81 |
+
self.plasticity = rc.get('plasticity', getattr(active_cfg, 'plasticity', 0.0)) \
|
| 82 |
+
if isinstance(rc, dict) else getattr(active_cfg, 'plasticity', 0.0)
|
| 83 |
+
|
| 84 |
+
sing_cfg = self.config.singularities
|
| 85 |
+
self.singularity_threshold = sing_cfg.threshold
|
| 86 |
+
self.black_hole_strength = sing_cfg.strength
|
| 87 |
+
|
| 88 |
+
gate_input_dim = dim * 2 # [sin(x), cos(x)]
|
| 89 |
+
friction_mode = getattr(self.config.stability, 'friction_mode', 'static')
|
| 90 |
+
self.friction_gate = FrictionGate(
|
| 91 |
+
dim, gate_input_dim=gate_input_dim, mode=friction_mode, num_heads=num_heads
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Optional singularity potential gate
|
| 95 |
+
if getattr(sing_cfg, 'enabled', False):
|
| 96 |
+
if num_heads > 1:
|
| 97 |
+
self.V_weight = nn.Parameter(torch.zeros(num_heads, gate_input_dim, 1))
|
| 98 |
+
self.V_bias = nn.Parameter(torch.full((num_heads, 1), GATE_BIAS_CLOSED))
|
| 99 |
+
else:
|
| 100 |
+
self.V = nn.Linear(gate_input_dim, 1)
|
| 101 |
+
nn.init.zeros_(self.V.weight)
|
| 102 |
+
nn.init.constant_(self.V.bias, GATE_BIAS_CLOSED)
|
| 103 |
+
else:
|
| 104 |
+
self.V = None
|
| 105 |
+
|
| 106 |
+
def validate_dimensions(self, x: torch.Tensor):
|
| 107 |
+
if x.shape[-1] % 2 != 0:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"ToroidalGeometry requires even dim for (θ,φ) pairing. Got {x.shape[-1]}"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
self.validate_dimensions(x)
|
| 114 |
+
g = torch.ones_like(x)
|
| 115 |
+
for i in range(0, x.shape[-1], 2):
|
| 116 |
+
th = x[..., i]
|
| 117 |
+
g[..., i] = self.r ** 2
|
| 118 |
+
g[..., i + 1] = (self.R + self.r * torch.cos(th)) ** 2
|
| 119 |
+
return g
|
| 120 |
+
|
| 121 |
+
def connection(self, v: torch.Tensor, w: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
d = x.shape[-1]
|
| 123 |
+
gamma = torch.zeros_like(v)
|
| 124 |
+
for i in range(0, (d // 2) * 2, 2):
|
| 125 |
+
th = x[..., i]
|
| 126 |
+
v_th, v_ph = v[..., i], v[..., i + 1]
|
| 127 |
+
w_th, w_ph = w[..., i], w[..., i + 1]
|
| 128 |
+
denom = torch.clamp(self.R + self.r * torch.cos(th), min=CLAMP_MIN_STRONG)
|
| 129 |
+
term_th = (denom * torch.sin(th) / (self.r + EPSILON_SMOOTH)) * TOROIDAL_CURVATURE_SCALE
|
| 130 |
+
gamma[..., i] = term_th * (v_ph * w_ph)
|
| 131 |
+
term_ph = (-(self.r * torch.sin(th)) / (denom + EPSILON_SMOOTH)) * TOROIDAL_CURVATURE_SCALE
|
| 132 |
+
gamma[..., i + 1] = term_ph * (v_ph * w_th + v_th * w_ph)
|
| 133 |
+
return gamma
|
| 134 |
+
|
| 135 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 136 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 137 |
+
if v is None:
|
| 138 |
+
return torch.zeros_like(x)
|
| 139 |
+
|
| 140 |
+
self.validate_dimensions(x)
|
| 141 |
+
|
| 142 |
+
if HAS_CUDA_EXT and x.is_cuda and v is not None and v.is_cuda:
|
| 143 |
+
gamma = toroidal_cuda.forward(
|
| 144 |
+
x, v, self.R, self.r,
|
| 145 |
+
TOROIDAL_CURVATURE_SCALE, EPSILON_SMOOTH, CLAMP_MIN_STRONG
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
is_odd = (self.dim % 2 != 0)
|
| 149 |
+
x_pad = torch.nn.functional.pad(x, (0, 1)) if is_odd else x
|
| 150 |
+
v_pad = torch.nn.functional.pad(v, (0, 1)) if is_odd else v
|
| 151 |
+
|
| 152 |
+
th = x_pad[..., 0::2]
|
| 153 |
+
v_th = v_pad[..., 0::2]
|
| 154 |
+
v_ph = v_pad[..., 1::2]
|
| 155 |
+
|
| 156 |
+
denom = torch.clamp(self.R + self.r * torch.cos(th), min=CLAMP_MIN_STRONG)
|
| 157 |
+
term_th = (denom * torch.sin(th) / (self.r + EPSILON_SMOOTH))
|
| 158 |
+
term_ph = -(self.r * torch.sin(th)) / (denom + EPSILON_SMOOTH)
|
| 159 |
+
|
| 160 |
+
gamma_th = term_th * (v_ph ** 2) * TOROIDAL_CURVATURE_SCALE
|
| 161 |
+
gamma_ph = 2.0 * term_ph * v_ph * v_th * TOROIDAL_CURVATURE_SCALE
|
| 162 |
+
|
| 163 |
+
half = x.shape[-1] // 2
|
| 164 |
+
gamma = torch.zeros_like(x)
|
| 165 |
+
gamma[..., 0::2] = gamma_th[..., :half + x.shape[-1] % 2]
|
| 166 |
+
gamma[..., 1::2] = gamma_ph[..., :half]
|
| 167 |
+
|
| 168 |
+
# Friction gate
|
| 169 |
+
x_in = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
|
| 170 |
+
mu = self.friction_gate(x_in, force=force)
|
| 171 |
+
if mu.shape[-1] != v.shape[-1]:
|
| 172 |
+
if mu.shape[-1] == 2 * v.shape[-1]:
|
| 173 |
+
mu = (mu[..., :v.shape[-1]] + mu[..., v.shape[-1]:]) / 2.0
|
| 174 |
+
|
| 175 |
+
# Active inference
|
| 176 |
+
if self.active_cfg.enabled:
|
| 177 |
+
rc = self.active_cfg.reactive_curvature
|
| 178 |
+
react_enabled = rc.get('enabled', False) if isinstance(rc, dict) else False
|
| 179 |
+
if react_enabled and self.plasticity > 0.0:
|
| 180 |
+
energy = torch.tanh(v.pow(2).mean(dim=-1, keepdim=True))
|
| 181 |
+
gamma = gamma * (1.0 + self.plasticity * energy)
|
| 182 |
+
|
| 183 |
+
if getattr(self.config.singularities, 'enabled', False):
|
| 184 |
+
if self.num_heads > 1:
|
| 185 |
+
potential = torch.sigmoid(
|
| 186 |
+
torch.matmul(x_in.unsqueeze(-2), self.V_weight).squeeze(-2)
|
| 187 |
+
+ self.V_bias
|
| 188 |
+
)
|
| 189 |
+
elif self.V is not None:
|
| 190 |
+
potential = torch.sigmoid(self.V(x_in))
|
| 191 |
+
else:
|
| 192 |
+
potential = None
|
| 193 |
+
|
| 194 |
+
if potential is not None:
|
| 195 |
+
soft_m = torch.sigmoid(5.0 * (potential - self.singularity_threshold))
|
| 196 |
+
gamma = gamma * (1.0 + soft_m * (self.black_hole_strength - 1.0))
|
| 197 |
+
|
| 198 |
+
# CONTRACT: Always return (gamma_pure, mu) — engine applies friction, not geometry.
|
| 199 |
+
# This unifies the contract with LowRankRiemannianGeometry (P0.2 fix).
|
| 200 |
+
return gamma, mu
|
| 201 |
+
|
| 202 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 203 |
+
return torch.atan2(torch.sin(x), torch.cos(x))
|
| 204 |
+
|
| 205 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 206 |
+
diff = x1 - x2
|
| 207 |
+
return torch.norm(torch.atan2(torch.sin(diff), torch.cos(diff)), dim=-1)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@register_geometry('flat_torus')
|
| 211 |
+
class FlatToroidalRiemannianGeometry(BaseGeometry):
|
| 212 |
+
def __init__(self, dim: int = 64, rank: int = 16, num_heads: int = 1,
|
| 213 |
+
config: Optional[PhysicsConfig] = None):
|
| 214 |
+
super().__init__(config)
|
| 215 |
+
self.dim = dim
|
| 216 |
+
self.num_heads = num_heads
|
| 217 |
+
topo = self.config.topology
|
| 218 |
+
|
| 219 |
+
# CORRECCIÓN: Hacer R y r aprendibles también en FlatTorus
|
| 220 |
+
learnable_R = getattr(topo, 'learnable_R', True)
|
| 221 |
+
learnable_r = getattr(topo, 'learnable_r', True)
|
| 222 |
+
|
| 223 |
+
if learnable_R:
|
| 224 |
+
self.R = nn.Parameter(torch.tensor(topo.R, dtype=torch.float32))
|
| 225 |
+
else:
|
| 226 |
+
self.register_buffer('R', torch.tensor(topo.R, dtype=torch.float32))
|
| 227 |
+
|
| 228 |
+
if learnable_r:
|
| 229 |
+
self.r = nn.Parameter(torch.tensor(topo.r, dtype=torch.float32))
|
| 230 |
+
else:
|
| 231 |
+
self.register_buffer('r', torch.tensor(topo.r, dtype=torch.float32))
|
| 232 |
+
|
| 233 |
+
self.topology = topo.type.lower()
|
| 234 |
+
self.return_friction_separately = True
|
| 235 |
+
gate_input_dim = dim * 2
|
| 236 |
+
friction_mode = getattr(self.config.stability, 'friction_mode', 'static')
|
| 237 |
+
self.friction_gate = FrictionGate(
|
| 238 |
+
dim, gate_input_dim=gate_input_dim, mode=friction_mode, num_heads=num_heads
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def validate_dimensions(self, x: torch.Tensor):
|
| 242 |
+
if x.shape[-1] % 2 != 0:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"FlatToroidalGeometry requires even dim for (θ,φ) pairing. Got {x.shape[-1]}"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 248 |
+
self.validate_dimensions(x)
|
| 249 |
+
return torch.ones_like(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 252 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 253 |
+
if v is None:
|
| 254 |
+
return torch.zeros_like(x)
|
| 255 |
+
|
| 256 |
+
self.validate_dimensions(x)
|
| 257 |
+
x_in = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
|
| 258 |
+
mu = self.friction_gate(x_in, force=force)
|
| 259 |
+
if mu.shape[-1] != v.shape[-1]:
|
| 260 |
+
if mu.shape[-1] == 2 * v.shape[-1]:
|
| 261 |
+
mu = (mu[..., :v.shape[-1]] + mu[..., v.shape[-1]:]) / 2.0
|
| 262 |
+
else:
|
| 263 |
+
mu = mu.expand_as(v)
|
| 264 |
+
gamma = torch.zeros_like(v)
|
| 265 |
+
if self.return_friction_separately:
|
| 266 |
+
return gamma, mu
|
| 267 |
+
return gamma + mu * v
|
| 268 |
+
|
| 269 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 270 |
+
return torch.atan2(torch.sin(x), torch.cos(x))
|
| 271 |
+
|
| 272 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 273 |
+
diff = x1 - x2
|
| 274 |
+
return torch.norm(torch.atan2(torch.sin(diff), torch.cos(diff)), dim=-1)
|