Final Fix: Bind to 0.0.0.0 and show_api=False
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +83 -5
- app.py +150 -0
- config.json +56 -0
- convergence_plot.png +0 -0
- gfn/__init__.py +30 -0
- gfn/realizations/__init__.py +19 -0
- gfn/realizations/api.py +66 -0
- gfn/realizations/gssm/__init__.py +10 -0
- gfn/realizations/gssm/api.py +102 -0
- gfn/realizations/gssm/config/__init__.py +63 -0
- gfn/realizations/gssm/config/defaults.py +111 -0
- gfn/realizations/gssm/config/loader.py +163 -0
- gfn/realizations/gssm/config/schema.py +199 -0
- gfn/realizations/gssm/config/serialization.py +39 -0
- gfn/realizations/gssm/config/validator.py +109 -0
- gfn/realizations/gssm/constants.py +67 -0
- gfn/realizations/gssm/core/__init__.py +7 -0
- gfn/realizations/gssm/core/state.py +60 -0
- gfn/realizations/gssm/core/types.py +27 -0
- gfn/realizations/gssm/csrc/README.md +2 -0
- gfn/realizations/gssm/csrc/compile_cuda_12.9.bat +68 -0
- gfn/realizations/gssm/csrc/extension.cpp +87 -0
- gfn/realizations/gssm/csrc/geometry/low_rank.cu +160 -0
- gfn/realizations/gssm/csrc/integrators/integrators.cpp +252 -0
- gfn/realizations/gssm/csrc/integrators/integrators.h +41 -0
- gfn/realizations/gssm/csrc/losses/toroidal.cu +99 -0
- gfn/realizations/gssm/csrc/setup.py +38 -0
- gfn/realizations/gssm/cuda/__init__.py +11 -0
- gfn/realizations/gssm/cuda/autograd/__init__.py +0 -0
- gfn/realizations/gssm/cuda/kernels/__init__.py +0 -0
- gfn/realizations/gssm/cuda/kernels/geometry_kernels.py +99 -0
- gfn/realizations/gssm/cuda/kernels/integrator_kernels.py +73 -0
- gfn/realizations/gssm/cuda/ops/__init__.py +52 -0
- gfn/realizations/gssm/data/__init__.py +16 -0
- gfn/realizations/gssm/data/dataset.py +14 -0
- gfn/realizations/gssm/data/loader.py +53 -0
- gfn/realizations/gssm/data/replay.py +130 -0
- gfn/realizations/gssm/data/transforms.py +40 -0
- gfn/realizations/gssm/errors.py +23 -0
- gfn/realizations/gssm/geometry/__init__.py +42 -0
- gfn/realizations/gssm/geometry/adaptive.py +83 -0
- gfn/realizations/gssm/geometry/base.py +70 -0
- gfn/realizations/gssm/geometry/euclidean.py +20 -0
- gfn/realizations/gssm/geometry/factory.py +117 -0
- gfn/realizations/gssm/geometry/hierarchical.py +84 -0
- gfn/realizations/gssm/geometry/holographic.py +91 -0
- gfn/realizations/gssm/geometry/hyperbolic.py +97 -0
- gfn/realizations/gssm/geometry/low_rank.py +324 -0
- gfn/realizations/gssm/geometry/reactive.py +109 -0
- gfn/realizations/gssm/geometry/spherical.py +47 -0
README.md
CHANGED
|
@@ -1,12 +1,90 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: GFN XOR Parity Solver
|
| 3 |
+
emoji: 🌀
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
| 8 |
+
python_version: "3.11"
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
license: cc-by-nc-nd-4.0
|
| 12 |
+
library_name: gfn
|
| 13 |
+
language: en
|
| 14 |
+
pipeline_tag: tabular-classification
|
| 15 |
+
tags:
|
| 16 |
+
- gfn
|
| 17 |
+
- physics-informed
|
| 18 |
+
- geometric-deep-learning
|
| 19 |
+
- g-ssm
|
| 20 |
+
- parity
|
| 21 |
+
- xor
|
| 22 |
+
model-index:
|
| 23 |
+
- name: gfn-gssm-xor-parity
|
| 24 |
+
results:
|
| 25 |
+
- task:
|
| 26 |
+
type: tabular-classification
|
| 27 |
+
name: XOR Parity
|
| 28 |
+
dataset:
|
| 29 |
+
name: synthetic-bitstreams
|
| 30 |
+
type: synthetic
|
| 31 |
+
metrics:
|
| 32 |
+
- name: Accuracy
|
| 33 |
+
type: accuracy
|
| 34 |
+
value: 100.0
|
| 35 |
---
|
| 36 |
|
| 37 |
+
# 🌀 G-SSM XOR Parity Solver
|
| 38 |
+
|
| 39 |
+
[](https://doi.org/10.5281/zenodo.19141133)
|
| 40 |
+
[](https://huggingface.co/DepthMuun)
|
| 41 |
+
[](https://github.com/DepthMuun/gfn)
|
| 42 |
+
|
| 43 |
+
This model is a spatial, differential realization of the **Geometric Flow Network (GFN)** paradigm, specifically implementing the **Geodesic State Space Model (G-SSM)** specialized for cumulative parity (XOR) logic.
|
| 44 |
+
|
| 45 |
+
## 🚀 Technical Highlights
|
| 46 |
+
- **O(1) Memory Complexity**: The physical state is a single point on a 16D torus, regardless of bitstream length. It does not use KV-caching.
|
| 47 |
+
- **Symplectic Integration**: Uses the **Yoshida 4th-order integrator** to preserve the Hamiltonian structure of the belief flow.
|
| 48 |
+
- **Infinite Generalization**: Trained on 20-bit sequences, generalizes to **1,000,000+ bits** with zero error.
|
| 49 |
+
|
| 50 |
+
## 🛠️ Local Installation & Usage
|
| 51 |
+
|
| 52 |
+
To run this model locally, you need the **GFN Framework** and the model assets.
|
| 53 |
+
|
| 54 |
+
### 1. Install GFN Framework
|
| 55 |
+
```bash
|
| 56 |
+
pip install git+https://github.com/DepthMuun/gfn.git
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### 2. Clone this repository
|
| 60 |
+
```bash
|
| 61 |
+
git lfs install
|
| 62 |
+
git clone https://huggingface.co/spaces/DepthMuun/gfn-gssm-xor-parity-space
|
| 63 |
+
cd gfn-gssm-xor-parity-space
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### 3. Run the Interactive Demo
|
| 67 |
+
```bash
|
| 68 |
+
python app.py
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## 🧠 Technical Concept
|
| 72 |
+
Unlike standard statistical models, the **G-SSM** treats input as physical impulses on a Riemannian manifold. Parity is decoded as a geodetic position on a torus ($+PI/2$ for 1, $-PI/2$ for 0).
|
| 73 |
+
|
| 74 |
+
## 📜 Citation
|
| 75 |
+
If you use this work, please cite:
|
| 76 |
+
```latex
|
| 77 |
+
@article{sturtz2026geometry,
|
| 78 |
+
title={Geometric Flow Networks: A Physics-Informed Paradigm for Sequential Intelligence},
|
| 79 |
+
author={Stürtz, Joaquín},
|
| 80 |
+
journal={Zenodo Preprints},
|
| 81 |
+
year={2026},
|
| 82 |
+
doi={10.5281/zenodo.19141133},
|
| 83 |
+
url={https://doi.org/10.5281/zenodo.19141133}
|
| 84 |
+
}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## 🔗 Resources
|
| 88 |
+
- **Official Checkpoint**: [GFN XOR Model](https://huggingface.co/DepthMuun/gfn-gssm-xor-parity)
|
| 89 |
+
- **Framework Source**: [GitHub: DepthMuun/gfn](https://github.com/DepthMuun/gfn)
|
| 90 |
+
- **Official Paper**: [Zenodo](https://doi.org/10.5281/zenodo.19141133)
|
app.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Add local gfn folder to path if it exists (for HF Spaces)
|
| 10 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
if os.path.exists(os.path.join(script_dir, "gfn")):
|
| 12 |
+
sys.path.insert(0, script_dir)
|
| 13 |
+
|
| 14 |
+
import gfn
|
| 15 |
+
|
| 16 |
+
def load_model():
|
| 17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
+
|
| 19 |
+
# Load config safely using absolute path
|
| 20 |
+
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.json")
|
| 21 |
+
with open(config_path, "r") as f:
|
| 22 |
+
config = json.load(f)
|
| 23 |
+
|
| 24 |
+
model = gfn.gssm.create(
|
| 25 |
+
vocab_size=config['architecture']['vocab_size'],
|
| 26 |
+
dim=config['architecture']['dim'],
|
| 27 |
+
depth=config['architecture']['depth'],
|
| 28 |
+
heads=config['architecture']['heads'],
|
| 29 |
+
physics=config['physics'],
|
| 30 |
+
trajectory_mode=config['architecture']['trajectory_mode'],
|
| 31 |
+
coupler_mode=config['architecture']['coupler_mode'],
|
| 32 |
+
initial_spread=config['architecture']['initial_spread'],
|
| 33 |
+
integrator=config['architecture']['integrator'],
|
| 34 |
+
holographic=config['architecture'].get('holographic', True)
|
| 35 |
+
).to(device)
|
| 36 |
+
|
| 37 |
+
checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "xor_best_model.bin")
|
| 38 |
+
if os.path.exists(checkpoint_path):
|
| 39 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
|
| 40 |
+
model.eval()
|
| 41 |
+
return model, device
|
| 42 |
+
|
| 43 |
+
model, device = load_model()
|
| 44 |
+
|
| 45 |
+
import json
|
| 46 |
+
import tempfile
|
| 47 |
+
|
| 48 |
+
def predict_parity(bitstream):
|
| 49 |
+
if not all(c in "01" for c in bitstream):
|
| 50 |
+
return "Error: Input must be a binary string.", 0, None
|
| 51 |
+
|
| 52 |
+
if len(bitstream) == 0:
|
| 53 |
+
return "Empty input", 0, None
|
| 54 |
+
|
| 55 |
+
x_in = torch.tensor([[int(c) for c in bitstream]], device=device)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
output = model(x_in)
|
| 59 |
+
x_pred = output[0] # [B, T, D]
|
| 60 |
+
|
| 61 |
+
# Parity calculation for display
|
| 62 |
+
bits = [int(c) for c in bitstream]
|
| 63 |
+
cumulative_parity = []
|
| 64 |
+
curr = 0
|
| 65 |
+
for b in bits:
|
| 66 |
+
curr = curr ^ b
|
| 67 |
+
cumulative_parity.append(int(curr))
|
| 68 |
+
|
| 69 |
+
# Prediction
|
| 70 |
+
PI = math.pi
|
| 71 |
+
TWO_PI = 2.0 * PI
|
| 72 |
+
half_pi = PI * 0.5
|
| 73 |
+
|
| 74 |
+
# Last token prediction
|
| 75 |
+
final_state = x_pred[0, -1, :]
|
| 76 |
+
dist_pos = torch.min(
|
| 77 |
+
torch.abs(final_state - half_pi) % TWO_PI,
|
| 78 |
+
TWO_PI - (torch.abs(final_state - half_pi) % TWO_PI)
|
| 79 |
+
).mean().item()
|
| 80 |
+
dist_neg = torch.min(
|
| 81 |
+
torch.abs(final_state + half_pi) % TWO_PI,
|
| 82 |
+
TWO_PI - (torch.abs(final_state + half_pi) % TWO_PI)
|
| 83 |
+
).mean().item()
|
| 84 |
+
|
| 85 |
+
prediction = 1 if dist_pos < dist_neg else 0
|
| 86 |
+
is_correct = (prediction == cumulative_parity[-1])
|
| 87 |
+
accuracy = 100.0 if is_correct else 0.0
|
| 88 |
+
confidence = 1.0 - min(dist_pos, dist_neg) / half_pi
|
| 89 |
+
|
| 90 |
+
result_data = {
|
| 91 |
+
"input": bitstream,
|
| 92 |
+
"target_parity": cumulative_parity[-1],
|
| 93 |
+
"model_prediction": prediction,
|
| 94 |
+
"is_correct": is_correct,
|
| 95 |
+
"geometric_confidence": f"{confidence:.4f}",
|
| 96 |
+
"sequence_length": len(bitstream),
|
| 97 |
+
"full_target_trace": "".join(map(str, cumulative_parity))
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Save to temp file for download
|
| 101 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w')
|
| 102 |
+
json.dump(result_data, temp_file, indent=4)
|
| 103 |
+
temp_file.close()
|
| 104 |
+
|
| 105 |
+
status = "✅ SUCCESS" if is_correct else "❌ FAILURE"
|
| 106 |
+
return status, f"{accuracy}%", temp_file.name
|
| 107 |
+
|
| 108 |
+
with gr.Blocks(title="G-SSM XOR Parity Solver", theme=gr.themes.Soft()) as demo:
|
| 109 |
+
gr.Markdown("# 🌀 G-SSM XOR Parity Solver")
|
| 110 |
+
gr.Markdown("""
|
| 111 |
+
### Geodesic State Space Model (G-SSM) — Zero-Shot Logic Generalization
|
| 112 |
+
This model demonstrates **O(1) memory scaling** by solving XOR parity on arbitrarily long sequences.
|
| 113 |
+
""")
|
| 114 |
+
|
| 115 |
+
with gr.Row():
|
| 116 |
+
with gr.Column(scale=2):
|
| 117 |
+
input_text = gr.Textbox(
|
| 118 |
+
label="Input Binary Stream",
|
| 119 |
+
placeholder="Enter 0s and 1s...",
|
| 120 |
+
value="10110",
|
| 121 |
+
lines=2
|
| 122 |
+
)
|
| 123 |
+
submit_btn = gr.Button("🔥 Run Geometric Inference", variant="primary")
|
| 124 |
+
|
| 125 |
+
with gr.Column(scale=1):
|
| 126 |
+
acc_label = gr.Label(label="REAL ACCURACY")
|
| 127 |
+
status_output = gr.Textbox(label="Status")
|
| 128 |
+
|
| 129 |
+
with gr.Row():
|
| 130 |
+
download_btn = gr.File(label="Full Trace (JSON)")
|
| 131 |
+
|
| 132 |
+
gr.Examples(
|
| 133 |
+
examples=["10110", "1" * 20, "10" * 50, "1" * 1000, "0" * 500 + "1"],
|
| 134 |
+
inputs=input_text
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# LINK EVENTS
|
| 138 |
+
submit_btn.click(
|
| 139 |
+
fn=predict_parity,
|
| 140 |
+
inputs=input_text,
|
| 141 |
+
outputs=[status_output, acc_label, download_btn]
|
| 142 |
+
)
|
| 143 |
+
input_text.submit(
|
| 144 |
+
fn=predict_parity,
|
| 145 |
+
inputs=input_text,
|
| 146 |
+
outputs=[status_output, acc_label, download_btn]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)
|
config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": {
|
| 3 |
+
"vocab_size": 2,
|
| 4 |
+
"dim": 8,
|
| 5 |
+
"depth": 1,
|
| 6 |
+
"heads": 2,
|
| 7 |
+
"trajectory_mode": "partition",
|
| 8 |
+
"coupler_mode": "mean_field",
|
| 9 |
+
"initial_spread": 0.01,
|
| 10 |
+
"integrator": "yoshida",
|
| 11 |
+
"holographic": true
|
| 12 |
+
},
|
| 13 |
+
"physics": {
|
| 14 |
+
"embedding": {
|
| 15 |
+
"type": "functional",
|
| 16 |
+
"mode": "linear",
|
| 17 |
+
"coord_dim": 16,
|
| 18 |
+
"impulse_scale": 80.0
|
| 19 |
+
},
|
| 20 |
+
"readout": {
|
| 21 |
+
"type": "implicit",
|
| 22 |
+
"coord_dim": 16
|
| 23 |
+
},
|
| 24 |
+
"active_inference": {
|
| 25 |
+
"enabled": false,
|
| 26 |
+
"dynamic_time": {
|
| 27 |
+
"enabled": false
|
| 28 |
+
},
|
| 29 |
+
"reactive_curvature": {
|
| 30 |
+
"enabled": false,
|
| 31 |
+
"plasticity": 0.05
|
| 32 |
+
},
|
| 33 |
+
"singularities": {
|
| 34 |
+
"enabled": true,
|
| 35 |
+
"strength": 5.0,
|
| 36 |
+
"threshold": 0.8
|
| 37 |
+
},
|
| 38 |
+
"topology": {
|
| 39 |
+
"type": "torus",
|
| 40 |
+
"riemannian_type": "low_rank"
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
"fractal": {
|
| 44 |
+
"enabled": false,
|
| 45 |
+
"threshold": 0.5,
|
| 46 |
+
"alpha": 0.2
|
| 47 |
+
},
|
| 48 |
+
"stability": {
|
| 49 |
+
"enable_trace_normalization": true,
|
| 50 |
+
"base_dt": 0.4,
|
| 51 |
+
"velocity_saturation": 15.0,
|
| 52 |
+
"friction": 2.0,
|
| 53 |
+
"toroidal_curvature_scale": 0.01
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
}
|
convergence_plot.png
ADDED
|
gfn/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GFN (Geodesic Flow Network) Package
|
| 3 |
+
==================================
|
| 4 |
+
Unified framework for Geodesic State Space Models (G-SSM)
|
| 5 |
+
and Inertial State Networks (ISN).
|
| 6 |
+
|
| 7 |
+
This package implements the GFN paradigm as a platform for
|
| 8 |
+
physics-informed neural dynamics.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# ── Realizations ──────────────────────────────────────────────────────────────
|
| 12 |
+
from .realizations import api, gssm, isn
|
| 13 |
+
from .realizations.api import create, load, save
|
| 14 |
+
|
| 15 |
+
# ── Dynamic Registry
|
| 16 |
+
REALIZATIONS = api.list_available()
|
| 17 |
+
|
| 18 |
+
# ── Package Metadata ──────────────────────────────────────────────────────────
|
| 19 |
+
__version__ = "2.7.0"
|
| 20 |
+
__author__ = "DepthMuun"
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"gssm",
|
| 24 |
+
"isn",
|
| 25 |
+
"api",
|
| 26 |
+
"create",
|
| 27 |
+
"load",
|
| 28 |
+
"save",
|
| 29 |
+
"REALIZATIONS",
|
| 30 |
+
]
|
gfn/realizations/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GFN Realizations Subpackage
|
| 3 |
+
===========================
|
| 4 |
+
Contains specific implementations of the GFN paradigm:
|
| 5 |
+
- G-SSM: Geodesic State Space Model (Riemannian/Symplectic)
|
| 6 |
+
- ISN: Inertial State Network (Physics-Informed Interaction Engine)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from . import api
|
| 10 |
+
from .api import create, list_available
|
| 11 |
+
|
| 12 |
+
# Trigger registration of standard realizations
|
| 13 |
+
from . import gssm
|
| 14 |
+
from . import isn
|
| 15 |
+
|
| 16 |
+
# Future realizations can be added here or via external plugins
|
| 17 |
+
# from . import rt
|
| 18 |
+
|
| 19 |
+
__all__ = ['gssm', 'isn', 'api', 'create', 'list_available']
|
gfn/realizations/api.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GFN Realizations API Router (Purified Version)
|
| 3 |
+
==============================================
|
| 4 |
+
Agnostic factory and dynamic registry for GFN architectural realizations.
|
| 5 |
+
Follows SOLID principles: open for extension, closed for modification.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any, Optional, Protocol, runtime_checkable
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
@runtime_checkable
|
| 15 |
+
class RealizationProvider(Protocol):
|
| 16 |
+
"""Protocol defining the interface any GFN realization must provide."""
|
| 17 |
+
def create(self, **kwargs) -> nn.Module: ...
|
| 18 |
+
def save(self, model: nn.Module, path: str): ...
|
| 19 |
+
def load(self, path: str, **kwargs) -> nn.Module: ...
|
| 20 |
+
|
| 21 |
+
# The Dynamic Registry
|
| 22 |
+
_REGISTRY: Dict[str, RealizationProvider] = {}
|
| 23 |
+
|
| 24 |
+
def register(name: str, provider: RealizationProvider):
|
| 25 |
+
"""
|
| 26 |
+
Register a new realization architecture.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
name: Unique identifier for the realization.
|
| 30 |
+
provider: An object or module implementing the RealizationProvider protocol.
|
| 31 |
+
"""
|
| 32 |
+
name = name.lower()
|
| 33 |
+
if name in _REGISTRY:
|
| 34 |
+
logger.debug(f"Overwriting GFN realization provider: {name}")
|
| 35 |
+
_REGISTRY[name] = provider
|
| 36 |
+
|
| 37 |
+
def list_available() -> List[str]:
|
| 38 |
+
"""List all dynamically registered architectural realizations."""
|
| 39 |
+
return list(_REGISTRY.keys())
|
| 40 |
+
|
| 41 |
+
def create(name: str, **kwargs) -> nn.Module:
|
| 42 |
+
"""
|
| 43 |
+
Unified factory to create any registered GFN realization by name.
|
| 44 |
+
"""
|
| 45 |
+
name = name.lower()
|
| 46 |
+
if name not in _REGISTRY:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"GFN Error: Realization '{name}' is not registered. "
|
| 49 |
+
f"Ensure the subpackage is imported. Available: {list_available()}"
|
| 50 |
+
)
|
| 51 |
+
return _REGISTRY[name].create(**kwargs)
|
| 52 |
+
|
| 53 |
+
def save(model: nn.Module, path: str, realization: Optional[str] = None):
|
| 54 |
+
"""Unified save interface delegate."""
|
| 55 |
+
if realization and realization.lower() in _REGISTRY:
|
| 56 |
+
_REGISTRY[realization.lower()].save(model, path)
|
| 57 |
+
else:
|
| 58 |
+
import torch
|
| 59 |
+
torch.save(model.state_dict(), path)
|
| 60 |
+
|
| 61 |
+
def load(path: str, realization: str, **kwargs) -> nn.Module:
|
| 62 |
+
"""Unified load interface delegate."""
|
| 63 |
+
realization = realization.lower()
|
| 64 |
+
if realization not in _REGISTRY:
|
| 65 |
+
raise ValueError(f"GFN Error: Realization provider for '{realization}' not found.")
|
| 66 |
+
return _REGISTRY[realization].load(path, **kwargs)
|
gfn/realizations/gssm/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Export core API
|
| 2 |
+
from .api import create, save, load, Model, Manifold, loss, Trainer
|
| 3 |
+
|
| 4 |
+
# Register with central realization registry
|
| 5 |
+
try:
|
| 6 |
+
from .. import api as central_api
|
| 7 |
+
from . import api as gssm_api
|
| 8 |
+
central_api.register('gssm', gssm_api)
|
| 9 |
+
except ImportError:
|
| 10 |
+
pass # Fallback for standalone GSSM usage
|
gfn/realizations/gssm/api.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/api.py — GFN V5
|
| 3 |
+
Interfaz pública simplificada y orquestación de alto nivel.
|
| 4 |
+
Centraliza la creación, carga y evaluación de modelos.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Dict, Any, Union
|
| 10 |
+
|
| 11 |
+
from .models.factory import ModelFactory
|
| 12 |
+
from .models.manifold import ManifoldModel
|
| 13 |
+
from .losses.factory import LossFactory
|
| 14 |
+
from .training.trainer import GFNTrainer
|
| 15 |
+
from .training.evaluation import ManifoldMetricEvaluator
|
| 16 |
+
|
| 17 |
+
# -- Alias principales
|
| 18 |
+
Model = ManifoldModel
|
| 19 |
+
Manifold = ManifoldModel
|
| 20 |
+
Trainer = GFNTrainer
|
| 21 |
+
|
| 22 |
+
def create(*args, **kwargs):
|
| 23 |
+
"""Factory para modelos Manifold (V5)."""
|
| 24 |
+
return ModelFactory.create(*args, **kwargs)
|
| 25 |
+
|
| 26 |
+
def loss(config, **kwargs):
|
| 27 |
+
"""Factory para funciones de pérdida (V5)."""
|
| 28 |
+
return LossFactory.create(config, **kwargs)
|
| 29 |
+
|
| 30 |
+
def save(model: nn.Module, path: str):
|
| 31 |
+
"""
|
| 32 |
+
Guarda el modelo y su configuración (HuggingFace Style).
|
| 33 |
+
"""
|
| 34 |
+
if hasattr(model, 'save_pretrained'):
|
| 35 |
+
model.save_pretrained(path)
|
| 36 |
+
else:
|
| 37 |
+
# Fallback para modelos que no heredan de BaseModel
|
| 38 |
+
torch.save({'state_dict': model.state_dict()}, path)
|
| 39 |
+
|
| 40 |
+
def load(path: str, device: Optional[str] = None):
|
| 41 |
+
"""
|
| 42 |
+
Carga un modelo guardado junto con su configuración.
|
| 43 |
+
Soporta directorios (HF Style) o archivos .pth/.bin legados.
|
| 44 |
+
"""
|
| 45 |
+
import os
|
| 46 |
+
if os.path.isdir(path):
|
| 47 |
+
return ModelFactory.from_pretrained(path)
|
| 48 |
+
|
| 49 |
+
# Fallback para archivos aislados legados
|
| 50 |
+
checkpoint = torch.load(path, map_location=device or 'cpu', weights_only=True)
|
| 51 |
+
config = checkpoint.get('config')
|
| 52 |
+
if config is None:
|
| 53 |
+
raise ValueError(f"No se encontró configuración en el checkpoint {path}. Use directorios HF para carga completa.")
|
| 54 |
+
|
| 55 |
+
model = create(config=config)
|
| 56 |
+
|
| 57 |
+
# Robust state_dict extraction (handles different saving conventions)
|
| 58 |
+
state_dict = checkpoint.get('state_dict') or checkpoint.get('model') or checkpoint
|
| 59 |
+
|
| 60 |
+
# Filter state_dict against the model's actual parameters
|
| 61 |
+
model_state = model.state_dict()
|
| 62 |
+
filtered_state = {k: v for k, v in state_dict.items() if k in model_state}
|
| 63 |
+
|
| 64 |
+
# Log filtered keys for debugging (optional)
|
| 65 |
+
n_filtered = len(state_dict) - len(filtered_state)
|
| 66 |
+
if n_filtered > 0:
|
| 67 |
+
import logging
|
| 68 |
+
logging.getLogger("gssm.api").info(f"Filtered {n_filtered} unexpected keys from state_dict (legacy or auxiliary data).")
|
| 69 |
+
|
| 70 |
+
# Load with strict=False to handle potential missing non-essential parameters
|
| 71 |
+
model.load_state_dict(filtered_state, strict=False)
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
def benchmark(model: nn.Module, dataloader: torch.utils.data.DataLoader,
|
| 75 |
+
device: Optional[str] = None) -> Dict[str, float]:
|
| 76 |
+
"""
|
| 77 |
+
Ejecuta una evaluación rápida de métricas geométricas y de tarea.
|
| 78 |
+
"""
|
| 79 |
+
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 80 |
+
model.to(device)
|
| 81 |
+
model.eval()
|
| 82 |
+
|
| 83 |
+
evaluator = ManifoldMetricEvaluator(model)
|
| 84 |
+
all_x, all_v, all_y = [], [], []
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
for x, y in dataloader:
|
| 88 |
+
x, y = x.to(device), y.to(device)
|
| 89 |
+
logits, (xf, vf), info = model(x)
|
| 90 |
+
|
| 91 |
+
all_x.append(xf.detach().cpu())
|
| 92 |
+
all_v.append(vf.detach().cpu())
|
| 93 |
+
all_y.append(y.detach().cpu())
|
| 94 |
+
|
| 95 |
+
if not all_x:
|
| 96 |
+
return {}
|
| 97 |
+
|
| 98 |
+
x_total = torch.cat(all_x, dim=0)
|
| 99 |
+
v_total = torch.cat(all_v, dim=0)
|
| 100 |
+
y_total = torch.cat(all_y, dim=0)
|
| 101 |
+
|
| 102 |
+
return evaluator.full_report(x_total, v_total, y_total)
|
gfn/realizations/gssm/config/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/config/__init__.py
|
| 3 |
+
Public API for the configuration module — GFN V5
|
| 4 |
+
Centralized configuration system for all GFN components.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .schema import (
|
| 8 |
+
TopologyConfig,
|
| 9 |
+
StabilityConfig,
|
| 10 |
+
DynamicTimeConfig,
|
| 11 |
+
HysteresisConfig,
|
| 12 |
+
ActiveInferenceConfig,
|
| 13 |
+
EmbeddingConfig,
|
| 14 |
+
ReadoutConfig,
|
| 15 |
+
MixtureConfig,
|
| 16 |
+
DynamicsConfig,
|
| 17 |
+
FractalConfig,
|
| 18 |
+
SingularityConfig,
|
| 19 |
+
PhysicsConfig,
|
| 20 |
+
TrainerConfig,
|
| 21 |
+
ManifoldConfig,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from .defaults import (
|
| 25 |
+
PHYSICS_DEFAULTS,
|
| 26 |
+
MODEL_DEFAULTS,
|
| 27 |
+
TRAINING_DEFAULTS,
|
| 28 |
+
LOSS_DEFAULTS,
|
| 29 |
+
get_default,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from .loader import dict_to_physics_config
|
| 33 |
+
from .validator import ConfigValidator, validate_manifold_config, validate_and_print
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
# Schema classes
|
| 37 |
+
"TopologyConfig",
|
| 38 |
+
"StabilityConfig",
|
| 39 |
+
"DynamicTimeConfig",
|
| 40 |
+
"HysteresisConfig",
|
| 41 |
+
"ActiveInferenceConfig",
|
| 42 |
+
"EmbeddingConfig",
|
| 43 |
+
"ReadoutConfig",
|
| 44 |
+
"MixtureConfig",
|
| 45 |
+
"DynamicsConfig",
|
| 46 |
+
"FractalConfig",
|
| 47 |
+
"SingularityConfig",
|
| 48 |
+
"PhysicsConfig",
|
| 49 |
+
"TrainerConfig",
|
| 50 |
+
"ManifoldConfig",
|
| 51 |
+
# Defaults
|
| 52 |
+
"PHYSICS_DEFAULTS",
|
| 53 |
+
"MODEL_DEFAULTS",
|
| 54 |
+
"TRAINING_DEFAULTS",
|
| 55 |
+
"LOSS_DEFAULTS",
|
| 56 |
+
"get_default",
|
| 57 |
+
# Loader
|
| 58 |
+
"dict_to_physics_config",
|
| 59 |
+
# Validator
|
| 60 |
+
"ConfigValidator",
|
| 61 |
+
"validate_manifold_config",
|
| 62 |
+
"validate_and_print",
|
| 63 |
+
]
|
gfn/realizations/gssm/config/defaults.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config/defaults.py — GFN V5
|
| 3 |
+
Valores por defecto centralizados para todas las configuraciones.
|
| 4 |
+
Elimina hardcodes dispersos en implementaciones.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from ..constants import (
|
| 9 |
+
DEFAULT_DT, DEFAULT_FRICTION, DEFAULT_PLASTICITY,
|
| 10 |
+
MAX_VELOCITY, CURVATURE_CLAMP, VELOCITY_FRICTION_SCALE,
|
| 11 |
+
SINGULARITY_THRESHOLD, BLACK_HOLE_STRENGTH, EPSILON_STANDARD,
|
| 12 |
+
TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# ─── Física ─────────────────────────────────────────────────────────────────
|
| 16 |
+
PHYSICS_DEFAULTS: Dict[str, Any] = {
|
| 17 |
+
# Topología
|
| 18 |
+
'topology_type': TOPOLOGY_EUCLIDEAN,
|
| 19 |
+
'riemannian_type': 'low_rank',
|
| 20 |
+
'major_radius_R': 2.0,
|
| 21 |
+
'minor_radius_r': 1.0,
|
| 22 |
+
|
| 23 |
+
# Estabilidad — referencias a constants.py (una sola fuente de verdad)
|
| 24 |
+
'base_dt': DEFAULT_DT,
|
| 25 |
+
'adaptive_dt': True,
|
| 26 |
+
'friction': DEFAULT_FRICTION,
|
| 27 |
+
'velocity_clamp': MAX_VELOCITY,
|
| 28 |
+
'curvature_clamp': CURVATURE_CLAMP,
|
| 29 |
+
'enable_trace_normalization': True,
|
| 30 |
+
'velocity_friction_scale': VELOCITY_FRICTION_SCALE,
|
| 31 |
+
'integrator_type': 'leapfrog',
|
| 32 |
+
'friction_mode': 'static',
|
| 33 |
+
|
| 34 |
+
# Inferencia activa
|
| 35 |
+
'active_inference_enabled': True,
|
| 36 |
+
'holographic_geometry': False,
|
| 37 |
+
'plasticity': DEFAULT_PLASTICITY,
|
| 38 |
+
|
| 39 |
+
# Singularities
|
| 40 |
+
'singularity_enabled': False,
|
| 41 |
+
'singularity_threshold': SINGULARITY_THRESHOLD,
|
| 42 |
+
'singularity_strength': BLACK_HOLE_STRENGTH,
|
| 43 |
+
'singularity_epsilon': EPSILON_STANDARD,
|
| 44 |
+
|
| 45 |
+
# Hysteresis
|
| 46 |
+
'hysteresis_enabled': False,
|
| 47 |
+
'hysteresis_decay': 0.95,
|
| 48 |
+
'hysteresis_ghost_force': True,
|
| 49 |
+
|
| 50 |
+
# Stochasticity / Curiosity
|
| 51 |
+
'stochasticity_enabled': False,
|
| 52 |
+
'stochasticity_type': 'brownian',
|
| 53 |
+
'stochasticity_sigma': 0.01,
|
| 54 |
+
'curiosity_enabled': False,
|
| 55 |
+
'curiosity_strength': 0.1,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# ─── Modelo ──────────────────────────────────────────────────────────────────
|
| 59 |
+
MODEL_DEFAULTS: Dict[str, Any] = {
|
| 60 |
+
'dim': 64,
|
| 61 |
+
'heads': 4,
|
| 62 |
+
'depth': 2,
|
| 63 |
+
'rank': 16,
|
| 64 |
+
'vocab_size': 256,
|
| 65 |
+
'holographic': False,
|
| 66 |
+
'pooling_type': None,
|
| 67 |
+
'initial_spread': 1e-3,
|
| 68 |
+
'n_trajectories': 1,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# ─── Entrenamiento ───────────────────────────────────────────────────────────
|
| 72 |
+
TRAINING_DEFAULTS: Dict[str, Any] = {
|
| 73 |
+
'lr': 1e-3,
|
| 74 |
+
'optimizer_type': 'adam',
|
| 75 |
+
'weight_decay': 0.0,
|
| 76 |
+
'grad_clip': 1.0,
|
| 77 |
+
'epochs': 10,
|
| 78 |
+
'batch_size': 32,
|
| 79 |
+
'scheduler_type': 'cosine_warmup',
|
| 80 |
+
'warmup_steps': 100,
|
| 81 |
+
'min_lr': 1e-6,
|
| 82 |
+
'task': 'lm',
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# ─── Pérdidas ────────────────────────────────────────────────────────────────
|
| 86 |
+
LOSS_DEFAULTS: Dict[str, Any] = {
|
| 87 |
+
'type': 'generative',
|
| 88 |
+
'mode': 'nll',
|
| 89 |
+
'entropy_coef': 0.0,
|
| 90 |
+
'label_smoothing': 0.0,
|
| 91 |
+
|
| 92 |
+
# Physics-informed
|
| 93 |
+
'lambda_physics': 0.01,
|
| 94 |
+
'lambda_geo': 0.001,
|
| 95 |
+
'lambda_ham': 0.0,
|
| 96 |
+
'lambda_kin': 0.0,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_default(section: str, key: str, fallback=None):
|
| 101 |
+
"""
|
| 102 |
+
Obtiene un valor por defecto desde la sección correspondiente.
|
| 103 |
+
Uso: get_default('physics', 'base_dt') -> 0.1
|
| 104 |
+
"""
|
| 105 |
+
mapping = {
|
| 106 |
+
'physics': PHYSICS_DEFAULTS,
|
| 107 |
+
'model': MODEL_DEFAULTS,
|
| 108 |
+
'training': TRAINING_DEFAULTS,
|
| 109 |
+
'loss': LOSS_DEFAULTS,
|
| 110 |
+
}
|
| 111 |
+
return mapping.get(section, {}).get(key, fallback)
|
gfn/realizations/gssm/config/loader.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config/loader.py — GFN V5
|
| 3 |
+
Conversión de dicts de configuración a PhysicsConfig tipado.
|
| 4 |
+
Soporte para overrides anidados sobre configs existentes.
|
| 5 |
+
"""
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
from .schema import (
|
| 8 |
+
PhysicsConfig, TopologyConfig, StabilityConfig, DynamicsConfig,
|
| 9 |
+
ActiveInferenceConfig, DynamicTimeConfig, HysteresisConfig,
|
| 10 |
+
EmbeddingConfig, FractalConfig, SingularityConfig,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def dict_to_physics_config(d: Dict[str, Any]) -> PhysicsConfig:
|
| 15 |
+
"""
|
| 16 |
+
Convierte un dict anidado en un PhysicsConfig tipado.
|
| 17 |
+
|
| 18 |
+
Soporta todos los sub-campos de PhysicsConfig. Los campos no presentes
|
| 19 |
+
en el dict mantienen sus valores default del schema.
|
| 20 |
+
Si `d` ya es PhysicsConfig, lo devuelve intacto.
|
| 21 |
+
"""
|
| 22 |
+
if isinstance(d, PhysicsConfig):
|
| 23 |
+
return d
|
| 24 |
+
|
| 25 |
+
cfg = PhysicsConfig()
|
| 26 |
+
_apply_dict_to_physics_config(cfg, d)
|
| 27 |
+
return cfg
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_physics_overrides(cfg: PhysicsConfig, overrides: Dict[str, Any]) -> PhysicsConfig:
|
| 31 |
+
"""
|
| 32 |
+
Aplica un dict de overrides sobre un PhysicsConfig EXISTENTE (in-place).
|
| 33 |
+
|
| 34 |
+
A diferencia de dict_to_physics_config(), esta función NO parte de defaults
|
| 35 |
+
sino que modifica solo los campos presentes en el dict, dejando el resto intacto.
|
| 36 |
+
Es la función que usa ModelFactory cuando se combina preset + physics kwarg.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
cfg: PhysicsConfig existente (ej. resultado de get_preset())
|
| 40 |
+
overrides: Dict anidado con los campos a sobreescribir
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
El mismo cfg modificado in-place (también retornado para encadenamiento).
|
| 44 |
+
"""
|
| 45 |
+
if not overrides:
|
| 46 |
+
return cfg
|
| 47 |
+
_apply_dict_to_physics_config(cfg, overrides)
|
| 48 |
+
return cfg
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _apply_dict_to_physics_config(cfg: PhysicsConfig, d: Dict[str, Any]) -> None:
|
| 52 |
+
"""Función interna — aplica los campos del dict sobre cfg in-place."""
|
| 53 |
+
|
| 54 |
+
# ── Topology ──────────────────────────────────────────────────────────────
|
| 55 |
+
t_d = d.get('topology', d.get('topology_config', {}))
|
| 56 |
+
if isinstance(t_d, dict) and t_d:
|
| 57 |
+
_apply(cfg.topology, t_d, [
|
| 58 |
+
'type', 'R', 'r', 'curvature',
|
| 59 |
+
'riemannian_type', 'riemannian_rank', 'riemannian_class',
|
| 60 |
+
'geometry_scope'
|
| 61 |
+
])
|
| 62 |
+
if 'major_radius' in t_d: cfg.topology.R = t_d['major_radius']
|
| 63 |
+
if 'minor_radius' in t_d: cfg.topology.r = t_d['minor_radius']
|
| 64 |
+
|
| 65 |
+
# ── Stability ─────────────────────────────────────────────────────────────
|
| 66 |
+
s_d = d.get('stability', d.get('stability_config', {}))
|
| 67 |
+
if isinstance(s_d, dict) and s_d:
|
| 68 |
+
_apply(cfg.stability, s_d, [
|
| 69 |
+
'base_dt', 'adaptive', 'dt_min', 'dt_max',
|
| 70 |
+
'enable_trace_normalization', 'wrap_x',
|
| 71 |
+
'friction', 'velocity_friction_scale',
|
| 72 |
+
'curvature_clamp', 'friction_mode',
|
| 73 |
+
'integrator_type',
|
| 74 |
+
# alias legacy
|
| 75 |
+
'velocity_saturation', # → ignorado, no existe en StabilityConfig
|
| 76 |
+
])
|
| 77 |
+
# Alias de nombres legacy
|
| 78 |
+
if 'toroidal_curvature_scale' in s_d:
|
| 79 |
+
cfg.stability.curvature_clamp = s_d['toroidal_curvature_scale']
|
| 80 |
+
|
| 81 |
+
# ── Dynamics ──────────────────────────────────────────────────────────────
|
| 82 |
+
dyn_d = d.get('dynamics', d.get('dynamics_config', {}))
|
| 83 |
+
if isinstance(dyn_d, dict) and dyn_d:
|
| 84 |
+
if 'type' in dyn_d:
|
| 85 |
+
cfg.dynamics.type = dyn_d['type']
|
| 86 |
+
|
| 87 |
+
# ── Active Inference ──────────────────────────────────────────────────────
|
| 88 |
+
ai_d = d.get('active_inference', d.get('active_inference_config', {}))
|
| 89 |
+
if isinstance(ai_d, dict) and ai_d:
|
| 90 |
+
_apply(cfg.active_inference, ai_d, [
|
| 91 |
+
'enabled', 'holographic_geometry',
|
| 92 |
+
'thermodynamic_geometry', 'plasticity',
|
| 93 |
+
])
|
| 94 |
+
# Dynamic time
|
| 95 |
+
dt_d = ai_d.get('dynamic_time', {})
|
| 96 |
+
if isinstance(dt_d, dict) and dt_d:
|
| 97 |
+
_apply(cfg.active_inference.dynamic_time, dt_d, ['enabled', 'type'])
|
| 98 |
+
# Reactive curvature — es un dict interno
|
| 99 |
+
rc_d = ai_d.get('reactive_curvature', {})
|
| 100 |
+
if isinstance(rc_d, dict) and rc_d:
|
| 101 |
+
cfg.active_inference.reactive_curvature.update(rc_d)
|
| 102 |
+
# Stochasticity — es un dict interno
|
| 103 |
+
st_d = ai_d.get('stochasticity', {})
|
| 104 |
+
if isinstance(st_d, dict) and st_d:
|
| 105 |
+
cfg.active_inference.stochasticity.update(st_d)
|
| 106 |
+
# Curiosity — es un dict interno
|
| 107 |
+
cu_d = ai_d.get('curiosity', {})
|
| 108 |
+
if isinstance(cu_d, dict) and cu_d:
|
| 109 |
+
cfg.active_inference.curiosity.update(cu_d)
|
| 110 |
+
# ── Hysteresis (pueden estar en raíz O dentro de active_inference) ────────
|
| 111 |
+
hyst_src = d.get('hysteresis', ai_d.get('hysteresis', {}) if isinstance(ai_d, dict) else {})
|
| 112 |
+
if isinstance(hyst_src, dict) and hyst_src:
|
| 113 |
+
_apply(cfg.hysteresis, hyst_src, [
|
| 114 |
+
'enabled', 'ghost_force', 'hyst_decay',
|
| 115 |
+
'hyst_update_w', 'hyst_update_b',
|
| 116 |
+
'hyst_readout_w', 'hyst_readout_b',
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
# ── Singularities (pueden estar en raíz O dentro de active_inference) ─────
|
| 120 |
+
sing_src = d.get('singularities', ai_d.get('singularities', {}) if isinstance(ai_d, dict) else {})
|
| 121 |
+
if isinstance(sing_src, dict) and sing_src:
|
| 122 |
+
_apply(cfg.singularities, sing_src, [
|
| 123 |
+
'enabled', 'epsilon', 'strength', 'threshold'
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
# ── Embedding ─────────────────────────────────────────────────────────────
|
| 127 |
+
emb_d = d.get('embedding', d.get('embedding_config', {}))
|
| 128 |
+
if isinstance(emb_d, dict) and emb_d:
|
| 129 |
+
_apply(cfg.embedding, emb_d, [
|
| 130 |
+
'type', 'mode', 'coord_dim', 'impulse_scale', 'omega_0'
|
| 131 |
+
])
|
| 132 |
+
|
| 133 |
+
# ── Readout ───────────────────────────────────────────────────────────────
|
| 134 |
+
read_d = d.get('readout', d.get('readout_config', {}))
|
| 135 |
+
if isinstance(read_d, dict) and read_d:
|
| 136 |
+
_apply(cfg.readout, read_d, ['type'])
|
| 137 |
+
|
| 138 |
+
# ── Mixture ───────────────────────────────────────────────────────────────
|
| 139 |
+
mix_d = d.get('mixture', d.get('mixture_config', {}))
|
| 140 |
+
if isinstance(mix_d, dict) and mix_d:
|
| 141 |
+
_apply(cfg.mixture, mix_d, ['coupler_mode'])
|
| 142 |
+
|
| 143 |
+
# ── Fractal ───────────────────────────────────────────────────────────────
|
| 144 |
+
frac_d = d.get('fractal', {})
|
| 145 |
+
if isinstance(frac_d, dict) and frac_d:
|
| 146 |
+
_apply(cfg.fractal, frac_d, ['enabled', 'threshold', 'alpha'])
|
| 147 |
+
|
| 148 |
+
# ── Top-level trajectory_mode ─────────────────────────────────────────────
|
| 149 |
+
if 'trajectory_mode' in d:
|
| 150 |
+
cfg.trajectory_mode = d['trajectory_mode']
|
| 151 |
+
|
| 152 |
+
# ── Attention/mixer alias (legacy ECG configs) ────────────────────────────
|
| 153 |
+
# 'attention': {'mixer_type': 'low_rank'} — se ignora acá, aplica en ManifoldConfig
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _apply(target, source: dict, keys: list) -> None:
|
| 157 |
+
"""Copia las claves presentes en source hacia target (setattr)."""
|
| 158 |
+
for k in keys:
|
| 159 |
+
if k in source:
|
| 160 |
+
try:
|
| 161 |
+
setattr(target, k, source[k])
|
| 162 |
+
except AttributeError:
|
| 163 |
+
pass # clave no existe en el dataclass — ignorar silenciosamente
|
gfn/realizations/gssm/config/schema.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# schema.py — GFN V5
|
| 2 |
+
# Definiciones de clases de configuración (Schema)
|
| 3 |
+
# SEPARACIÓN: Los valores por defecto van a defaults.py, las constantes físicas a constants.py
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field, asdict
|
| 6 |
+
from typing import Dict, Any, Optional, List
|
| 7 |
+
|
| 8 |
+
# Importar constantes físicas正确adas
|
| 9 |
+
from ..constants import (
|
| 10 |
+
EPSILON_STANDARD,
|
| 11 |
+
TOPOLOGY_TORUS,
|
| 12 |
+
MIN_DT,
|
| 13 |
+
MAX_DT,
|
| 14 |
+
CURVATURE_CLAMP,
|
| 15 |
+
SINGULARITY_THRESHOLD,
|
| 16 |
+
BLACK_HOLE_STRENGTH,
|
| 17 |
+
DEFAULT_DT,
|
| 18 |
+
DEFAULT_FRICTION,
|
| 19 |
+
DEFAULT_PLASTICITY,
|
| 20 |
+
MAX_VELOCITY,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TopologyConfig:
|
| 26 |
+
"""Configuración de topología del manifold."""
|
| 27 |
+
type: str = TOPOLOGY_TORUS
|
| 28 |
+
R: float = 2.0 # Radio mayor del toro (default)
|
| 29 |
+
r: float = 1.0 # Radio menor del toro (default)
|
| 30 |
+
curvature: float = 0.0
|
| 31 |
+
riemannian_type: str = 'reactive'
|
| 32 |
+
riemannian_rank: int = 16
|
| 33 |
+
riemannian_class: Optional[str] = None
|
| 34 |
+
geometry_scope: str = 'local' # 'local' (dim/heads) or 'global' (full dim)
|
| 35 |
+
# NUEVO: Parámetros aprendibles
|
| 36 |
+
learnable_R: bool = True # Hacer R aprendible (como dice el paper)
|
| 37 |
+
learnable_r: bool = True # Hacer r aprendible (como dice el paper)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class StabilityConfig:
|
| 42 |
+
"""Configuración de estabilidad numérica."""
|
| 43 |
+
base_dt: float = DEFAULT_DT
|
| 44 |
+
adaptive: bool = True
|
| 45 |
+
dt_min: float = MIN_DT
|
| 46 |
+
dt_max: float = MAX_DT
|
| 47 |
+
enable_trace_normalization: bool = True
|
| 48 |
+
wrap_x: bool = True
|
| 49 |
+
friction: float = DEFAULT_FRICTION
|
| 50 |
+
velocity_friction_scale: float = 0.0
|
| 51 |
+
velocity_saturation: float = 0.0 # P2.3: 0 = disabled, >0 = clamp magnitude via tanh
|
| 52 |
+
curvature_clamp: float = CURVATURE_CLAMP
|
| 53 |
+
friction_mode: str = 'static' # 'static' or 'lif'
|
| 54 |
+
integrator_type: str = 'leapfrog'
|
| 55 |
+
toroidal_curvature_scale: float = 0.01 # scale for torus Christoffel contribution
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class DynamicTimeConfig:
|
| 60 |
+
enabled: bool = False
|
| 61 |
+
type: str = 'riemannian'
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class HysteresisConfig:
|
| 66 |
+
enabled: bool = False
|
| 67 |
+
ghost_force: bool = True
|
| 68 |
+
hyst_decay: float = 0.1
|
| 69 |
+
hyst_update_w: float = 1.0
|
| 70 |
+
hyst_update_b: float = 0.0
|
| 71 |
+
hyst_readout_w: float = 1.0
|
| 72 |
+
hyst_readout_b: float = 0.0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ActiveInferenceConfig:
|
| 77 |
+
enabled: bool = False
|
| 78 |
+
holographic_geometry: bool = False
|
| 79 |
+
thermodynamic_geometry: bool = False
|
| 80 |
+
plasticity: float = DEFAULT_PLASTICITY
|
| 81 |
+
dynamic_time: DynamicTimeConfig = field(default_factory=DynamicTimeConfig)
|
| 82 |
+
reactive_curvature: Dict[str, Any] = field(default_factory=lambda: {
|
| 83 |
+
"enabled": False,
|
| 84 |
+
"plasticity": 0.0
|
| 85 |
+
})
|
| 86 |
+
geodesic_lensing: Dict[str, Any] = field(default_factory=lambda: {"enabled": False})
|
| 87 |
+
|
| 88 |
+
# Exploration / Noise
|
| 89 |
+
stochasticity: Dict[str, Any] = field(default_factory=lambda: {
|
| 90 |
+
"enabled": False,
|
| 91 |
+
"type": "brownian",
|
| 92 |
+
"sigma": 0.01,
|
| 93 |
+
"theta": 0.15,
|
| 94 |
+
"mu": 0.0
|
| 95 |
+
})
|
| 96 |
+
curiosity: Dict[str, Any] = field(default_factory=lambda: {
|
| 97 |
+
"enabled": False,
|
| 98 |
+
"strength": 0.1,
|
| 99 |
+
"decay": 0.99
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class EmbeddingConfig:
|
| 105 |
+
type: str = 'standard'
|
| 106 |
+
mode: str = 'linear'
|
| 107 |
+
coord_dim: int = 16
|
| 108 |
+
impulse_scale: float = 1.0
|
| 109 |
+
omega_0: float = 30.0
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class ReadoutConfig:
|
| 114 |
+
type: str = 'standard'
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class MixtureConfig:
|
| 119 |
+
coupler_mode: str = 'mean_field'
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclass
|
| 123 |
+
class DynamicsConfig:
|
| 124 |
+
type: str = 'direct'
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@dataclass
|
| 128 |
+
class FractalConfig:
|
| 129 |
+
enabled: bool = False
|
| 130 |
+
threshold: float = 0.5
|
| 131 |
+
alpha: float = 0.2
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class SingularityConfig:
|
| 136 |
+
enabled: bool = False
|
| 137 |
+
epsilon: float = EPSILON_STANDARD
|
| 138 |
+
strength: float = BLACK_HOLE_STRENGTH
|
| 139 |
+
threshold: float = SINGULARITY_THRESHOLD
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class PhysicsConfig:
|
| 144 |
+
"""Configuración completa de física."""
|
| 145 |
+
topology: TopologyConfig = field(default_factory=TopologyConfig)
|
| 146 |
+
stability: StabilityConfig = field(default_factory=StabilityConfig)
|
| 147 |
+
dynamics: DynamicsConfig = field(default_factory=DynamicsConfig)
|
| 148 |
+
active_inference: ActiveInferenceConfig = field(default_factory=ActiveInferenceConfig)
|
| 149 |
+
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
| 150 |
+
readout: ReadoutConfig = field(default_factory=ReadoutConfig)
|
| 151 |
+
mixture: MixtureConfig = field(default_factory=MixtureConfig)
|
| 152 |
+
fractal: FractalConfig = field(default_factory=FractalConfig)
|
| 153 |
+
hysteresis: HysteresisConfig = field(default_factory=HysteresisConfig)
|
| 154 |
+
singularities: SingularityConfig = field(default_factory=SingularityConfig)
|
| 155 |
+
trajectory_mode: str = 'partition'
|
| 156 |
+
lensing: Dict[str, Any] = field(default_factory=lambda: {'enabled': False})
|
| 157 |
+
checkpointing: Dict[str, Any] = field(default_factory=lambda: {'enabled': False})
|
| 158 |
+
|
| 159 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 160 |
+
return asdict(self)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclass
|
| 164 |
+
class TrainerConfig:
|
| 165 |
+
lr: float = 1e-3
|
| 166 |
+
optimizer: str = 'adamw'
|
| 167 |
+
max_lr: Optional[float] = None
|
| 168 |
+
total_steps: Optional[int] = None
|
| 169 |
+
loss_config: Dict[str, Any] = field(default_factory=lambda: {
|
| 170 |
+
'lambda_g': 0.001,
|
| 171 |
+
'lambda_h': 0.0,
|
| 172 |
+
'geodesic_mode': 'magnitude'
|
| 173 |
+
})
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@dataclass
|
| 177 |
+
class ManifoldConfig:
|
| 178 |
+
"""Configuración principal del modelo Manifold."""
|
| 179 |
+
vocab_size: int
|
| 180 |
+
dim: int = 512
|
| 181 |
+
depth: int = 4
|
| 182 |
+
heads: int = 4
|
| 183 |
+
rank: int = 32
|
| 184 |
+
integrator: str = 'leapfrog'
|
| 185 |
+
physics: PhysicsConfig = field(default_factory=PhysicsConfig)
|
| 186 |
+
trainer: TrainerConfig = field(default_factory=TrainerConfig)
|
| 187 |
+
adjoint_enabled: bool = False
|
| 188 |
+
adjoint_rtol: float = 1e-4
|
| 189 |
+
adjoint_atol: float = 1e-4
|
| 190 |
+
holographic: bool = False
|
| 191 |
+
impulse_scale: float = 1.0
|
| 192 |
+
dynamics_type: str = 'direct'
|
| 193 |
+
mixer_type: str = 'low_rank'
|
| 194 |
+
trajectory_mode: str = 'partition'
|
| 195 |
+
coupler_mode: str = 'mean_field'
|
| 196 |
+
initial_spread: float = 1e-3
|
| 197 |
+
|
| 198 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 199 |
+
return asdict(self)
|
gfn/realizations/gssm/config/serialization.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from typing import Any, Dict, Type, TypeVar, get_type_hints, get_args, get_origin, Union
|
| 3 |
+
|
| 4 |
+
T = TypeVar('T')
|
| 5 |
+
|
| 6 |
+
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
| 7 |
+
"""
|
| 8 |
+
Reconstructs a nested dataclass from a dictionary.
|
| 9 |
+
Handles nested dataclasses and basic types.
|
| 10 |
+
"""
|
| 11 |
+
if not dataclasses.is_dataclass(cls):
|
| 12 |
+
return data
|
| 13 |
+
|
| 14 |
+
field_types = get_type_hints(cls)
|
| 15 |
+
kwargs = {}
|
| 16 |
+
|
| 17 |
+
for field in dataclasses.fields(cls):
|
| 18 |
+
if field.name in data:
|
| 19 |
+
value = data[field.name]
|
| 20 |
+
field_type = field_types[field.name]
|
| 21 |
+
|
| 22 |
+
# Handle Optional[T]
|
| 23 |
+
origin = get_origin(field_type)
|
| 24 |
+
if origin is Union:
|
| 25 |
+
args = get_args(field_type)
|
| 26 |
+
if type(None) in args:
|
| 27 |
+
# It's an Optional. Find the non-None type
|
| 28 |
+
field_type = [arg for arg in args if arg is not type(None)][0]
|
| 29 |
+
|
| 30 |
+
# Handle nested dataclasses
|
| 31 |
+
if dataclasses.is_dataclass(field_type):
|
| 32 |
+
if value is not None:
|
| 33 |
+
kwargs[field.name] = from_dict(field_type, value)
|
| 34 |
+
else:
|
| 35 |
+
kwargs[field.name] = None
|
| 36 |
+
else:
|
| 37 |
+
kwargs[field.name] = value
|
| 38 |
+
|
| 39 |
+
return cls(**kwargs)
|
gfn/realizations/gssm/config/validator.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validación de configuraciones — GFN V5
|
| 3 |
+
Verifica la compatibilidad de parámetros antes de construir componentes.
|
| 4 |
+
Fusionado de utils/validation.py y config/validator.py original.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Any, List, Optional
|
| 8 |
+
from .schema import ManifoldConfig, PhysicsConfig
|
| 9 |
+
from ..constants import TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN, TOPOLOGY_SPHERE
|
| 10 |
+
|
| 11 |
+
class ConfigValidationError(Exception):
|
| 12 |
+
"""Error de validación de configuración crítica."""
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
class ConfigValidator:
|
| 16 |
+
"""Central validator for GFN configurations."""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def validate_physics(cfg: PhysicsConfig, dim: Optional[int] = None, heads: Optional[int] = None):
|
| 20 |
+
"""
|
| 21 |
+
Validate physical and architectural consistency of PhysicsConfig.
|
| 22 |
+
Raises ConfigValidationError if strict topology/stability rules are violated.
|
| 23 |
+
"""
|
| 24 |
+
# 1. Topology checks
|
| 25 |
+
if cfg.topology.type == TOPOLOGY_TORUS:
|
| 26 |
+
if dim is not None and heads is not None:
|
| 27 |
+
head_dim = dim // heads
|
| 28 |
+
if head_dim % 2 != 0:
|
| 29 |
+
raise ConfigValidationError(
|
| 30 |
+
f"Toroid geometry requires head_dim (dim//heads) to be even. "
|
| 31 |
+
f"Found {dim}//{heads}={head_dim}"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if cfg.topology.type == TOPOLOGY_SPHERE and cfg.topology.curvature <= 0:
|
| 35 |
+
raise ConfigValidationError("Spherical topology requires positive curvature.")
|
| 36 |
+
|
| 37 |
+
# 2. Stability checks
|
| 38 |
+
if cfg.stability.base_dt <= 0:
|
| 39 |
+
raise ConfigValidationError("base_dt must be positive.")
|
| 40 |
+
if cfg.stability.friction < 0:
|
| 41 |
+
raise ConfigValidationError("friction cannot be negative.")
|
| 42 |
+
|
| 43 |
+
# 3. Mode Compatibility
|
| 44 |
+
if cfg.trajectory_mode == 'ensemble' and heads is not None and heads <= 1:
|
| 45 |
+
raise ConfigValidationError("Ensemble trajectory mode requires more than 1 head.")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def validate_manifold_config(config: ManifoldConfig) -> List[str]:
|
| 49 |
+
"""
|
| 50 |
+
Valida un ManifoldConfig completo y su PhysicsConfig anidado.
|
| 51 |
+
Retorna lista de warnings (vacía si todo está OK).
|
| 52 |
+
Lanza ConfigValidationError en errores críticos o de compatibilidad.
|
| 53 |
+
"""
|
| 54 |
+
warnings = []
|
| 55 |
+
|
| 56 |
+
# Validaciones críticas (Raise exceptions)
|
| 57 |
+
if config.dim % config.heads != 0:
|
| 58 |
+
raise ConfigValidationError(
|
| 59 |
+
f"dim={config.dim} no es divisible por heads={config.heads}. "
|
| 60 |
+
f"head_dim={config.dim/config.heads:.1f} no es entero."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if config.vocab_size <= 0:
|
| 64 |
+
raise ConfigValidationError(f"vocab_size={config.vocab_size} debe ser > 0.")
|
| 65 |
+
|
| 66 |
+
if config.depth <= 0:
|
| 67 |
+
raise ConfigValidationError(f"depth={config.depth} debe ser > 0.")
|
| 68 |
+
|
| 69 |
+
# Validate Physics properties via centralized method
|
| 70 |
+
ConfigValidator.validate_physics(config.physics, config.dim, config.heads)
|
| 71 |
+
|
| 72 |
+
# Validaciones suaves (Warnings)
|
| 73 |
+
head_dim = config.dim // config.heads
|
| 74 |
+
topo_type = config.physics.topology.type.lower()
|
| 75 |
+
|
| 76 |
+
if topo_type == TOPOLOGY_TORUS and head_dim % 2 != 0:
|
| 77 |
+
warnings.append(
|
| 78 |
+
f"[WARN] Para geometría toroidal, head_dim={head_dim} debería ser par "
|
| 79 |
+
f"para representaciones sin/cos. Considera usar heads={config.dim // (head_dim + 1)} o similar."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if config.rank > config.dim:
|
| 83 |
+
warnings.append(
|
| 84 |
+
f"[WARN] rank={config.rank} > dim={config.dim}. "
|
| 85 |
+
f"La descomposición no es de rango bajo. ¿Intencional?"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
dt = config.physics.stability.base_dt
|
| 89 |
+
if dt > 1.0:
|
| 90 |
+
warnings.append(f"[WARN] base_dt={dt} > 1.0 puede causar inestabilidad numérica.")
|
| 91 |
+
if dt < 1e-5:
|
| 92 |
+
warnings.append(f"[WARN] base_dt={dt} < 1e-5 puede ralentizar convergencia.")
|
| 93 |
+
|
| 94 |
+
return warnings
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def validate_and_print(config: ManifoldConfig) -> bool:
|
| 98 |
+
"""
|
| 99 |
+
Valida la configuración e imprime warnings.
|
| 100 |
+
Retorna True si es válida, False si hubo errores.
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
warnings = validate_manifold_config(config)
|
| 104 |
+
for w in warnings:
|
| 105 |
+
print(w)
|
| 106 |
+
return True
|
| 107 |
+
except ConfigValidationError as e:
|
| 108 |
+
print(f"[CONFIG ERROR] {e}")
|
| 109 |
+
return False
|
gfn/realizations/gssm/constants.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# constants.py — GFN V5
|
| 2 |
+
# Constantes físicas y matemáticas universales.
|
| 3 |
+
# NO contiene hiperparámetros de entrenamiento ni valores configurables.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# ─── Constantes Matemáticas ─────────────────────────────────────────────────
|
| 8 |
+
PI = 3.14159265358979
|
| 9 |
+
E = 2.718281828459045
|
| 10 |
+
SQRT_2 = 1.4142135623730951
|
| 11 |
+
LOG_2 = 0.6931471805599453
|
| 12 |
+
|
| 13 |
+
# ─── Estabilidad Numérica ─────────────────────────────────────────────────
|
| 14 |
+
EPS = 1e-8
|
| 15 |
+
INF = 1e12
|
| 16 |
+
EPSILON_STANDARD = 1e-7
|
| 17 |
+
EPSILON_SMOOTH = 1e-9
|
| 18 |
+
EPSILON_STRONG = 1e-6
|
| 19 |
+
CLAMP_MIN_STRONG = 1e-4
|
| 20 |
+
|
| 21 |
+
# ─── Límites Físicos ───────────────────────────────────────────────────────
|
| 22 |
+
MIN_DT = 0.001
|
| 23 |
+
MAX_DT = 1.0
|
| 24 |
+
|
| 25 |
+
# ─── Geometría / Curvatura ─────────────────────────────────────────────────
|
| 26 |
+
CURVATURE_CLAMP = 5.0 # Maximum absolute value of Christoffel output
|
| 27 |
+
FRICTION_SCALE = 0.1 # Global friction scaling factor
|
| 28 |
+
VELOCITY_FRICTION_SCALE = 0.01
|
| 29 |
+
|
| 30 |
+
# ─── Gate initialization constants ─────────────────────────────────────────
|
| 31 |
+
GATE_BIAS_OPEN = 2.0 # sigmoid(2.0) ≈ 0.88
|
| 32 |
+
GATE_BIAS_CLOSED = -2.0 # sigmoid(-2.0) ≈ 0.12
|
| 33 |
+
|
| 34 |
+
# ─── Singularity / Active Inference ───────────────────────────────────────
|
| 35 |
+
SINGULARITY_THRESHOLD = 0.5
|
| 36 |
+
BLACK_HOLE_STRENGTH = 3.0
|
| 37 |
+
SINGULARITY_GATE_SLOPE = 10.0
|
| 38 |
+
|
| 39 |
+
# ─── Torus geometry ───────────────────────────────────────────────────────
|
| 40 |
+
TOROIDAL_MAJOR_RADIUS = 1.0
|
| 41 |
+
TOROIDAL_MINOR_RADIUS = 0.3
|
| 42 |
+
TOROIDAL_PERIOD = 2.0 * PI
|
| 43 |
+
TOROIDAL_CURVATURE_SCALE = 0.1
|
| 44 |
+
|
| 45 |
+
# ─── Tipo de dato por defecto ─────────────────────────────────────────────
|
| 46 |
+
DTYPE = torch.float32
|
| 47 |
+
|
| 48 |
+
# ─── Topology Names ───────────────────────────────────────────────────────
|
| 49 |
+
TOPOLOGY_TORUS = "torus"
|
| 50 |
+
TOPOLOGY_SPHERE = "spherical"
|
| 51 |
+
TOPOLOGY_HYPERBOLIC = "hyperbolic"
|
| 52 |
+
TOPOLOGY_EUCLIDEAN = "euclidean"
|
| 53 |
+
|
| 54 |
+
# ─── Dynamics Modes ───────────────────────────────────────────────────────
|
| 55 |
+
DYNAMICS_DIRECT = "direct"
|
| 56 |
+
DYNAMICS_RESIDUAL = "residual"
|
| 57 |
+
DYNAMICS_MIX = "mix"
|
| 58 |
+
DYNAMICS_GATED = "gated"
|
| 59 |
+
DYNAMICS_STOCHASTIC = "stochastic"
|
| 60 |
+
|
| 61 |
+
# ─── Alias de compatibilidad (valores por defecto que moved a defaults.py) ─
|
| 62 |
+
# NOTA: Estos valores se mantienen aquí por compatibilidad pero deberían
|
| 63 |
+
# imports desde config/defaults.py en código nuevo
|
| 64 |
+
DEFAULT_FRICTION = 0.01
|
| 65 |
+
DEFAULT_DT = 0.1
|
| 66 |
+
DEFAULT_PLASTICITY = 0.05
|
| 67 |
+
MAX_VELOCITY = 10.0
|
gfn/realizations/gssm/core/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/__init__.py — GFN V5
|
| 3 |
+
"""
|
| 4 |
+
from ..core.types import ManifoldState, Trajectory, StepResult, ModelOutput
|
| 5 |
+
from ..core.state import ManifoldStateManager
|
| 6 |
+
|
| 7 |
+
__all__ = ['ManifoldState', 'Trajectory', 'StepResult', 'ModelOutput', 'ManifoldStateManager']
|
gfn/realizations/gssm/core/state.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/state.py — GFN V5
|
| 3 |
+
Manejo de estado del manifold (posición + velocidad).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ManifoldStateManager:
|
| 12 |
+
"""
|
| 13 |
+
Gestiona la inicialización y manipulación del estado (x, v).
|
| 14 |
+
Compatible con batches y múltiples cabezales.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def initialize(x0: nn.Parameter, v0: nn.Parameter,
|
| 19 |
+
batch_size: int, n_trajectories: int = 1,
|
| 20 |
+
initial_spread: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 21 |
+
"""
|
| 22 |
+
Inicializa el estado (x, v) para un batch dado.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
x0, v0: Parámetros iniciales [1, H, HD]
|
| 26 |
+
batch_size: Tamaño del batch
|
| 27 |
+
n_trajectories: Número de trayectorias paralelas
|
| 28 |
+
initial_spread: Ruido inicial
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
(x, v) — [B, H, HD]
|
| 32 |
+
"""
|
| 33 |
+
x = x0.expand(batch_size, -1, -1)
|
| 34 |
+
v = v0.expand(batch_size, -1, -1)
|
| 35 |
+
|
| 36 |
+
if initial_spread > 0:
|
| 37 |
+
x = x + torch.randn_like(x) * initial_spread
|
| 38 |
+
|
| 39 |
+
return x.contiguous(), v.contiguous()
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def from_tuple(state: Optional[Tuple], x0: nn.Parameter, v0: nn.Parameter,
|
| 43 |
+
batch_size: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
"""
|
| 45 |
+
Construye (x, v) desde un estado previo o desde parámetros iniciales.
|
| 46 |
+
Compatible con el API de BasicModel.
|
| 47 |
+
"""
|
| 48 |
+
if state is not None and isinstance(state, (tuple, list)) and len(state) == 2:
|
| 49 |
+
return state[0], state[1]
|
| 50 |
+
return ManifoldStateManager.initialize(x0, v0, batch_size, **kwargs)
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def wrap_torus(x: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
"""Proyecta posición al dominio toroidal [-π, π]."""
|
| 55 |
+
return torch.atan2(torch.sin(x), torch.cos(x))
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def energy(v: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
"""Energía cinética H = 0.5 * ||v||² por muestra."""
|
| 60 |
+
return 0.5 * (v ** 2).sum(dim=-1)
|
gfn/realizations/gssm/core/types.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/types.py — GFN V5
|
| 3 |
+
Tipos y type aliases del framework.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Dict, Any, Tuple, Optional, List, Union
|
| 8 |
+
|
| 9 |
+
# ─── State types ─────────────────────────────────────────────────────────────
|
| 10 |
+
# (position, velocity) pair
|
| 11 |
+
ManifoldState = Tuple[torch.Tensor, torch.Tensor]
|
| 12 |
+
|
| 13 |
+
# Trajectory: list of (x, v) states over time
|
| 14 |
+
Trajectory = List[ManifoldState]
|
| 15 |
+
|
| 16 |
+
# Force tensor (same shape as x, v)
|
| 17 |
+
Force = torch.Tensor
|
| 18 |
+
|
| 19 |
+
# Integration step result
|
| 20 |
+
StepResult = Dict[str, torch.Tensor] # {'x': ..., 'v': ...}
|
| 21 |
+
|
| 22 |
+
# ─── Config types ─────────────────────────────────────────────────────────────
|
| 23 |
+
ConfigDict = Dict[str, Any]
|
| 24 |
+
|
| 25 |
+
# ─── Forward pass outputs ─────────────────────────────────────────────────────
|
| 26 |
+
# (logits, state, info_dict)
|
| 27 |
+
ModelOutput = Tuple[torch.Tensor, ManifoldState, Dict[str, Any]]
|
gfn/realizations/gssm/csrc/README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Native C++/CUDA Extensions
|
| 2 |
+
Store raw .cpp and .cu files here. Bindings remain in new_gfn/cuda/
|
gfn/realizations/gssm/csrc/compile_cuda_12.9.bat
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
|
| 3 |
+
REM Ensure we are in the script's directory so we run the correct setup.py
|
| 4 |
+
cd /d "%~dp0"
|
| 5 |
+
|
| 6 |
+
echo ================================================================
|
| 7 |
+
echo [GFN] Custom CUDA Kernel Compilation Pipeline (VS 2022 + CUDA 12.9)
|
| 8 |
+
echo ================================================================
|
| 9 |
+
|
| 10 |
+
REM --- 1. Find Visual Studio 2022 Installation ---
|
| 11 |
+
set "VS_PATH="
|
| 12 |
+
|
| 13 |
+
if exist "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" (
|
| 14 |
+
set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat"
|
| 15 |
+
) else if exist "C:\Program Files\Microsoft Visual Studio\2022\Professional\VC\Auxiliary\Build\vcvars64.bat" (
|
| 16 |
+
set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Professional\VC\Auxiliary\Build\vcvars64.bat"
|
| 17 |
+
) else if exist "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" (
|
| 18 |
+
set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if "%VS_PATH%"=="" (
|
| 22 |
+
echo [ERROR] Could not find Visual Studio 2022 installation.
|
| 23 |
+
pause
|
| 24 |
+
exit /b 1
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
echo [*] Found MSVC Environment: "%VS_PATH%"
|
| 28 |
+
echo [*] Initializing Developer Console...
|
| 29 |
+
call "%VS_PATH%"
|
| 30 |
+
|
| 31 |
+
REM --- 2. Setup CUDA Environment (12.9) ---
|
| 32 |
+
set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9"
|
| 33 |
+
set "PATH=%CUDA_PATH%\bin;%PATH%"
|
| 34 |
+
|
| 35 |
+
echo [*] Using CUDA Path: "%CUDA_PATH%"
|
| 36 |
+
nvcc --version
|
| 37 |
+
|
| 38 |
+
REM --- 3. Compile Kernels ---
|
| 39 |
+
echo [*] Cleaning old builds...
|
| 40 |
+
rmdir /s /q build
|
| 41 |
+
rmdir /s /q gfn_cuda.egg-info
|
| 42 |
+
del /q *.pyc
|
| 43 |
+
rmdir /s /q __pycache__
|
| 44 |
+
|
| 45 |
+
echo.
|
| 46 |
+
echo [*] Starting Setup Compilation (In-place)...
|
| 47 |
+
|
| 48 |
+
REM Fix for "It seems that the VC environment is activated..." warning
|
| 49 |
+
set DISTUTILS_USE_SDK=1
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
python setup.py build_ext --inplace
|
| 54 |
+
|
| 55 |
+
if %errorlevel% neq 0 (
|
| 56 |
+
echo [ERROR] Compilation failed.
|
| 57 |
+
echo Ensure you have PyTorch installed for CUDA 12.x:
|
| 58 |
+
echo pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
| 59 |
+
pause
|
| 60 |
+
exit /b 1
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
echo.
|
| 64 |
+
echo [SUCCESS] Kernels compiled to local .pyd file!
|
| 65 |
+
echo Verified import:
|
| 66 |
+
|
| 67 |
+
python -c "print('SUCCESS: gfn_cuda module imported directly!')"
|
| 68 |
+
pause
|
gfn/realizations/gssm/csrc/extension.cpp
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include "integrators/integrators.h"
|
| 4 |
+
|
| 5 |
+
// Declaración de funciones (Toroidal Loss)
|
| 6 |
+
torch::Tensor toroidal_distance_loss_fwd(const torch::Tensor& y_pred, const torch::Tensor& y_true);
|
| 7 |
+
torch::Tensor toroidal_distance_loss_bwd(const torch::Tensor& grad_output, const torch::Tensor& y_pred, const torch::Tensor& y_true);
|
| 8 |
+
|
| 9 |
+
// Declaración de funciones (Low Rank Christoffel)
|
| 10 |
+
torch::Tensor low_rank_christoffel_fwd(
|
| 11 |
+
const torch::Tensor& v, const torch::Tensor& U, const torch::Tensor& W,
|
| 12 |
+
double clamp_val, bool enable_trace_norm, bool is_paper_version);
|
| 13 |
+
|
| 14 |
+
// Implementación de Backward puro en ATen C++ (Compilado por MSVC, evadiendo bug de NVCC CICC)
|
| 15 |
+
std::vector<torch::Tensor> low_rank_christoffel_bwd(
|
| 16 |
+
const torch::Tensor& grad_gamma,
|
| 17 |
+
const torch::Tensor& v,
|
| 18 |
+
const torch::Tensor& U,
|
| 19 |
+
const torch::Tensor& W,
|
| 20 |
+
const torch::Tensor& gamma_out,
|
| 21 |
+
double clamp_val,
|
| 22 |
+
bool enable_trace_norm,
|
| 23 |
+
bool is_paper_version)
|
| 24 |
+
{
|
| 25 |
+
// Fast pure ATen operations avoiding Python overhead
|
| 26 |
+
auto g_norm = gamma_out / clamp_val;
|
| 27 |
+
auto d_tanh = 1.0 - g_norm.pow(2);
|
| 28 |
+
auto grad_raw = grad_gamma * d_tanh; // [B, H, D]
|
| 29 |
+
|
| 30 |
+
if (enable_trace_norm) {
|
| 31 |
+
auto mean_d = grad_raw.mean(-1, /*keepdim=*/true);
|
| 32 |
+
grad_raw = grad_raw - mean_d;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// Explicit Batched Matrix Multiplication (bmm) to avoid matmul broadcast crashes
|
| 36 |
+
// W is [H, D, R], grad_raw is [B, H, D]
|
| 37 |
+
auto grad_raw_h = grad_raw.permute({1, 0, 2}); // [H, B, D]
|
| 38 |
+
auto d_sq_h = torch::bmm(grad_raw_h, W); // [H, B, D] @ [H, D, R] -> [H, B, R]
|
| 39 |
+
auto d_sq = d_sq_h.permute({1, 0, 2}); // [B, H, R]
|
| 40 |
+
|
| 41 |
+
auto v_h = v.permute({1, 0, 2}); // [H, B, D]
|
| 42 |
+
auto v_r_h = torch::bmm(v_h, U); // [H, B, D] @ [H, D, R] -> [H, B, R]
|
| 43 |
+
auto v_r = v_r_h.permute({1, 0, 2}); // [B, H, R]
|
| 44 |
+
|
| 45 |
+
torch::Tensor d_vr;
|
| 46 |
+
if (is_paper_version) {
|
| 47 |
+
auto vr_norm = torch::norm(v_r, 2, -1, true);
|
| 48 |
+
auto denom = 1.0 + vr_norm;
|
| 49 |
+
|
| 50 |
+
// Correct Chain Rule for normalized denominator Coupling:
|
| 51 |
+
// grad_vr_j = grad_phi_j * (2v_j / denom) - v_j * Sum_k(grad_phi_k * v_k^2) / (||v|| * denom^2)
|
| 52 |
+
auto S = (d_sq * v_r.pow(2)).sum(-1, true);
|
| 53 |
+
auto term1 = d_sq * (2.0 * v_r / denom);
|
| 54 |
+
auto term2 = (v_r * S) / (vr_norm * denom.pow(2) + 1e-8);
|
| 55 |
+
d_vr = term1 - term2;
|
| 56 |
+
} else {
|
| 57 |
+
d_vr = d_sq * 2.0 * v_r;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
auto d_vr_h = d_vr.permute({1, 0, 2}); // [H, B, R]
|
| 61 |
+
auto U_t = U.transpose(-1, -2); // [H, R, D]
|
| 62 |
+
auto d_v_h = torch::bmm(d_vr_h, U_t); // [H, B, R] @ [H, R, D] -> [H, B, D]
|
| 63 |
+
auto d_v = d_v_h.permute({1, 0, 2}); // [B, H, D]
|
| 64 |
+
|
| 65 |
+
// We accumulate W and U gradients over Batch:
|
| 66 |
+
auto sq = is_paper_version ? v_r.pow(2) / (1.0 + torch::norm(v_r, 2, -1, true)) : v_r.pow(2); // [B, H, R]
|
| 67 |
+
auto sq_h = sq.permute({1, 0, 2}); // [H, B, R]
|
| 68 |
+
|
| 69 |
+
auto grad_raw_h_t = grad_raw_h.transpose(-1, -2); // [H, D, B]
|
| 70 |
+
auto d_W = torch::bmm(grad_raw_h_t, sq_h); // [H, D, B] @ [H, B, R] -> [H, D, R]
|
| 71 |
+
|
| 72 |
+
auto v_h_t = v_h.transpose(-1, -2); // [H, D, B]
|
| 73 |
+
auto d_U = torch::bmm(v_h_t, d_vr_h); // [H, D, B] @ [H, B, R] -> [H, D, R]
|
| 74 |
+
|
| 75 |
+
return {d_v, d_U, d_W};
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 79 |
+
m.def("toroidal_distance_loss_fwd", &toroidal_distance_loss_fwd, "Toroidal Distance Loss Forward (CUDA)");
|
| 80 |
+
m.def("toroidal_distance_loss_bwd", &toroidal_distance_loss_bwd, "Toroidal Distance Loss Backward (CUDA)");
|
| 81 |
+
|
| 82 |
+
m.def("low_rank_christoffel_fwd", &low_rank_christoffel_fwd, "Low Rank Christoffel Forward Kernel");
|
| 83 |
+
m.def("low_rank_christoffel_bwd", &low_rank_christoffel_bwd, "Low Rank Christoffel Backward ATen");
|
| 84 |
+
|
| 85 |
+
m.def("yoshida_fwd", &yoshida_fwd_aten, "Yoshida C++ Macro Integrator Step");
|
| 86 |
+
m.def("leapfrog_fwd", &leapfrog_fwd_aten, "Leapfrog C++ Macro Integrator Step");
|
| 87 |
+
}
|
gfn/realizations/gssm/csrc/geometry/low_rank.cu
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
|
| 5 |
+
template <typename scalar_t>
|
| 6 |
+
__global__ void low_rank_christoffel_fwd_kernel(
|
| 7 |
+
const scalar_t* __restrict__ v, // [B, H, D]
|
| 8 |
+
const scalar_t* __restrict__ U, // [H, D, R]
|
| 9 |
+
const scalar_t* __restrict__ W, // [H, D, R]
|
| 10 |
+
scalar_t* __restrict__ gamma, // [B, H, D]
|
| 11 |
+
const int B, const int H, const int D, const int R,
|
| 12 |
+
const scalar_t clamp_val,
|
| 13 |
+
const bool enable_trace_norm,
|
| 14 |
+
const bool is_paper_version)
|
| 15 |
+
{
|
| 16 |
+
// A block computes the output for one (b, h) pair
|
| 17 |
+
int bh = blockIdx.x;
|
| 18 |
+
if (bh >= B * H) return;
|
| 19 |
+
|
| 20 |
+
int h = bh % H;
|
| 21 |
+
|
| 22 |
+
// Dynamic shared memory allocations
|
| 23 |
+
extern __shared__ char smem[];
|
| 24 |
+
scalar_t* v_s_d = reinterpret_cast<scalar_t*>(smem); // Size: D
|
| 25 |
+
scalar_t* vr_sq_s = reinterpret_cast<scalar_t*>(&v_s_d[D]); // Size: R
|
| 26 |
+
scalar_t* gamma_s_d = reinterpret_cast<scalar_t*>(&vr_sq_s[R]); // Size: D
|
| 27 |
+
|
| 28 |
+
const scalar_t* v_b = v + bh * D;
|
| 29 |
+
scalar_t* gamma_b = gamma + bh * D;
|
| 30 |
+
|
| 31 |
+
// Pointers for H offset
|
| 32 |
+
const scalar_t* U_h = U + h * D * R;
|
| 33 |
+
const scalar_t* W_h = W + h * D * R;
|
| 34 |
+
|
| 35 |
+
const int tid = threadIdx.x;
|
| 36 |
+
const int bdim = blockDim.x;
|
| 37 |
+
|
| 38 |
+
// 1. Load v into shared memory
|
| 39 |
+
for (int i = tid; i < D; i += bdim) {
|
| 40 |
+
v_s_d[i] = v_b[i];
|
| 41 |
+
}
|
| 42 |
+
__syncthreads();
|
| 43 |
+
|
| 44 |
+
// 2. Compute v_r = v @ U -> sq = v_r^2
|
| 45 |
+
for (int r = tid; r < R; r += bdim) {
|
| 46 |
+
scalar_t sum = 0;
|
| 47 |
+
for (int j = 0; j < D; ++j) {
|
| 48 |
+
sum += v_s_d[j] * U_h[j * R + r];
|
| 49 |
+
}
|
| 50 |
+
vr_sq_s[r] = sum * sum;
|
| 51 |
+
}
|
| 52 |
+
__syncthreads();
|
| 53 |
+
|
| 54 |
+
// 3. Optional: Paper Low Rank denominator logic
|
| 55 |
+
if (is_paper_version) {
|
| 56 |
+
__shared__ scalar_t block_sum_sq;
|
| 57 |
+
if (tid == 0) block_sum_sq = 0;
|
| 58 |
+
__syncthreads();
|
| 59 |
+
|
| 60 |
+
scalar_t local_sq = 0;
|
| 61 |
+
for (int r = tid; r < R; r += bdim) {
|
| 62 |
+
local_sq += vr_sq_s[r];
|
| 63 |
+
}
|
| 64 |
+
atomicAdd(&block_sum_sq, local_sq);
|
| 65 |
+
__syncthreads();
|
| 66 |
+
|
| 67 |
+
scalar_t norm_vr = sqrt(block_sum_sq);
|
| 68 |
+
scalar_t denom = 1.0 + norm_vr;
|
| 69 |
+
|
| 70 |
+
for (int r = tid; r < R; r += bdim) {
|
| 71 |
+
vr_sq_s[r] = vr_sq_s[r] / denom;
|
| 72 |
+
}
|
| 73 |
+
__syncthreads();
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// 4. Compute gamma_raw = sq @ W.T
|
| 77 |
+
for (int d = tid; d < D; d += bdim) {
|
| 78 |
+
scalar_t sum = 0;
|
| 79 |
+
for (int r = 0; r < R; ++r) {
|
| 80 |
+
sum += vr_sq_s[r] * W_h[d * R + r];
|
| 81 |
+
}
|
| 82 |
+
gamma_s_d[d] = sum;
|
| 83 |
+
}
|
| 84 |
+
__syncthreads();
|
| 85 |
+
|
| 86 |
+
// 5. Trace normalization (mean subtraction)
|
| 87 |
+
scalar_t mean_val = 0;
|
| 88 |
+
if (enable_trace_norm) {
|
| 89 |
+
__shared__ scalar_t block_sum_gamma;
|
| 90 |
+
if (tid == 0) block_sum_gamma = 0;
|
| 91 |
+
__syncthreads();
|
| 92 |
+
|
| 93 |
+
scalar_t local_gamma_sum = 0;
|
| 94 |
+
for (int d = tid; d < D; d += bdim) {
|
| 95 |
+
local_gamma_sum += gamma_s_d[d];
|
| 96 |
+
}
|
| 97 |
+
atomicAdd(&block_sum_gamma, local_gamma_sum);
|
| 98 |
+
__syncthreads();
|
| 99 |
+
|
| 100 |
+
mean_val = block_sum_gamma / D;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// 6. Normalization and storage
|
| 104 |
+
for (int d = tid; d < D; d += bdim) {
|
| 105 |
+
scalar_t g = gamma_s_d[d];
|
| 106 |
+
if (enable_trace_norm) {
|
| 107 |
+
g -= mean_val;
|
| 108 |
+
}
|
| 109 |
+
g = clamp_val * tanh(g / clamp_val);
|
| 110 |
+
gamma_b[d] = g;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 115 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 116 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 117 |
+
|
| 118 |
+
torch::Tensor low_rank_christoffel_fwd(
|
| 119 |
+
const torch::Tensor& v,
|
| 120 |
+
const torch::Tensor& U,
|
| 121 |
+
const torch::Tensor& W,
|
| 122 |
+
double clamp_val,
|
| 123 |
+
bool enable_trace_norm,
|
| 124 |
+
bool is_paper_version)
|
| 125 |
+
{
|
| 126 |
+
CHECK_INPUT(v);
|
| 127 |
+
CHECK_INPUT(U);
|
| 128 |
+
CHECK_INPUT(W);
|
| 129 |
+
|
| 130 |
+
// Ensure shapes: v is [B, H, D], U is [H, D, R], W is [H, D, R]
|
| 131 |
+
int B = v.size(0);
|
| 132 |
+
int H = v.size(1);
|
| 133 |
+
int D = v.size(2);
|
| 134 |
+
int R = U.size(2);
|
| 135 |
+
|
| 136 |
+
auto gamma = torch::empty_like(v);
|
| 137 |
+
|
| 138 |
+
const int threads = 256;
|
| 139 |
+
const int blocks = B * H;
|
| 140 |
+
|
| 141 |
+
// Shared memory size: (D + R + D) * sizeof(float)
|
| 142 |
+
const int shared_mem_size = (2 * D + R) * sizeof(float);
|
| 143 |
+
|
| 144 |
+
if (v.scalar_type() == torch::kFloat32) {
|
| 145 |
+
low_rank_christoffel_fwd_kernel<float><<<blocks, threads, shared_mem_size>>>(
|
| 146 |
+
v.data_ptr<float>(),
|
| 147 |
+
U.data_ptr<float>(),
|
| 148 |
+
W.data_ptr<float>(),
|
| 149 |
+
gamma.data_ptr<float>(),
|
| 150 |
+
B, H, D, R,
|
| 151 |
+
static_cast<float>(clamp_val),
|
| 152 |
+
enable_trace_norm,
|
| 153 |
+
is_paper_version
|
| 154 |
+
);
|
| 155 |
+
} else {
|
| 156 |
+
TORCH_CHECK(false, "low_rank_christoffel_fwd only supports float32");
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
return gamma;
|
| 160 |
+
}
|
gfn/realizations/gssm/csrc/integrators/integrators.cpp
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#define _USE_MATH_DEFINES
|
| 4 |
+
#include <cmath>
|
| 5 |
+
|
| 6 |
+
#ifndef M_PI
|
| 7 |
+
#define M_PI 3.14159265358979323846
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
// -------------------------------------------------------------
|
| 11 |
+
// Pure ATen Implementation of the Integrators Loop
|
| 12 |
+
// -------------------------------------------------------------
|
| 13 |
+
// We build this in standard C++ using ATen so it runs in a single
|
| 14 |
+
// Python call but is compiled by MSVC, averting NVCC OOM bugs.
|
| 15 |
+
// It performs `steps` loop using the exact GFN LowRank geometry.
|
| 16 |
+
|
| 17 |
+
// Helper to compute Christoffel Gamma
|
| 18 |
+
torch::Tensor _compute_gamma(
|
| 19 |
+
const torch::Tensor& v,
|
| 20 |
+
const torch::Tensor& U,
|
| 21 |
+
const torch::Tensor& W,
|
| 22 |
+
double clamp_val,
|
| 23 |
+
bool enable_trace_norm,
|
| 24 |
+
bool is_paper_version)
|
| 25 |
+
{
|
| 26 |
+
auto v_r = torch::matmul(v.unsqueeze(-2), U).squeeze(-2); // [..., R]
|
| 27 |
+
torch::Tensor sq;
|
| 28 |
+
if (is_paper_version) {
|
| 29 |
+
auto vr_norm = torch::norm(v_r, 2, -1, true);
|
| 30 |
+
sq = v_r.pow(2) / (1.0 + vr_norm);
|
| 31 |
+
} else {
|
| 32 |
+
sq = v_r.pow(2);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
auto gamma = torch::matmul(sq.unsqueeze(-2), W.transpose(-1, -2)).squeeze(-2); // [..., D]
|
| 36 |
+
|
| 37 |
+
if (enable_trace_norm) {
|
| 38 |
+
auto mean_g = gamma.mean(-1, /*keepdim=*/true);
|
| 39 |
+
gamma = gamma - mean_g;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
return clamp_val * torch::tanh(gamma / clamp_val);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// Helper for velocity saturation (soft-clamp)
|
| 46 |
+
torch::Tensor _clamp_velocity(const torch::Tensor& v, double v_sat) {
|
| 47 |
+
if (v_sat > 0) {
|
| 48 |
+
return v_sat * torch::tanh(v / v_sat);
|
| 49 |
+
}
|
| 50 |
+
return v;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Yoshida 4th order coefficients
|
| 54 |
+
const double w1 = 1.3512071919596576;
|
| 55 |
+
const double w0 = -1.7024143839193153;
|
| 56 |
+
const double y_c1 = w1 / 2.0;
|
| 57 |
+
const double y_c2 = (w0 + w1) / 2.0;
|
| 58 |
+
const double y_c3 = y_c2;
|
| 59 |
+
const double y_c4 = y_c1;
|
| 60 |
+
const double y_d1 = w1;
|
| 61 |
+
const double y_d2 = w0;
|
| 62 |
+
const double y_d3 = w1;
|
| 63 |
+
|
| 64 |
+
// Helper for Gated Friction (Active Inference)
|
| 65 |
+
torch::Tensor _compute_mu(
|
| 66 |
+
const torch::Tensor& x,
|
| 67 |
+
const torch::Tensor& v,
|
| 68 |
+
const torch::Tensor& gate_w,
|
| 69 |
+
const torch::Tensor& gate_b,
|
| 70 |
+
double base_friction,
|
| 71 |
+
double vel_fric_scale)
|
| 72 |
+
{
|
| 73 |
+
const double eps = 1e-8;
|
| 74 |
+
const double D = x.size(-1);
|
| 75 |
+
|
| 76 |
+
// mu_base = base_friction
|
| 77 |
+
torch::Tensor mu = torch::full_like(x.select(-1, 0).unsqueeze(-1), base_friction);
|
| 78 |
+
|
| 79 |
+
// If gate weigths are provided, calculate learnable friction component
|
| 80 |
+
if (gate_w.numel() > 0) {
|
| 81 |
+
torch::Tensor feat;
|
| 82 |
+
// Check if we need Torus features [sin, cos] (gate_w dim will be 2*D)
|
| 83 |
+
if (gate_w.size(1) == 2 * D) {
|
| 84 |
+
feat = torch::cat({torch::sin(x), torch::cos(x)}, -1); // [..., 2D]
|
| 85 |
+
} else {
|
| 86 |
+
feat = x; // Euclidean / Flat
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// Linear gate: sigmoid(feat @ w + b)
|
| 90 |
+
auto gate_out = torch::matmul(feat.unsqueeze(-2), gate_w).squeeze(-2); // [B, H, 1]
|
| 91 |
+
if (gate_b.numel() > 0) {
|
| 92 |
+
gate_out = gate_out + gate_b;
|
| 93 |
+
}
|
| 94 |
+
mu = mu + torch::sigmoid(gate_out);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// Velocity-dependent scaling: mu * (1 + scale * ||v||)
|
| 98 |
+
auto v_norm = torch::norm(v, 2, -1, true) / (std::sqrt(D) + eps);
|
| 99 |
+
mu = mu * (1.0 + vel_fric_scale * v_norm);
|
| 100 |
+
|
| 101 |
+
return mu;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Helper for Singularity Damping
|
| 105 |
+
torch::Tensor _apply_singularity_damping(
|
| 106 |
+
const torch::Tensor& acc,
|
| 107 |
+
const torch::Tensor& v,
|
| 108 |
+
const torch::Tensor& U,
|
| 109 |
+
double sing_thresh,
|
| 110 |
+
double sing_strength)
|
| 111 |
+
{
|
| 112 |
+
if (sing_strength <= 1.0 || sing_thresh <= 0.0) return acc;
|
| 113 |
+
|
| 114 |
+
// Detect singularity: metrics are low near singular points.
|
| 115 |
+
// In LowRank, g_diag = sum(U^2).
|
| 116 |
+
auto g_diag = (U.pow(2)).sum(-1); // [H, D]
|
| 117 |
+
|
| 118 |
+
// Potential = sigmoid(5.0 * (g - thresh))
|
| 119 |
+
auto soft_mask = torch::sigmoid(5.0 * (g_diag - sing_thresh));
|
| 120 |
+
|
| 121 |
+
// Scale acceleration by soft_mask (Damping Shield)
|
| 122 |
+
// Near singularity: soft_mask -> 0, damping the forces.
|
| 123 |
+
return acc * soft_mask;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
std::vector<torch::Tensor> yoshida_fwd_aten(
|
| 127 |
+
const torch::Tensor& x_init,
|
| 128 |
+
const torch::Tensor& v_init,
|
| 129 |
+
const torch::Tensor& U,
|
| 130 |
+
const torch::Tensor& W,
|
| 131 |
+
const torch::Tensor& force,
|
| 132 |
+
const torch::Tensor& dt,
|
| 133 |
+
int steps,
|
| 134 |
+
double clamp_val,
|
| 135 |
+
double friction,
|
| 136 |
+
double vel_fric_scale,
|
| 137 |
+
double vel_sat,
|
| 138 |
+
const torch::Tensor& gate_w,
|
| 139 |
+
const torch::Tensor& gate_b,
|
| 140 |
+
double sing_thresh,
|
| 141 |
+
double sing_strength,
|
| 142 |
+
bool enable_trace_norm,
|
| 143 |
+
bool is_paper_version)
|
| 144 |
+
{
|
| 145 |
+
auto x = x_init.clone();
|
| 146 |
+
auto v = v_init.clone();
|
| 147 |
+
|
| 148 |
+
const double eps = 1e-8;
|
| 149 |
+
|
| 150 |
+
for (int i = 0; i < steps; ++i) {
|
| 151 |
+
// Sub-step 1
|
| 152 |
+
x = x + y_c1 * dt * v;
|
| 153 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI; // Toroidal resolve
|
| 154 |
+
|
| 155 |
+
auto gamma1 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 156 |
+
auto a1_nf = force - gamma1;
|
| 157 |
+
a1_nf = _apply_singularity_damping(a1_nf, v, U, sing_thresh, sing_strength);
|
| 158 |
+
|
| 159 |
+
auto mu1 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 160 |
+
|
| 161 |
+
v = (v + y_d1 * dt * a1_nf) / (1.0 + y_d1 * dt * mu1 + eps);
|
| 162 |
+
v = _clamp_velocity(v, vel_sat);
|
| 163 |
+
|
| 164 |
+
// Sub-step 2
|
| 165 |
+
x = x + y_c2 * dt * v;
|
| 166 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 167 |
+
|
| 168 |
+
auto gamma2 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 169 |
+
auto a2_nf = force - gamma2;
|
| 170 |
+
a2_nf = _apply_singularity_damping(a2_nf, v, U, sing_thresh, sing_strength);
|
| 171 |
+
|
| 172 |
+
auto mu2 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 173 |
+
|
| 174 |
+
v = (v + y_d2 * dt * a2_nf) / (1.0 + y_d2 * dt * mu2 + eps);
|
| 175 |
+
v = _clamp_velocity(v, vel_sat);
|
| 176 |
+
|
| 177 |
+
// Sub-step 3
|
| 178 |
+
x = x + y_c3 * dt * v;
|
| 179 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 180 |
+
|
| 181 |
+
auto gamma3 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 182 |
+
auto a3_nf = force - gamma3;
|
| 183 |
+
a3_nf = _apply_singularity_damping(a3_nf, v, U, sing_thresh, sing_strength);
|
| 184 |
+
|
| 185 |
+
auto mu3 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 186 |
+
|
| 187 |
+
v = (v + y_d3 * dt * a3_nf) / (1.0 + y_d3 * dt * mu3 + eps);
|
| 188 |
+
v = _clamp_velocity(v, vel_sat);
|
| 189 |
+
|
| 190 |
+
// Final drift
|
| 191 |
+
x = x + y_c4 * dt * v;
|
| 192 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return {x, v};
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
std::vector<torch::Tensor> leapfrog_fwd_aten(
|
| 199 |
+
const torch::Tensor& x_init,
|
| 200 |
+
const torch::Tensor& v_init,
|
| 201 |
+
const torch::Tensor& U,
|
| 202 |
+
const torch::Tensor& W,
|
| 203 |
+
const torch::Tensor& force,
|
| 204 |
+
const torch::Tensor& dt,
|
| 205 |
+
int steps,
|
| 206 |
+
double clamp_val,
|
| 207 |
+
double friction,
|
| 208 |
+
double vel_fric_scale,
|
| 209 |
+
double vel_sat,
|
| 210 |
+
const torch::Tensor& gate_w,
|
| 211 |
+
const torch::Tensor& gate_b,
|
| 212 |
+
double sing_thresh,
|
| 213 |
+
double sing_strength,
|
| 214 |
+
bool enable_trace_norm,
|
| 215 |
+
bool is_paper_version)
|
| 216 |
+
{
|
| 217 |
+
auto x = x_init.clone();
|
| 218 |
+
auto v = v_init.clone();
|
| 219 |
+
|
| 220 |
+
const double eps = 1e-8;
|
| 221 |
+
|
| 222 |
+
for (int i = 0; i < steps; ++i) {
|
| 223 |
+
// Half-kick 1
|
| 224 |
+
auto gamma1 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 225 |
+
auto a1_nf = force - gamma1;
|
| 226 |
+
a1_nf = _apply_singularity_damping(a1_nf, v, U, sing_thresh, sing_strength);
|
| 227 |
+
|
| 228 |
+
auto mu1 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
|
| 229 |
+
|
| 230 |
+
auto v_half = (v + 0.5 * dt * a1_nf) / (1.0 + 0.5 * dt * mu1 + eps);
|
| 231 |
+
v_half = _clamp_velocity(v_half, vel_sat);
|
| 232 |
+
|
| 233 |
+
// Drift
|
| 234 |
+
x = x + dt * v_half;
|
| 235 |
+
x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
|
| 236 |
+
|
| 237 |
+
// Half-kick 2
|
| 238 |
+
auto gamma2 = _compute_gamma(v_half, U, W, clamp_val, enable_trace_norm, is_paper_version);
|
| 239 |
+
auto a2_nf = force - gamma2;
|
| 240 |
+
a2_nf = _apply_singularity_damping(a2_nf, v_half, U, sing_thresh, sing_strength);
|
| 241 |
+
|
| 242 |
+
auto mu2 = _compute_mu(x, v_half, gate_w, gate_b, friction, vel_fric_scale);
|
| 243 |
+
|
| 244 |
+
auto a_avg = (a1_nf + a2_nf) / 2.0;
|
| 245 |
+
auto mu_avg = (mu1 + mu2) / 2.0;
|
| 246 |
+
|
| 247 |
+
v = (v + dt * a_avg) / (1.0 + dt * mu_avg + eps);
|
| 248 |
+
v = _clamp_velocity(v, vel_sat);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
return {x, v};
|
| 252 |
+
}
|
gfn/realizations/gssm/csrc/integrators/integrators.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
|
| 4 |
+
// Forward declarations for the integrators
|
| 5 |
+
std::vector<torch::Tensor> yoshida_fwd_aten(
|
| 6 |
+
const torch::Tensor& x_init,
|
| 7 |
+
const torch::Tensor& v_init,
|
| 8 |
+
const torch::Tensor& U,
|
| 9 |
+
const torch::Tensor& W,
|
| 10 |
+
const torch::Tensor& force,
|
| 11 |
+
const torch::Tensor& dt,
|
| 12 |
+
int steps,
|
| 13 |
+
double clamp_val,
|
| 14 |
+
double friction,
|
| 15 |
+
double vel_fric_scale,
|
| 16 |
+
double vel_sat,
|
| 17 |
+
const torch::Tensor& gate_w,
|
| 18 |
+
const torch::Tensor& gate_b,
|
| 19 |
+
double sing_thresh,
|
| 20 |
+
double sing_strength,
|
| 21 |
+
bool enable_trace_norm,
|
| 22 |
+
bool is_paper_version);
|
| 23 |
+
|
| 24 |
+
std::vector<torch::Tensor> leapfrog_fwd_aten(
|
| 25 |
+
const torch::Tensor& x_init,
|
| 26 |
+
const torch::Tensor& v_init,
|
| 27 |
+
const torch::Tensor& U,
|
| 28 |
+
const torch::Tensor& W,
|
| 29 |
+
const torch::Tensor& force,
|
| 30 |
+
const torch::Tensor& dt,
|
| 31 |
+
int steps,
|
| 32 |
+
double clamp_val,
|
| 33 |
+
double friction,
|
| 34 |
+
double vel_fric_scale,
|
| 35 |
+
double vel_sat,
|
| 36 |
+
const torch::Tensor& gate_w,
|
| 37 |
+
const torch::Tensor& gate_b,
|
| 38 |
+
double sing_thresh,
|
| 39 |
+
double sing_strength,
|
| 40 |
+
bool enable_trace_norm,
|
| 41 |
+
bool is_paper_version);
|
gfn/realizations/gssm/csrc/losses/toroidal.cu
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
#include <math_constants.h>
|
| 5 |
+
|
| 6 |
+
// ------------------------------------------------------------------------
|
| 7 |
+
// Toroidal Distance Loss
|
| 8 |
+
// L(y_pred, y_true) = (atan2(sin(y_pred - y_true), cos(y_pred - y_true)))^2
|
| 9 |
+
// ------------------------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
template <typename scalar_t>
|
| 12 |
+
__global__ void toroidal_distance_loss_fwd_kernel(
|
| 13 |
+
const scalar_t* __restrict__ y_pred,
|
| 14 |
+
const scalar_t* __restrict__ y_true,
|
| 15 |
+
scalar_t* __restrict__ out,
|
| 16 |
+
const int numel)
|
| 17 |
+
{
|
| 18 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 19 |
+
if (idx < numel) {
|
| 20 |
+
scalar_t diff = y_pred[idx] - y_true[idx];
|
| 21 |
+
scalar_t wrapped = atan2(sin(diff), cos(diff));
|
| 22 |
+
out[idx] = wrapped * wrapped;
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template <typename scalar_t>
|
| 27 |
+
__global__ void toroidal_distance_loss_bwd_kernel(
|
| 28 |
+
const scalar_t* __restrict__ grad_output,
|
| 29 |
+
const scalar_t* __restrict__ y_pred,
|
| 30 |
+
const scalar_t* __restrict__ y_true,
|
| 31 |
+
scalar_t* __restrict__ grad_pred,
|
| 32 |
+
const int numel)
|
| 33 |
+
{
|
| 34 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 35 |
+
if (idx < numel) {
|
| 36 |
+
// Derivada de atan2(sin(x), cos(x))^2 respecto a x es 2 * atan2(sin(x), cos(x))
|
| 37 |
+
scalar_t diff = y_pred[idx] - y_true[idx];
|
| 38 |
+
scalar_t wrapped = atan2(sin(diff), cos(diff));
|
| 39 |
+
grad_pred[idx] = grad_output[idx] * 2.0 * wrapped;
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// ------------------------------------------------------------------------
|
| 44 |
+
// Wrappers ATen
|
| 45 |
+
// ------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 48 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 49 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 50 |
+
|
| 51 |
+
torch::Tensor toroidal_distance_loss_fwd(const torch::Tensor& y_pred, const torch::Tensor& y_true) {
|
| 52 |
+
CHECK_INPUT(y_pred);
|
| 53 |
+
CHECK_INPUT(y_true);
|
| 54 |
+
|
| 55 |
+
auto out = torch::empty_like(y_pred);
|
| 56 |
+
int numel = y_pred.numel();
|
| 57 |
+
|
| 58 |
+
const int threads = 256;
|
| 59 |
+
const int blocks = (numel + threads - 1) / threads;
|
| 60 |
+
|
| 61 |
+
if (y_pred.scalar_type() == torch::kFloat32) {
|
| 62 |
+
toroidal_distance_loss_fwd_kernel<float><<<blocks, threads>>>(
|
| 63 |
+
y_pred.data_ptr<float>(),
|
| 64 |
+
y_true.data_ptr<float>(),
|
| 65 |
+
out.data_ptr<float>(),
|
| 66 |
+
numel
|
| 67 |
+
);
|
| 68 |
+
} else {
|
| 69 |
+
TORCH_CHECK(false, "toroidal_distance_loss_fwd only supports float32");
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
return out;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
torch::Tensor toroidal_distance_loss_bwd(const torch::Tensor& grad_output, const torch::Tensor& y_pred, const torch::Tensor& y_true) {
|
| 76 |
+
CHECK_INPUT(grad_output);
|
| 77 |
+
CHECK_INPUT(y_pred);
|
| 78 |
+
CHECK_INPUT(y_true);
|
| 79 |
+
|
| 80 |
+
auto grad_pred = torch::empty_like(y_pred);
|
| 81 |
+
int numel = y_pred.numel();
|
| 82 |
+
|
| 83 |
+
const int threads = 256;
|
| 84 |
+
const int blocks = (numel + threads - 1) / threads;
|
| 85 |
+
|
| 86 |
+
if (y_pred.scalar_type() == torch::kFloat32) {
|
| 87 |
+
toroidal_distance_loss_bwd_kernel<float><<<blocks, threads>>>(
|
| 88 |
+
grad_output.data_ptr<float>(),
|
| 89 |
+
y_pred.data_ptr<float>(),
|
| 90 |
+
y_true.data_ptr<float>(),
|
| 91 |
+
grad_pred.data_ptr<float>(),
|
| 92 |
+
numel
|
| 93 |
+
);
|
| 94 |
+
} else {
|
| 95 |
+
TORCH_CHECK(false, "toroidal_distance_loss_bwd only supports float32");
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
return grad_pred;
|
| 99 |
+
}
|
gfn/realizations/gssm/csrc/setup.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from setuptools import setup
|
| 3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 4 |
+
|
| 5 |
+
# Force sequential compilation to prevent NVCC Out-Of-Memory (LLVM ERROR)
|
| 6 |
+
os.environ["MAX_JOBS"] = "1"
|
| 7 |
+
|
| 8 |
+
# Directorio base
|
| 9 |
+
csrc_dir = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
+
|
| 11 |
+
sources = [
|
| 12 |
+
os.path.join(csrc_dir, "extension.cpp"),
|
| 13 |
+
os.path.join(csrc_dir, "losses", "toroidal.cu"),
|
| 14 |
+
os.path.join(csrc_dir, "geometry", "low_rank.cu"),
|
| 15 |
+
os.path.join(csrc_dir, "integrators", "integrators.cpp")
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
# Configuración específica para MSVC/Windows vs Linux
|
| 19 |
+
extra_compile_args = {
|
| 20 |
+
'cxx': ['-O2'],
|
| 21 |
+
'nvcc': ['-O2', '-allow-unsupported-compiler']
|
| 22 |
+
}
|
| 23 |
+
if os.name == 'nt':
|
| 24 |
+
extra_compile_args['cxx'].append('/std:c++17')
|
| 25 |
+
|
| 26 |
+
setup(
|
| 27 |
+
name='gfn_cuda',
|
| 28 |
+
ext_modules=[
|
| 29 |
+
CUDAExtension(
|
| 30 |
+
name='gfn_cuda',
|
| 31 |
+
sources=sources,
|
| 32 |
+
extra_compile_args=extra_compile_args,
|
| 33 |
+
)
|
| 34 |
+
],
|
| 35 |
+
cmdclass={
|
| 36 |
+
'build_ext': BuildExtension
|
| 37 |
+
}
|
| 38 |
+
)
|
gfn/realizations/gssm/cuda/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/cuda/__init__.py
|
| 3 |
+
Infraestructura CUDA para GFN V5.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
CUDA_AVAILABLE = torch.cuda.is_available()
|
| 8 |
+
|
| 9 |
+
def is_cuda_active(tensor: torch.Tensor) -> bool:
|
| 10 |
+
"""Verifica si CUDA está disponible y el tensor está en un dispositivo GPU."""
|
| 11 |
+
return CUDA_AVAILABLE and tensor.is_cuda
|
gfn/realizations/gssm/cuda/autograd/__init__.py
ADDED
|
File without changes
|
gfn/realizations/gssm/cuda/kernels/__init__.py
ADDED
|
File without changes
|
gfn/realizations/gssm/cuda/kernels/geometry_kernels.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geometry Kernels — GFN V5
|
| 3 |
+
Unified entry points for geometric computations with hardware dispatching.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Optional, Tuple, Union, Any
|
| 8 |
+
from ...registry import GEOMETRY_REGISTRY
|
| 9 |
+
from ...cuda import is_cuda_active
|
| 10 |
+
|
| 11 |
+
# Lazy import for CUDA ops to avoid loading if not available
|
| 12 |
+
_christoffel_cuda = None
|
| 13 |
+
|
| 14 |
+
def _get_cuda_ops():
|
| 15 |
+
global _christoffel_cuda
|
| 16 |
+
if _christoffel_cuda is None:
|
| 17 |
+
try:
|
| 18 |
+
from ...cuda.ops import christoffel_cuda_fwd
|
| 19 |
+
_christoffel_cuda = christoffel_cuda_fwd
|
| 20 |
+
except ImportError:
|
| 21 |
+
pass
|
| 22 |
+
return _christoffel_cuda
|
| 23 |
+
|
| 24 |
+
def unified_christoffel_fwd(
|
| 25 |
+
x: torch.Tensor,
|
| 26 |
+
v: torch.Tensor,
|
| 27 |
+
U: torch.Tensor,
|
| 28 |
+
W: torch.Tensor,
|
| 29 |
+
clamp_val: float = 5.0,
|
| 30 |
+
**kwargs: Any
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Unified forward pass for Christoffel symbols.
|
| 34 |
+
Dispatches to CUDA kernel if available and on GPU, otherwise falls back to PyTorch.
|
| 35 |
+
"""
|
| 36 |
+
if is_cuda_active(v):
|
| 37 |
+
cuda_op = _get_cuda_ops()
|
| 38 |
+
if cuda_op is not None:
|
| 39 |
+
try:
|
| 40 |
+
return _run_cuda_christoffel(x, v, U, W, clamp_val, cuda_op, **kwargs)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
# print(f"[Dispatcher] CUDA Error: {e}. Falling back.")
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
return _run_pytorch_christoffel(x, v, U, W, clamp_val, **kwargs)
|
| 46 |
+
|
| 47 |
+
def _run_cuda_christoffel(x, v, U, W, clamp_val, cuda_op, **kwargs):
|
| 48 |
+
# Kernel expects Head-Aware tensors [B, H, HD]
|
| 49 |
+
# Check if x,v are [B, D] or [B, H, HD]
|
| 50 |
+
if v.dim() == 2:
|
| 51 |
+
x_k, v_k = x.unsqueeze(1), v.unsqueeze(1)
|
| 52 |
+
else:
|
| 53 |
+
x_k, v_k = x, v
|
| 54 |
+
|
| 55 |
+
# U, W handling (assuming LowRank format)
|
| 56 |
+
# This logic matches the legacy kernel expectations
|
| 57 |
+
# W_k handling (ensuring rank-R is preserved per output dimension)
|
| 58 |
+
if U.dim() == 3:
|
| 59 |
+
U_k = U.transpose(1, 2).contiguous() # [H, R, HD]
|
| 60 |
+
W_k = W.transpose(1, 2).contiguous() # [H, R, HD]
|
| 61 |
+
else:
|
| 62 |
+
U_k = U.T.unsqueeze(0).contiguous() # [1, R, HD]
|
| 63 |
+
W_k = W.T.unsqueeze(0).contiguous() # [1, R, HD]
|
| 64 |
+
|
| 65 |
+
# Execute CUDA kernel
|
| 66 |
+
gamma = cuda_op(U_k, W_k, x_k, v_k, 0, 2.0, 1.0, 0.0)
|
| 67 |
+
|
| 68 |
+
if v.dim() == 2:
|
| 69 |
+
gamma = gamma.squeeze(1)
|
| 70 |
+
|
| 71 |
+
return clamp_val * torch.tanh(gamma / clamp_val)
|
| 72 |
+
|
| 73 |
+
def _run_pytorch_christoffel(x, v, U, W, clamp_val, **kwargs):
|
| 74 |
+
# Multi-head PyTorch fallback
|
| 75 |
+
if v.dim() == 3:
|
| 76 |
+
B, H, HD = v.shape
|
| 77 |
+
# Flatten batch and heads to use efficient matmuls
|
| 78 |
+
v_flat = v.reshape(B * H, HD)
|
| 79 |
+
if U.dim() == 3:
|
| 80 |
+
# U: [H, HD, R], W: [H, R, HD]
|
| 81 |
+
# Need to apply per-head
|
| 82 |
+
proj = torch.bmm(v.transpose(0, 1), U).transpose(0, 1) # [B, H, R]
|
| 83 |
+
sq = proj * proj
|
| 84 |
+
W_t = W.transpose(-1, -2)
|
| 85 |
+
gamma = torch.bmm(sq.transpose(0, 1), W_t).transpose(0, 1) # [B, H, HD]
|
| 86 |
+
else:
|
| 87 |
+
# Shared U, W across heads
|
| 88 |
+
proj = torch.matmul(v_flat, U) # [B*H, R]
|
| 89 |
+
sq = proj * proj
|
| 90 |
+
gamma_flat = torch.matmul(sq, W.t()) # [B*H, HD]
|
| 91 |
+
gamma = gamma_flat.view(B, H, HD)
|
| 92 |
+
else:
|
| 93 |
+
# Single head [B, D]
|
| 94 |
+
proj = torch.matmul(v, U[0] if U.dim() == 3 else U)
|
| 95 |
+
sq = proj * proj
|
| 96 |
+
W_t = (W[0] if W.dim() == 3 else W).t()
|
| 97 |
+
gamma = torch.matmul(sq, W_t)
|
| 98 |
+
|
| 99 |
+
return clamp_val * torch.tanh(gamma / clamp_val)
|
gfn/realizations/gssm/cuda/kernels/integrator_kernels.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integrator Kernels — GFN V5
|
| 3 |
+
Unified entry points for numerical integration with hardware dispatching.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Optional, Tuple, Any, Callable
|
| 8 |
+
from ...cuda import is_cuda_active
|
| 9 |
+
|
| 10 |
+
# Lazy imports for CUDA kernels
|
| 11 |
+
_euler_fused = None
|
| 12 |
+
_rk4_fused = None
|
| 13 |
+
_leapfrog_fused = None
|
| 14 |
+
|
| 15 |
+
def _get_cuda_integrators():
|
| 16 |
+
global _euler_fused, _rk4_fused, _leapfrog_fused
|
| 17 |
+
if _euler_fused is None:
|
| 18 |
+
try:
|
| 19 |
+
from ...cuda.ops import euler_fused, rk4_fused, leapfrog_fused
|
| 20 |
+
_euler_fused = euler_fused
|
| 21 |
+
_rk4_fused = rk4_fused
|
| 22 |
+
_leapfrog_fused = leapfrog_fused
|
| 23 |
+
except ImportError:
|
| 24 |
+
pass
|
| 25 |
+
return _euler_fused, _rk4_fused, _leapfrog_fused
|
| 26 |
+
|
| 27 |
+
def unified_leapfrog_step(
|
| 28 |
+
x: torch.Tensor,
|
| 29 |
+
v: torch.Tensor,
|
| 30 |
+
force: Optional[torch.Tensor],
|
| 31 |
+
U: torch.Tensor,
|
| 32 |
+
W: torch.Tensor,
|
| 33 |
+
dt: float,
|
| 34 |
+
steps: int = 1,
|
| 35 |
+
**kwargs
|
| 36 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 37 |
+
"""Unified Leapfrog integration step."""
|
| 38 |
+
if is_cuda_active(v):
|
| 39 |
+
_, _, f_leapfrog = _get_cuda_integrators()
|
| 40 |
+
if f_leapfrog is not None:
|
| 41 |
+
try:
|
| 42 |
+
# Prep parameters for CUDA kernel
|
| 43 |
+
topo_id = kwargs.get('topology_id', 0)
|
| 44 |
+
R = kwargs.get('R', 2.0)
|
| 45 |
+
r = kwargs.get('r', 1.0)
|
| 46 |
+
H = x.shape[1] if x.dim() == 3 else 1
|
| 47 |
+
|
| 48 |
+
# U, W transformations
|
| 49 |
+
if U.dim() == 2:
|
| 50 |
+
# [D, R] -> [1, R, D] -> [H, R, D]
|
| 51 |
+
U_k = U.T.unsqueeze(0).expand(H, -1, -1).contiguous()
|
| 52 |
+
else:
|
| 53 |
+
# [H, D, R] -> [H, R, D]
|
| 54 |
+
U_k = U.transpose(1, 2).contiguous()
|
| 55 |
+
|
| 56 |
+
if W.dim() == 2:
|
| 57 |
+
# [D, R] -> [1, R] -> [H, R]
|
| 58 |
+
# Use mean instead of sum to preserve effective force scale
|
| 59 |
+
W_k = W.mean(dim=0).unsqueeze(0).expand(H, -1).contiguous()
|
| 60 |
+
else:
|
| 61 |
+
# [H, D, R] -> [H, R]
|
| 62 |
+
W_k = W.abs().mean(dim=1).contiguous()
|
| 63 |
+
|
| 64 |
+
cx, cv = x, v
|
| 65 |
+
for _ in range(steps):
|
| 66 |
+
cx, cv = f_leapfrog(U_k, W_k, cx, cv, force, float(dt), int(topo_id), float(R), float(r), 0.0)
|
| 67 |
+
return cx, cv
|
| 68 |
+
except Exception:
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
# Python fallback is handled by the higher-level Integrator classes in gfn/integrators/
|
| 72 |
+
# This unified layer is primarily for hardware acceleration.
|
| 73 |
+
return None, None # Signal fallback
|
gfn/realizations/gssm/cuda/ops/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/cuda/ops/__init__.py
|
| 3 |
+
Exports all fused CUDA operations.
|
| 4 |
+
Gracefully returns None for any op whose kernel is not compiled.
|
| 5 |
+
"""
|
| 6 |
+
from ...cuda import CUDA_AVAILABLE
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
# Ensure gfn/csrc is in PYTHONPATH to find the compiled .pyd
|
| 11 |
+
_csrc_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "csrc"))
|
| 12 |
+
if _csrc_path not in sys.path:
|
| 13 |
+
sys.path.insert(0, _csrc_path)
|
| 14 |
+
|
| 15 |
+
def _get_op(module_path: str, name: str):
|
| 16 |
+
"""Safely import a CUDA binding, returning None on failure."""
|
| 17 |
+
try:
|
| 18 |
+
import importlib
|
| 19 |
+
mod = importlib.import_module(module_path)
|
| 20 |
+
return getattr(mod, name, None)
|
| 21 |
+
except Exception:
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
# ── Geometry ──────────────────────────────────────────────────────────────────
|
| 25 |
+
christoffel_cuda_fwd = _get_op("gfn_cuda", "compute_christoffel_symbols_fwd")
|
| 26 |
+
christoffel_cuda_bwd = _get_op("gfn_cuda", "compute_christoffel_symbols_bwd")
|
| 27 |
+
low_rank_christoffel_fwd = _get_op("gfn_cuda", "low_rank_christoffel_fwd")
|
| 28 |
+
low_rank_christoffel_bwd = _get_op("gfn_cuda", "low_rank_christoffel_bwd")
|
| 29 |
+
toroidal_christ_fwd = _get_op("gfn_cuda", "toroidal_geo_christoffel_fwd")
|
| 30 |
+
|
| 31 |
+
# ── Integrators ───────────────────────────────────────────────────────────────
|
| 32 |
+
heun_fused = _get_op("gfn_cuda", "heun_fwd")
|
| 33 |
+
leapfrog_fused = _get_op("gfn_cuda", "leapfrog_fwd")
|
| 34 |
+
yoshida_fused = _get_op("gfn_cuda", "yoshida_fwd")
|
| 35 |
+
rk4_fused = _get_op("gfn_cuda", "rk4_fwd")
|
| 36 |
+
|
| 37 |
+
# ── Loss ──────────────────────────────────────────────────────────────────────
|
| 38 |
+
toroidal_loss_fwd = _get_op("gfn_cuda", "toroidal_distance_loss_fwd")
|
| 39 |
+
toroidal_loss_bwd = _get_op("gfn_cuda", "toroidal_distance_loss_bwd")
|
| 40 |
+
|
| 41 |
+
def __getattr__(name):
|
| 42 |
+
if name.endswith(("_fused", "_fwd", "_bwd", "_cuda")):
|
| 43 |
+
return None
|
| 44 |
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
| 45 |
+
|
| 46 |
+
__all__ = [
|
| 47 |
+
"CUDA_AVAILABLE",
|
| 48 |
+
"christoffel_cuda_fwd", "christoffel_cuda_bwd",
|
| 49 |
+
"low_rank_christoffel_fwd", "low_rank_christoffel_bwd",
|
| 50 |
+
"toroidal_christ_fwd", "heun_fused", "leapfrog_fused",
|
| 51 |
+
"yoshida_fused", "rk4_fused", "toroidal_loss_fwd", "toroidal_loss_bwd"
|
| 52 |
+
]
|
gfn/realizations/gssm/data/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data/__init__.py — GFN V5
|
| 3 |
+
"""
|
| 4 |
+
from ..data.dataset import SequenceDataset
|
| 5 |
+
from ..data.loader import create_dataloaders
|
| 6 |
+
from ..data.transforms import shift_targets, add_bos_token, pad_sequences
|
| 7 |
+
from ..data.replay import TrajectoryReplayBuffer
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'SequenceDataset',
|
| 11 |
+
'create_dataloaders',
|
| 12 |
+
'shift_targets',
|
| 13 |
+
'add_bos_token',
|
| 14 |
+
'pad_sequences',
|
| 15 |
+
'TrajectoryReplayBuffer'
|
| 16 |
+
]
|
gfn/realizations/gssm/data/dataset.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
|
| 4 |
+
class SequenceDataset(Dataset):
|
| 5 |
+
"""Simple sequence dataset for (X, Y) pairs."""
|
| 6 |
+
def __init__(self, x: torch.Tensor, y: torch.Tensor):
|
| 7 |
+
self.x = x
|
| 8 |
+
self.y = y
|
| 9 |
+
|
| 10 |
+
def __len__(self):
|
| 11 |
+
return len(self.x)
|
| 12 |
+
|
| 13 |
+
def __getitem__(self, idx):
|
| 14 |
+
return self.x[idx], self.y[idx]
|
gfn/realizations/gssm/data/loader.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data/loader.py — GFN V5
|
| 3 |
+
DataLoaders y datasets para tareas GFN.
|
| 4 |
+
WATCHOUT: El DataLoader se crea UNA vez fuera del loop de entrenamiento.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
| 9 |
+
from typing import Tuple, Optional
|
| 10 |
+
from ..data.dataset import SequenceDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_dataloaders(
|
| 14 |
+
x: torch.Tensor,
|
| 15 |
+
y: torch.Tensor,
|
| 16 |
+
batch_size: int = 32,
|
| 17 |
+
val_split: float = 0.1,
|
| 18 |
+
shuffle: bool = True,
|
| 19 |
+
num_workers: int = 0,
|
| 20 |
+
seed: int = 42,
|
| 21 |
+
) -> Tuple[DataLoader, Optional[DataLoader]]:
|
| 22 |
+
"""
|
| 23 |
+
Crea train y validation DataLoaders desde tensores.
|
| 24 |
+
IMPORTANTE: Crear los DataLoaders UNA VEZ fuera del loop — no dentro.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
x, y: Tensores de entrada y objetivo
|
| 28 |
+
batch_size: Tamaño de batch
|
| 29 |
+
val_split: Fracción de datos para validación (0 = sin validación)
|
| 30 |
+
shuffle: Mezclar datos de entrenamiento
|
| 31 |
+
num_workers: Workers para carga de datos
|
| 32 |
+
seed: Semilla para reproducibilidad del split
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
(train_loader, val_loader) — val_loader es None si val_split=0
|
| 36 |
+
"""
|
| 37 |
+
dataset = SequenceDataset(x, y)
|
| 38 |
+
|
| 39 |
+
if val_split > 0:
|
| 40 |
+
n_val = max(1, int(len(dataset) * val_split))
|
| 41 |
+
n_train = len(dataset) - n_val
|
| 42 |
+
generator = torch.Generator().manual_seed(seed)
|
| 43 |
+
train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=generator)
|
| 44 |
+
|
| 45 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size,
|
| 46 |
+
shuffle=shuffle, num_workers=num_workers)
|
| 47 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size,
|
| 48 |
+
shuffle=False, num_workers=num_workers)
|
| 49 |
+
return train_loader, val_loader
|
| 50 |
+
|
| 51 |
+
train_loader = DataLoader(dataset, batch_size=batch_size,
|
| 52 |
+
shuffle=shuffle, num_workers=num_workers)
|
| 53 |
+
return train_loader, None
|
gfn/realizations/gssm/data/replay.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Replay / Trajectory Buffer — GFN V5
|
| 3 |
+
Maneja el almacenamiento y muestreo de estados físicos (x, v, forces)
|
| 4 |
+
para soporte de entrenamiento Off-Policy y exploración de GFlowNets reales.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
class TrajectoryReplayBuffer:
|
| 11 |
+
"""
|
| 12 |
+
A persistent buffer for storing and managing manifold trajectories (x, v states).
|
| 13 |
+
Serves as replay memory for Hamiltonian/Geodesic flows in V5.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
capacity: int,
|
| 18 |
+
dim: int,
|
| 19 |
+
device: torch.device = torch.device('cpu'),
|
| 20 |
+
dtype: torch.dtype = torch.float32
|
| 21 |
+
):
|
| 22 |
+
self.capacity = capacity
|
| 23 |
+
self.dim = dim
|
| 24 |
+
self.device = device
|
| 25 |
+
self.dtype = dtype
|
| 26 |
+
|
| 27 |
+
# Buffers for state (x), velocity (v), and optional force
|
| 28 |
+
# Shape: [capacity, dim] or [capacity, heads, head_dim] depending on input
|
| 29 |
+
# Note: we flatten the capacity dimension but keep the geometry shape.
|
| 30 |
+
self._initialized_shape = False
|
| 31 |
+
|
| 32 |
+
self.pointer = 0
|
| 33 |
+
self.size = 0
|
| 34 |
+
self.is_full = False
|
| 35 |
+
|
| 36 |
+
def _init_buffers(self, example_shape: torch.Size):
|
| 37 |
+
"""Initializes the tensor buffers based on the first observed shape."""
|
| 38 |
+
# example_shape might be [Batch, Dim] or [Batch, Heads, HeadDim]
|
| 39 |
+
# We need [Capacity, *shape[1:]]
|
| 40 |
+
element_shape = example_shape[1:]
|
| 41 |
+
|
| 42 |
+
# Memory check safeguard
|
| 43 |
+
import math
|
| 44 |
+
bytes_per_el = 4 if self.dtype == torch.float32 else 8
|
| 45 |
+
total_elements = self.capacity * math.prod(element_shape)
|
| 46 |
+
# 3 buffers (x, v, force)
|
| 47 |
+
total_mb = (3 * total_elements * bytes_per_el) / (1024 ** 2)
|
| 48 |
+
if total_mb > 1024 and self.device.type == 'cuda':
|
| 49 |
+
import logging
|
| 50 |
+
logging.warning(f"ReplayBuffer: Allocating {total_mb:.1f} MB on CUDA. Risk of OOM.")
|
| 51 |
+
|
| 52 |
+
self.x_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
|
| 53 |
+
self.v_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
|
| 54 |
+
self.force_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
|
| 55 |
+
self._initialized_shape = True
|
| 56 |
+
|
| 57 |
+
def add(
|
| 58 |
+
self,
|
| 59 |
+
x: torch.Tensor,
|
| 60 |
+
v: torch.Tensor,
|
| 61 |
+
force: Optional[torch.Tensor] = None
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
Adds a batch of transitions to the buffer.
|
| 65 |
+
"""
|
| 66 |
+
batch_size = x.size(0)
|
| 67 |
+
|
| 68 |
+
if not self._initialized_shape:
|
| 69 |
+
self._init_buffers(x.shape)
|
| 70 |
+
|
| 71 |
+
# Handle wrap-around indexing
|
| 72 |
+
indices = torch.arange(self.pointer, self.pointer + batch_size) % self.capacity
|
| 73 |
+
|
| 74 |
+
self.x_buffer[indices] = x.to(self.device).detach()
|
| 75 |
+
self.v_buffer[indices] = v.to(self.device).detach()
|
| 76 |
+
if force is not None:
|
| 77 |
+
self.force_buffer[indices] = force.to(self.device).detach()
|
| 78 |
+
|
| 79 |
+
self.pointer = (self.pointer + batch_size) % self.capacity
|
| 80 |
+
self.size = min(self.size + batch_size, self.capacity)
|
| 81 |
+
if self.size == self.capacity:
|
| 82 |
+
self.is_full = True
|
| 83 |
+
|
| 84 |
+
def sample_random(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 85 |
+
"""
|
| 86 |
+
Randomly samples a batch of states from the buffer.
|
| 87 |
+
Returns: (x, v, force)
|
| 88 |
+
"""
|
| 89 |
+
if self.size == 0:
|
| 90 |
+
raise ValueError("Cannot sample from an empty buffer.")
|
| 91 |
+
|
| 92 |
+
indices = torch.randint(0, self.size, (batch_size,), device=self.device)
|
| 93 |
+
|
| 94 |
+
return (
|
| 95 |
+
self.x_buffer[indices],
|
| 96 |
+
self.v_buffer[indices],
|
| 97 |
+
self.force_buffer[indices]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def sample_recent(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 101 |
+
"""Samples the most recently added transitions."""
|
| 102 |
+
if self.size == 0:
|
| 103 |
+
raise ValueError("Cannot sample from an empty buffer.")
|
| 104 |
+
|
| 105 |
+
if self.size < batch_size:
|
| 106 |
+
idx = torch.arange(0, self.size, device=self.device)
|
| 107 |
+
else:
|
| 108 |
+
idx = (torch.arange(self.pointer - batch_size, self.pointer, device=self.device) % self.capacity)
|
| 109 |
+
|
| 110 |
+
return (
|
| 111 |
+
self.x_buffer[idx],
|
| 112 |
+
self.v_buffer[idx],
|
| 113 |
+
self.force_buffer[idx]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def sample_with_noise(self, batch_size: int, noise_std: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 117 |
+
"""Samples with Gaussian jitter to improve robust training."""
|
| 118 |
+
x, v, _ = self.sample_random(batch_size)
|
| 119 |
+
x_noisy = x + torch.randn_like(x) * noise_std
|
| 120 |
+
return x_noisy, v
|
| 121 |
+
|
| 122 |
+
def clear(self):
|
| 123 |
+
"""Resets the buffer."""
|
| 124 |
+
self.pointer = 0
|
| 125 |
+
self.size = 0
|
| 126 |
+
self.is_full = False
|
| 127 |
+
self._initialized_shape = False
|
| 128 |
+
|
| 129 |
+
def __len__(self):
|
| 130 |
+
return self.size
|
gfn/realizations/gssm/data/transforms.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data/transforms.py — GFN V5
|
| 3 |
+
Transformaciones de datos para secuencias.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Tuple, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def shift_targets(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 11 |
+
"""
|
| 12 |
+
Crea pares (input, target) desplazados for language modeling.
|
| 13 |
+
input = x[:, :-1]
|
| 14 |
+
target = x[:, 1:]
|
| 15 |
+
"""
|
| 16 |
+
return x[:, :-1], x[:, 1:]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def add_bos_token(x: torch.Tensor, bos_id: int = 0) -> torch.Tensor:
|
| 20 |
+
"""Añade token BOS al inicio de cada secuencia."""
|
| 21 |
+
bos = torch.full((x.size(0), 1), bos_id, dtype=x.dtype, device=x.device)
|
| 22 |
+
return torch.cat([bos, x], dim=1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pad_sequences(sequences, max_len: int, pad_id: int = 0) -> torch.Tensor:
|
| 26 |
+
"""Padea una lista de secuencias de longitud variable."""
|
| 27 |
+
result = torch.full((len(sequences), max_len), pad_id, dtype=torch.long)
|
| 28 |
+
for i, seq in enumerate(sequences):
|
| 29 |
+
length = min(len(seq), max_len)
|
| 30 |
+
result[i, :length] = torch.tensor(seq[:length])
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_attention_mask(lengths: torch.Tensor, max_len: int) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
Crea attention mask desde longitudes de secuencia.
|
| 37 |
+
Returns: [B, max_len] con True donde hay datos válidos.
|
| 38 |
+
"""
|
| 39 |
+
indices = torch.arange(max_len, device=lengths.device).unsqueeze(0)
|
| 40 |
+
return indices < lengths.unsqueeze(1)
|
gfn/realizations/gssm/errors.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class GFNError(Exception):
|
| 2 |
+
"""Base exception for all GFN errors."""
|
| 3 |
+
pass
|
| 4 |
+
|
| 5 |
+
class ConfigurationError(GFNError):
|
| 6 |
+
"""Raised when a configuration is invalid or inconsistent."""
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
class GeometryError(GFNError):
|
| 10 |
+
"""Raised when a geometric operation fails (e.g., out of manifold)."""
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
class PhysicsError(GFNError):
|
| 14 |
+
"""Raised during physics engine failures (e.g., NaN detected)."""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
class IntegrationError(GFNError):
|
| 18 |
+
"""Raised during numerical integration failures."""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
class TrainingError(GFNError):
|
| 22 |
+
"""Raised during model training or optimization failures."""
|
| 23 |
+
pass
|
gfn/realizations/gssm/geometry/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gfn/geometry/__init__.py
|
| 3 |
+
Public API for the geometry module — GFN V5
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Base and factory
|
| 7 |
+
from ..geometry.base import BaseGeometry
|
| 8 |
+
from ..geometry.factory import GeometryFactory
|
| 9 |
+
|
| 10 |
+
# Concrete geometries (imports trigger @register_geometry decorators)
|
| 11 |
+
from ..geometry.euclidean import EuclideanGeometry
|
| 12 |
+
from ..geometry.torus import ToroidalRiemannianGeometry, FlatToroidalRiemannianGeometry
|
| 13 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry, PaperLowRankRiemannianGeometry
|
| 14 |
+
from ..geometry.adaptive import AdaptiveRiemannianGeometry
|
| 15 |
+
from ..geometry.reactive import ReactiveRiemannianGeometry
|
| 16 |
+
from ..geometry.hyperbolic import HyperRiemannianGeometry
|
| 17 |
+
from ..geometry.holographic import HolographicRiemannianGeometry
|
| 18 |
+
from ..geometry.spherical import SphericalGeometry
|
| 19 |
+
from ..geometry.hierarchical import HierarchicalGeometry
|
| 20 |
+
|
| 21 |
+
# Re-export FrictionGate from unified physics.components location
|
| 22 |
+
from ..physics.components.friction import FrictionGate
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
# Base
|
| 26 |
+
"BaseGeometry",
|
| 27 |
+
"GeometryFactory",
|
| 28 |
+
# Implementations
|
| 29 |
+
"EuclideanGeometry",
|
| 30 |
+
"ToroidalRiemannianGeometry",
|
| 31 |
+
"FlatToroidalRiemannianGeometry",
|
| 32 |
+
"LowRankRiemannianGeometry",
|
| 33 |
+
"PaperLowRankRiemannianGeometry",
|
| 34 |
+
"AdaptiveRiemannianGeometry",
|
| 35 |
+
"ReactiveRiemannianGeometry",
|
| 36 |
+
"HyperRiemannianGeometry",
|
| 37 |
+
"HolographicRiemannianGeometry",
|
| 38 |
+
"SphericalGeometry",
|
| 39 |
+
"HierarchicalGeometry",
|
| 40 |
+
# Shared components
|
| 41 |
+
"FrictionGate",
|
| 42 |
+
]
|
gfn/realizations/gssm/geometry/adaptive.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptiveRiemannianGeometry — GFN V5
|
| 3 |
+
Adaptive rank Christoffel symbol decomposition.
|
| 4 |
+
Migrated from gfn/geo/riemannian/adaptive_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
from ..constants import CURVATURE_CLAMP
|
| 12 |
+
from ..config.schema import PhysicsConfig
|
| 13 |
+
from ..geometry.base import BaseGeometry
|
| 14 |
+
from ..registry import register_geometry
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_geometry('adaptive')
|
| 18 |
+
class AdaptiveRiemannianGeometry(BaseGeometry):
|
| 19 |
+
"""
|
| 20 |
+
Adjusts the effective curvature rank dynamically based on velocity complexity.
|
| 21 |
+
|
| 22 |
+
Architecture:
|
| 23 |
+
eff_rank = f(||v||) in [min_rank, max_rank]
|
| 24 |
+
Γ(v) = W[:, :eff_rank] @ (U[:, :eff_rank]^T v)^2
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, dim: int, max_rank: int = 64, config: Optional[PhysicsConfig] = None):
|
| 28 |
+
super().__init__(config)
|
| 29 |
+
self.dim = dim
|
| 30 |
+
self.max_rank = max_rank
|
| 31 |
+
self.min_rank_ratio = 0.1
|
| 32 |
+
|
| 33 |
+
self.U_full = nn.Parameter(torch.randn(dim, max_rank) * 0.01)
|
| 34 |
+
self.W_full = nn.Parameter(torch.randn(dim, max_rank) * 0.01)
|
| 35 |
+
|
| 36 |
+
# Complexity predictor: maps v → rank_ratio ∈ [0, 1]
|
| 37 |
+
self.complexity_net = nn.Sequential(
|
| 38 |
+
nn.Linear(dim, 32),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Linear(32, 1),
|
| 41 |
+
nn.Sigmoid()
|
| 42 |
+
)
|
| 43 |
+
# Initialize bias to start with a mostly-open rank to avoid vanishing gradients
|
| 44 |
+
nn.init.constant_(self.complexity_net[-2].bias, 1.0)
|
| 45 |
+
|
| 46 |
+
self.return_friction_separately = True
|
| 47 |
+
|
| 48 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 49 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 50 |
+
if v is None:
|
| 51 |
+
return torch.zeros_like(x)
|
| 52 |
+
|
| 53 |
+
# 1. Predict rank weight (soft-mask)
|
| 54 |
+
# Use complexity_net to predict a value p in [0, 1]
|
| 55 |
+
p = self.complexity_net(v) # [B, 1]
|
| 56 |
+
|
| 57 |
+
# Create a soft mask for the rank dimension: [B, max_rank]
|
| 58 |
+
# Mask[i] = sigmoid(slope * (p * max_rank - i))
|
| 59 |
+
# This approximates hard-slicing but is differentiable.
|
| 60 |
+
indices = torch.arange(self.max_rank, device=v.device).float()
|
| 61 |
+
slope = 10.0
|
| 62 |
+
soft_mask = torch.sigmoid(slope * (p * self.max_rank - indices)) # [B, max_rank]
|
| 63 |
+
|
| 64 |
+
# 2. Christoffel using all components modulated by mask
|
| 65 |
+
proj = torch.matmul(v, self.U_full) # [B, max_rank]
|
| 66 |
+
sq = proj * proj # [B, max_rank]
|
| 67 |
+
modulated_sq = sq * soft_mask # [B, max_rank]
|
| 68 |
+
gamma = torch.matmul(modulated_sq, self.W_full.t()) # [B, dim]
|
| 69 |
+
|
| 70 |
+
# 3. Friction (ensure mu is not just zero)
|
| 71 |
+
# Fallback to config friction or a base value
|
| 72 |
+
friction_base = getattr(self.config.stability, 'friction', 0.1)
|
| 73 |
+
mu = torch.full_like(v, friction_base)
|
| 74 |
+
|
| 75 |
+
gamma_clamped = CURVATURE_CLAMP * torch.tanh(gamma / CURVATURE_CLAMP)
|
| 76 |
+
|
| 77 |
+
if self.return_friction_separately:
|
| 78 |
+
return gamma_clamped, mu
|
| 79 |
+
|
| 80 |
+
return gamma_clamped + mu * v
|
| 81 |
+
|
| 82 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
return torch.ones_like(x)
|
gfn/realizations/gssm/geometry/base.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
from ..interfaces.geometry import Geometry
|
| 5 |
+
from ..config.schema import PhysicsConfig
|
| 6 |
+
from ..constants import TOPOLOGY_EUCLIDEAN
|
| 7 |
+
|
| 8 |
+
class BaseGeometry(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Base implementation for Riemannian Geometries in GFN V5.
|
| 11 |
+
Conforms to the Geometry protocol.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, config: Optional[PhysicsConfig] = None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.config = config or PhysicsConfig()
|
| 16 |
+
self.return_friction_separately = True
|
| 17 |
+
self.topology_type = self.config.topology.type
|
| 18 |
+
|
| 19 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""Default metric is identity (Euclidean). Subclasses should override."""
|
| 21 |
+
return torch.ones_like(x)
|
| 22 |
+
|
| 23 |
+
def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
"""Default Christoffel symbols are zero (Euclidean). Subclasses should override."""
|
| 25 |
+
return torch.zeros_like(x)
|
| 26 |
+
|
| 27 |
+
def compute_kinetic_energy(self, x: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Calculates Riemannian kinetic energy: T = (1/2) Σ_i g_ii v_i²
|
| 30 |
+
Supports position-dependent metrics (like Torus).
|
| 31 |
+
"""
|
| 32 |
+
g = self.metric_tensor(x) # [..., D]
|
| 33 |
+
return 0.5 * (g * v.pow(2)).sum(dim=-1)
|
| 34 |
+
|
| 35 |
+
def compute_potential_energy(self, x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Calculates physical potential energy V(x).
|
| 38 |
+
Default is 0.0 unless overwritten by specific topologies or forces.
|
| 39 |
+
"""
|
| 40 |
+
return torch.zeros_like(x).sum(dim=-1)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None, force: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 43 |
+
"""
|
| 44 |
+
Computes acceleration: acc = -Gamma(v, v) + F/g
|
| 45 |
+
Subclasses can override for more complex physics.
|
| 46 |
+
"""
|
| 47 |
+
if v is None:
|
| 48 |
+
return torch.zeros_like(x)
|
| 49 |
+
|
| 50 |
+
gamma = self.christoffel_symbols(x)
|
| 51 |
+
# Standard geodesic acceleration: -Gamma^k_ij v^i v^j
|
| 52 |
+
# In our simplified 1D-per-dimension metric, it's often just a point-wise product
|
| 53 |
+
acc = -gamma * (v**2)
|
| 54 |
+
|
| 55 |
+
if force is not None:
|
| 56 |
+
g = self.metric_tensor(x)
|
| 57 |
+
acc = acc + (force / (g + 1e-8))
|
| 58 |
+
|
| 59 |
+
if getattr(self, 'return_friction_separately', False):
|
| 60 |
+
return acc, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 61 |
+
|
| 62 |
+
return acc
|
| 63 |
+
|
| 64 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
"""Default projection is identity. Subclasses should override for periodic spaces."""
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""Default distance is Euclidean. Subclasses should override."""
|
| 70 |
+
return torch.norm(x1 - x2, dim=-1)
|
gfn/realizations/gssm/geometry/euclidean.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .base import BaseGeometry
|
| 3 |
+
from ..registry import register_geometry
|
| 4 |
+
from ..constants import TOPOLOGY_EUCLIDEAN
|
| 5 |
+
|
| 6 |
+
@register_geometry(TOPOLOGY_EUCLIDEAN)
|
| 7 |
+
class EuclideanGeometry(BaseGeometry):
|
| 8 |
+
"""Standard Euclidean Space (Flat)."""
|
| 9 |
+
|
| 10 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 11 |
+
return torch.ones_like(x)
|
| 12 |
+
|
| 13 |
+
def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
return torch.zeros_like(x)
|
| 15 |
+
|
| 16 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
return torch.norm(x1 - x2, dim=-1)
|
gfn/realizations/gssm/geometry/factory.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GeometryFactory — GFN V5
|
| 3 |
+
Creates geometry instances from PhysicsConfig.
|
| 4 |
+
Supports: euclidean, torus, low_rank, reactive, adaptive, hyperbolic, holographic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from ..config.schema import PhysicsConfig
|
| 9 |
+
from ..registry import GEOMETRY_REGISTRY
|
| 10 |
+
from ..constants import TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
_GEOMETRIES_REGISTERED = False
|
| 16 |
+
|
| 17 |
+
def _register_all_geometries():
|
| 18 |
+
"""Importa los submódulos explícitamente para registrar las geometrías."""
|
| 19 |
+
global _GEOMETRIES_REGISTERED
|
| 20 |
+
if _GEOMETRIES_REGISTERED:
|
| 21 |
+
return
|
| 22 |
+
from . import euclidean
|
| 23 |
+
from . import torus
|
| 24 |
+
from . import low_rank
|
| 25 |
+
from . import adaptive
|
| 26 |
+
from . import reactive
|
| 27 |
+
from . import hyperbolic
|
| 28 |
+
_GEOMETRIES_REGISTERED = True
|
| 29 |
+
|
| 30 |
+
class GeometryFactory:
|
| 31 |
+
"""
|
| 32 |
+
Creates manifold geometries from configuration.
|
| 33 |
+
|
| 34 |
+
Primary key: topology.type ('euclidean', 'torus', 'hyperbolic', ...)
|
| 35 |
+
Secondary key: topology.riemannian_type ('low_rank', 'reactive', 'adaptive', ...)
|
| 36 |
+
|
| 37 |
+
riemannian_type overrides topology.type when explicitly set and registered.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _lookup_key(config: PhysicsConfig) -> str:
|
| 42 |
+
_register_all_geometries()
|
| 43 |
+
topo_type = config.topology.type.lower()
|
| 44 |
+
riem_type = getattr(config.topology, 'riemannian_type', 'reactive').lower()
|
| 45 |
+
available = GEOMETRY_REGISTRY.list_keys()
|
| 46 |
+
|
| 47 |
+
# Priority Logic:
|
| 48 |
+
# 1. Prioritize learned Riemannian geometries (low_rank, reactive, adaptive)
|
| 49 |
+
# even if the topology is specialized (torus, etc.), as they handle topology via features.
|
| 50 |
+
learned_types = {'low_rank', 'reactive', 'adaptive', 'low_rank_paper'}
|
| 51 |
+
if riem_type in learned_types and riem_type in available:
|
| 52 |
+
return riem_type
|
| 53 |
+
|
| 54 |
+
# 2. Otherwise, if topology is specific (torus, hyperbolic, etc.), use its analytical model.
|
| 55 |
+
if topo_type in available and topo_type != TOPOLOGY_EUCLIDEAN:
|
| 56 |
+
return topo_type
|
| 57 |
+
|
| 58 |
+
# 3. Fallback to riem_type or topo_type
|
| 59 |
+
if riem_type in available:
|
| 60 |
+
return riem_type
|
| 61 |
+
|
| 62 |
+
return topo_type
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def create(config: PhysicsConfig):
|
| 66 |
+
"""
|
| 67 |
+
Create geometry using default dim from config.
|
| 68 |
+
Looks for 'dim' in topology config or falls back to 64.
|
| 69 |
+
"""
|
| 70 |
+
lookup_key = GeometryFactory._lookup_key(config)
|
| 71 |
+
available = GEOMETRY_REGISTRY.list_keys()
|
| 72 |
+
|
| 73 |
+
if lookup_key in available:
|
| 74 |
+
geometry_cls = GEOMETRY_REGISTRY.get(lookup_key)
|
| 75 |
+
try:
|
| 76 |
+
dim = getattr(config, 'dim', 64)
|
| 77 |
+
rank = getattr(config.topology, 'riemannian_rank', 16)
|
| 78 |
+
return geometry_cls(dim=dim, rank=rank, config=config)
|
| 79 |
+
except TypeError:
|
| 80 |
+
try:
|
| 81 |
+
return geometry_cls(config=config)
|
| 82 |
+
except TypeError:
|
| 83 |
+
return geometry_cls()
|
| 84 |
+
|
| 85 |
+
logger.warning(f"Geometry '{lookup_key}' not found. Using EuclideanGeometry.")
|
| 86 |
+
from .euclidean import EuclideanGeometry
|
| 87 |
+
return EuclideanGeometry(config=config)
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def create_with_dim(dim: int, rank: int, num_heads: int, config: PhysicsConfig):
|
| 91 |
+
"""
|
| 92 |
+
Create geometry with explicit dim and rank.
|
| 93 |
+
Used by ModelFactory to pass head_dim (not total dim) to the geometry,
|
| 94 |
+
since geometry operates on per-head tensors [B, H, HD].
|
| 95 |
+
"""
|
| 96 |
+
lookup_key = GeometryFactory._lookup_key(config)
|
| 97 |
+
available = GEOMETRY_REGISTRY.list_keys()
|
| 98 |
+
|
| 99 |
+
if lookup_key in available:
|
| 100 |
+
geometry_cls = GEOMETRY_REGISTRY.get(lookup_key)
|
| 101 |
+
try:
|
| 102 |
+
return geometry_cls(dim=dim, rank=rank, num_heads=num_heads, config=config)
|
| 103 |
+
except TypeError:
|
| 104 |
+
try:
|
| 105 |
+
return geometry_cls(dim=dim, rank=rank, config=config)
|
| 106 |
+
except TypeError:
|
| 107 |
+
try:
|
| 108 |
+
return geometry_cls(config=config)
|
| 109 |
+
except TypeError:
|
| 110 |
+
return geometry_cls()
|
| 111 |
+
|
| 112 |
+
logger.warning(f"Geometry '{lookup_key}' not found. Using EuclideanGeometry.")
|
| 113 |
+
from .euclidean import EuclideanGeometry
|
| 114 |
+
try:
|
| 115 |
+
return EuclideanGeometry(dim=dim, num_heads=num_heads, config=config)
|
| 116 |
+
except TypeError:
|
| 117 |
+
return EuclideanGeometry(config=config)
|
gfn/realizations/gssm/geometry/hierarchical.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List, Optional, Union, Tuple, Any
|
| 4 |
+
from ..geometry.base import BaseGeometry
|
| 5 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry
|
| 6 |
+
from ..registry import register_geometry
|
| 7 |
+
|
| 8 |
+
@register_geometry('hierarchical')
|
| 9 |
+
class HierarchicalGeometry(BaseGeometry):
|
| 10 |
+
"""
|
| 11 |
+
Multi-Scale Riemannian Geometry (Christoffel Mixture).
|
| 12 |
+
Combines multiple geometries (typically Low-Rank) with different scales.
|
| 13 |
+
|
| 14 |
+
Migrated from gfn_old HierarchicalRiemannianGeometry.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, dim: int, rank: int = 16, ranks: Optional[List[int]] = None,
|
| 17 |
+
num_heads: int = 1, config: Optional[Any] = None, **kwargs):
|
| 18 |
+
super().__init__(config)
|
| 19 |
+
self.dim = dim
|
| 20 |
+
self.ranks = ranks if ranks is not None else [8, 16, 32]
|
| 21 |
+
if rank not in self.ranks:
|
| 22 |
+
# Optionally include the factory-suggested rank
|
| 23 |
+
self.ranks = sorted(list(set(self.ranks + [rank])))
|
| 24 |
+
self.num_heads = num_heads
|
| 25 |
+
|
| 26 |
+
# Initialize sub-geometries (defaulting to LowRank)
|
| 27 |
+
self.scales = nn.ModuleList([
|
| 28 |
+
LowRankRiemannianGeometry(dim, rank=r, num_heads=num_heads, config=config)
|
| 29 |
+
for r in self.ranks
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
# Learnable mixing weights
|
| 33 |
+
self.scale_weights = nn.Parameter(torch.ones(len(self.ranks)) / len(self.ranks))
|
| 34 |
+
self.return_friction_separately = False
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 37 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 38 |
+
|
| 39 |
+
gammas = []
|
| 40 |
+
frictions = []
|
| 41 |
+
|
| 42 |
+
# Execute each scale
|
| 43 |
+
for scale in self.scales:
|
| 44 |
+
# Temporarily ensure consistent return mode
|
| 45 |
+
was_sep = getattr(scale, 'return_friction_separately', False)
|
| 46 |
+
scale.return_friction_separately = True
|
| 47 |
+
|
| 48 |
+
res = scale(x, v, force=force, **kwargs)
|
| 49 |
+
if isinstance(res, tuple):
|
| 50 |
+
g, f = res
|
| 51 |
+
else:
|
| 52 |
+
g, f = res, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 53 |
+
|
| 54 |
+
gammas.append(g)
|
| 55 |
+
frictions.append(f)
|
| 56 |
+
scale.return_friction_separately = was_sep
|
| 57 |
+
|
| 58 |
+
# Mix using softmax weights
|
| 59 |
+
weights = torch.softmax(self.scale_weights, dim=0)
|
| 60 |
+
|
| 61 |
+
gamma_mixed = sum(w * g for w, g in zip(weights, gammas))
|
| 62 |
+
friction_mixed = sum(w * f for w, f in zip(weights, frictions))
|
| 63 |
+
|
| 64 |
+
if self.return_friction_separately:
|
| 65 |
+
return gamma_mixed, friction_mixed
|
| 66 |
+
|
| 67 |
+
if v is not None:
|
| 68 |
+
return gamma_mixed + friction_mixed * v
|
| 69 |
+
return gamma_mixed
|
| 70 |
+
|
| 71 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
weights = torch.softmax(self.scale_weights, dim=0)
|
| 73 |
+
metrics = [scale.metric_tensor(x) for scale in self.scales]
|
| 74 |
+
return sum(w * m for w, m in zip(weights, metrics))
|
| 75 |
+
|
| 76 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
# Hierarchical usually just projects using the first scale (geometry consistent)
|
| 78 |
+
from typing import cast
|
| 79 |
+
return cast(BaseGeometry, self.scales[0]).project(x)
|
| 80 |
+
|
| 81 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
weights = torch.softmax(self.scale_weights, dim=0)
|
| 83 |
+
dists = [scale.dist(x1, x2) for scale in self.scales]
|
| 84 |
+
return sum(w * d for w, d in zip(weights, dists))
|
gfn/realizations/gssm/geometry/holographic.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HolographicRiemannianGeometry — GFN V5
|
| 3 |
+
AdS/CFT-inspired holographic extensions (Paper 18).
|
| 4 |
+
Migrated from gfn/geo/physical/holographic_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
from ..config.schema import PhysicsConfig
|
| 12 |
+
from ..geometry.base import BaseGeometry
|
| 13 |
+
from ..registry import register_geometry
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@register_geometry('holographic')
|
| 17 |
+
class HolographicRiemannianGeometry(BaseGeometry):
|
| 18 |
+
"""
|
| 19 |
+
Conformal manifold inspired by Bulk-Boundary (AdS/CFT) correspondence.
|
| 20 |
+
|
| 21 |
+
Lifts boundary state x → bulk (x, z) where z is the holographic radial dim.
|
| 22 |
+
Conformal metric: g_ij = (1/z(x)²) · δ_ij
|
| 23 |
+
|
| 24 |
+
The Christoffel correction adds an AdS-geodesic term to any base geometry.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, base_geometry: BaseGeometry, z_min: float = 0.1,
|
| 28 |
+
z_max: float = 10.0, config: Optional[PhysicsConfig] = None):
|
| 29 |
+
super().__init__(config)
|
| 30 |
+
self.base_geometry = base_geometry
|
| 31 |
+
self.dim = getattr(base_geometry, 'dim', None)
|
| 32 |
+
self.z_min = z_min
|
| 33 |
+
self.z_max = z_max
|
| 34 |
+
|
| 35 |
+
dim = self.dim or 0
|
| 36 |
+
if dim > 0:
|
| 37 |
+
self.radial_net: nn.Module = nn.Sequential(
|
| 38 |
+
nn.Linear(dim, dim // 2),
|
| 39 |
+
nn.SiLU(),
|
| 40 |
+
nn.Linear(dim // 2, 1),
|
| 41 |
+
nn.Softplus()
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
self.radial_net = nn.Identity()
|
| 45 |
+
|
| 46 |
+
def get_z_and_grad(self, x: torch.Tensor):
|
| 47 |
+
x_req = x.detach().requires_grad_(True)
|
| 48 |
+
with torch.enable_grad():
|
| 49 |
+
z = self.radial_net(x_req) + self.z_min
|
| 50 |
+
z = torch.clamp(z, max=self.z_max)
|
| 51 |
+
grad_z = torch.autograd.grad(
|
| 52 |
+
z.sum(), x_req,
|
| 53 |
+
create_graph=self.training,
|
| 54 |
+
retain_graph=False
|
| 55 |
+
)[0]
|
| 56 |
+
return z, grad_z
|
| 57 |
+
|
| 58 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
z, _ = self.get_z_and_grad(x)
|
| 60 |
+
return (1.0 / z.pow(2)) * torch.ones_like(x)
|
| 61 |
+
|
| 62 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 63 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 64 |
+
out_base = self.base_geometry(x, v, force=force, **kwargs)
|
| 65 |
+
if isinstance(out_base, tuple):
|
| 66 |
+
gamma_base, mu = out_base
|
| 67 |
+
else:
|
| 68 |
+
gamma_base, mu = out_base, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 69 |
+
|
| 70 |
+
if v is None:
|
| 71 |
+
if self.return_friction_separately:
|
| 72 |
+
return gamma_base, mu
|
| 73 |
+
return gamma_base
|
| 74 |
+
|
| 75 |
+
z, grad_z = self.get_z_and_grad(x)
|
| 76 |
+
v_dot_gradz = (v * grad_z).sum(dim=-1, keepdim=True)
|
| 77 |
+
v_sq = (v * v).sum(dim=-1, keepdim=True)
|
| 78 |
+
gamma_ads = -(1.0 / z) * (2.0 * v_dot_gradz * v - v_sq * grad_z)
|
| 79 |
+
|
| 80 |
+
gamma_total = gamma_base + gamma_ads
|
| 81 |
+
|
| 82 |
+
if self.return_friction_separately:
|
| 83 |
+
return gamma_total, mu
|
| 84 |
+
|
| 85 |
+
return gamma_total + mu * v
|
| 86 |
+
|
| 87 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
return self.base_geometry.project(x)
|
| 89 |
+
|
| 90 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
return self.base_geometry.dist(x1, x2)
|
gfn/realizations/gssm/geometry/hyperbolic.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HyperRiemannianGeometry — GFN V5
|
| 3 |
+
Context-dependent (gated) Christoffel symbols.
|
| 4 |
+
Migrated from gfn/geo/topological/hyperbolic_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
from ..constants import CURVATURE_CLAMP, EPS, TOPOLOGY_TORUS
|
| 12 |
+
from ..config.schema import PhysicsConfig
|
| 13 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry
|
| 14 |
+
from ..registry import register_geometry
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_geometry('hyperbolic')
|
| 18 |
+
class HyperRiemannianGeometry(LowRankRiemannianGeometry):
|
| 19 |
+
"""
|
| 20 |
+
Hyper-Christoffel: geometry conditioned on current position.
|
| 21 |
+
|
| 22 |
+
Architecture:
|
| 23 |
+
U(x) = U_static * diag(Gate_u(x)) — position-scaled basis
|
| 24 |
+
W(x) = W_static * diag(Gate_w(x))
|
| 25 |
+
Γ(v | x) = W(x) @ (U(x)^T v)²
|
| 26 |
+
|
| 27 |
+
Gates output values in [0, 2] initialized near 1.0 (identity).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 31 |
+
config: Optional[PhysicsConfig] = None):
|
| 32 |
+
super().__init__(dim, rank, num_heads=num_heads, config=config)
|
| 33 |
+
self.return_friction_separately = True
|
| 34 |
+
|
| 35 |
+
self.gate_u = nn.Linear(dim, rank)
|
| 36 |
+
self.gate_w = nn.Linear(dim, rank)
|
| 37 |
+
nn.init.zeros_(self.gate_u.weight); nn.init.zeros_(self.gate_u.bias)
|
| 38 |
+
nn.init.zeros_(self.gate_w.weight); nn.init.zeros_(self.gate_w.bias)
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 41 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 42 |
+
if v is None:
|
| 43 |
+
return torch.zeros_like(x)
|
| 44 |
+
|
| 45 |
+
original_shape = v.shape
|
| 46 |
+
# Handle multi-head [B, H, HD] -> [B*H, HD]
|
| 47 |
+
if v.dim() == 3:
|
| 48 |
+
B, H, HD = v.shape
|
| 49 |
+
v_flat = v.reshape(B * H, HD)
|
| 50 |
+
x_flat = x.reshape(B * H, HD)
|
| 51 |
+
else:
|
| 52 |
+
v_flat = v
|
| 53 |
+
x_flat = x
|
| 54 |
+
B, H = None, None
|
| 55 |
+
|
| 56 |
+
# Context gates in [0, 2]
|
| 57 |
+
g_u = torch.sigmoid(self.gate_u(x_flat)) * 2.0 # [B*H, rank]
|
| 58 |
+
g_w = torch.sigmoid(self.gate_w(x_flat)) * 2.0
|
| 59 |
+
|
| 60 |
+
# Modulate static basis
|
| 61 |
+
# self.U is [HD, rank] or [H, HD, rank]
|
| 62 |
+
U_eff = self.U if self.U.dim() == 2 else self.U.mean(0)
|
| 63 |
+
proj_static = torch.matmul(v_flat, U_eff) # [B*H, rank]
|
| 64 |
+
proj_dynamic = proj_static * g_u
|
| 65 |
+
|
| 66 |
+
# Soft-saturation to prevent energy explosion
|
| 67 |
+
sq_dynamic = (proj_dynamic * proj_dynamic) / (1.0 + torch.abs(proj_dynamic) + EPS)
|
| 68 |
+
sq_modulated = sq_dynamic * g_w
|
| 69 |
+
|
| 70 |
+
W_t = self.W.t() if self.W.dim() == 2 else self.W.mean(0).t()
|
| 71 |
+
gamma = torch.matmul(sq_modulated, W_t) # [B*H, HD]
|
| 72 |
+
|
| 73 |
+
# Restore original shape if multi-head
|
| 74 |
+
if B is not None:
|
| 75 |
+
gamma = gamma.view(original_shape)
|
| 76 |
+
x_flat_for_mu = x_flat
|
| 77 |
+
v_flat_for_mu = v_flat
|
| 78 |
+
else:
|
| 79 |
+
x_flat_for_mu = x
|
| 80 |
+
v_flat_for_mu = v
|
| 81 |
+
|
| 82 |
+
# Friction
|
| 83 |
+
x_in = torch.cat([torch.sin(x_flat), torch.cos(x_flat)], dim=-1) \
|
| 84 |
+
if self.topology_type == TOPOLOGY_TORUS else x_flat
|
| 85 |
+
mu_base = self.friction + self.friction_gate(x_in, force=force)
|
| 86 |
+
v_norm = torch.norm(v, dim=-1, keepdim=True) / (self.dim ** 0.5 + EPS)
|
| 87 |
+
mu = mu_base * (1.0 + self.velocity_friction_scale * v_norm)
|
| 88 |
+
if mu.shape != v.shape:
|
| 89 |
+
mu = mu.view_as(v) if mu.numel() == v.numel() else mu.mean(dim=-1, keepdim=True)
|
| 90 |
+
|
| 91 |
+
gamma = self._normalize(gamma)
|
| 92 |
+
gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
|
| 93 |
+
|
| 94 |
+
if self.return_friction_separately:
|
| 95 |
+
return gamma, mu
|
| 96 |
+
|
| 97 |
+
return gamma + mu * v
|
gfn/realizations/gssm/geometry/low_rank.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LowRankRiemannianGeometry — GFN V5
|
| 3 |
+
Computes Christoffel symbols via a low-rank decomposition.
|
| 4 |
+
Migrated from gfn/geo/riemannian/low_rank_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple, Dict, Any
|
| 10 |
+
|
| 11 |
+
from ..constants import (
|
| 12 |
+
EPS, MAX_VELOCITY, TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN,
|
| 13 |
+
DEFAULT_FRICTION, CURVATURE_CLAMP, GATE_BIAS_OPEN
|
| 14 |
+
)
|
| 15 |
+
from ..config.schema import PhysicsConfig
|
| 16 |
+
from ..geometry.base import BaseGeometry
|
| 17 |
+
from ..registry import register_geometry
|
| 18 |
+
from ..cuda.ops import CUDA_AVAILABLE, low_rank_christoffel_fwd, low_rank_christoffel_bwd
|
| 19 |
+
|
| 20 |
+
class LowRankChristoffelFunction(torch.autograd.Function):
|
| 21 |
+
@staticmethod
|
| 22 |
+
def forward(ctx, v, U, W, clamp_val, enable_trace_norm, is_paper_version=False):
|
| 23 |
+
v_c = v.contiguous()
|
| 24 |
+
U_c = U.contiguous()
|
| 25 |
+
W_c = W.contiguous()
|
| 26 |
+
|
| 27 |
+
gamma = low_rank_christoffel_fwd(v_c, U_c, W_c, float(clamp_val), enable_trace_norm, is_paper_version)
|
| 28 |
+
ctx.save_for_backward(v_c, U_c, W_c, gamma)
|
| 29 |
+
ctx.clamp_val = float(clamp_val)
|
| 30 |
+
ctx.enable_trace_norm = enable_trace_norm
|
| 31 |
+
ctx.is_paper_version = is_paper_version
|
| 32 |
+
return gamma
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def backward(ctx, grad_gamma):
|
| 36 |
+
v_c, U_c, W_c, gamma_out = ctx.saved_tensors
|
| 37 |
+
if grad_gamma is None:
|
| 38 |
+
return None, None, None, None, None, None
|
| 39 |
+
|
| 40 |
+
grad_gamma_c = grad_gamma.contiguous()
|
| 41 |
+
d_v, d_U, d_W = low_rank_christoffel_bwd(
|
| 42 |
+
grad_gamma_c, v_c, U_c, W_c, gamma_out,
|
| 43 |
+
ctx.clamp_val, ctx.enable_trace_norm, ctx.is_paper_version
|
| 44 |
+
)
|
| 45 |
+
return d_v, d_U, d_W, None, None, None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Use unified FrictionGate from physics.components (no duplication)
|
| 49 |
+
from ..physics.components.friction import FrictionGate
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@register_geometry('low_rank')
|
| 53 |
+
class LowRankRiemannianGeometry(BaseGeometry):
|
| 54 |
+
r"""
|
| 55 |
+
Low-rank Christoffel symbol decomposition.
|
| 56 |
+
|
| 57 |
+
Γ^k_ij ≈ Σ_r W_{rk} * (U_ir * U_jr)
|
| 58 |
+
This is an approximation — symmetry is preserved but Bianchi identities are not guaranteed.
|
| 59 |
+
Chosen for computational efficiency.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
dim: Manifold dimension.
|
| 63 |
+
rank: Rank of the decomposition.
|
| 64 |
+
num_heads: Number of parallel heads.
|
| 65 |
+
config: PhysicsConfig instance.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 69 |
+
config: Optional[PhysicsConfig] = None):
|
| 70 |
+
super().__init__(config)
|
| 71 |
+
self.dim = dim
|
| 72 |
+
self.rank = rank
|
| 73 |
+
self.num_heads = num_heads
|
| 74 |
+
|
| 75 |
+
topo = self.config.topology.type.lower()
|
| 76 |
+
self.topology = topo
|
| 77 |
+
self.clamp_val = self.config.stability.curvature_clamp
|
| 78 |
+
self.enable_trace_normalization = self.config.stability.enable_trace_normalization
|
| 79 |
+
self.enable_trace_normalization = self.config.stability.enable_trace_normalization
|
| 80 |
+
# Friction parameters are now handled by PhysicsEngine to avoid duplication
|
| 81 |
+
|
| 82 |
+
# Feature dimension for gate input (Fourier for torus)
|
| 83 |
+
gate_input_dim = dim * 2 if topo == TOPOLOGY_TORUS else dim
|
| 84 |
+
|
| 85 |
+
# Low-rank basis parameters - initialized with small noise to break symmetry
|
| 86 |
+
if num_heads > 1:
|
| 87 |
+
self.U = nn.Parameter(torch.randn(num_heads, dim, rank) * 1e-4)
|
| 88 |
+
self.W = nn.Parameter(torch.randn(num_heads, dim, rank) * 1e-4)
|
| 89 |
+
else:
|
| 90 |
+
self.U = nn.Parameter(torch.randn(dim, rank) * 1e-4)
|
| 91 |
+
self.W = nn.Parameter(torch.randn(dim, rank) * 1e-4)
|
| 92 |
+
|
| 93 |
+
# Friction gate
|
| 94 |
+
friction_mode = getattr(self.config.stability, 'friction_mode', 'static')
|
| 95 |
+
self.friction_gate = FrictionGate(dim, gate_input_dim, mode=friction_mode, num_heads=num_heads)
|
| 96 |
+
|
| 97 |
+
# CONTRACT: LowRank ALWAYS returns (gamma_christoffel, mu_friction) separately.
|
| 98 |
+
# The physics engine is the single authority on when/how friction is applied.
|
| 99 |
+
# This prevents the P0.1 double-friction bug (geometry + engine both applying mu*v).
|
| 100 |
+
self.return_friction_separately = True
|
| 101 |
+
|
| 102 |
+
def _get_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""Convert position to gate input features."""
|
| 104 |
+
if self.topology == TOPOLOGY_TORUS:
|
| 105 |
+
return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
def connection(self, v: torch.Tensor, w: torch.Tensor,
|
| 109 |
+
x: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
Bilinear Christoffel contraction: Γ(v,w)^k
|
| 112 |
+
Γ^k_ij ≈ Σ_r W[r,k] * (U[i,r] * U[j,r])
|
| 113 |
+
"""
|
| 114 |
+
if v.dim() == 3 and self.U.dim() == 3:
|
| 115 |
+
v_r = torch.einsum('bhd, hdr -> bhr', v, self.U)
|
| 116 |
+
w_r = torch.einsum('bhd, hdr -> bhr', w, self.U)
|
| 117 |
+
vw_r = v_r * w_r
|
| 118 |
+
gamma = torch.einsum('bhr, hdr -> bhd', vw_r, self.W)
|
| 119 |
+
else:
|
| 120 |
+
v_r = v @ self.U # [..., rank] (works for both 2D and 3D U)
|
| 121 |
+
w_r = w @ self.U
|
| 122 |
+
vw_r = v_r * w_r
|
| 123 |
+
W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
|
| 124 |
+
gamma = vw_r @ W_t
|
| 125 |
+
return torch.clamp(gamma, -self.clamp_val, self.clamp_val)
|
| 126 |
+
|
| 127 |
+
def _normalize(self, gamma: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
"""Symmetry-preserving trace normalization."""
|
| 129 |
+
if gamma.dim() < 2:
|
| 130 |
+
return gamma
|
| 131 |
+
|
| 132 |
+
is_multi_head = (gamma.dim() == 3 and self.num_heads > 1)
|
| 133 |
+
|
| 134 |
+
# Matrix case [..., D, D]
|
| 135 |
+
if not is_multi_head and gamma.dim() >= 3 and gamma.shape[-1] == gamma.shape[-2]:
|
| 136 |
+
gamma_sym = 0.5 * (gamma + gamma.transpose(-1, -2))
|
| 137 |
+
if self.enable_trace_normalization:
|
| 138 |
+
diag_mean = torch.diagonal(gamma_sym, dim1=-1, dim2=-2).mean(dim=-1, keepdim=True)
|
| 139 |
+
correction = torch.diag_embed(diag_mean.expand(-1, self.dim))
|
| 140 |
+
return gamma_sym - correction
|
| 141 |
+
return gamma_sym
|
| 142 |
+
|
| 143 |
+
# Vector case [..., D]
|
| 144 |
+
if self.enable_trace_normalization:
|
| 145 |
+
return gamma - gamma.mean(dim=-1, keepdim=True)
|
| 146 |
+
return gamma
|
| 147 |
+
|
| 148 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 149 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 150 |
+
if v is None:
|
| 151 |
+
return torch.zeros_like(x)
|
| 152 |
+
|
| 153 |
+
original_shape = v.shape
|
| 154 |
+
# Handle multi-head [B, H, HD] → reshape to [B*H, HD] for matmul with U=[HD, rank]
|
| 155 |
+
if v.dim() == 3 and self.U.dim() == 2:
|
| 156 |
+
B, H, HD = v.shape
|
| 157 |
+
v_flat = v.reshape(B * H, HD) # [B*H, HD]
|
| 158 |
+
x_flat = x.reshape(B * H, HD)
|
| 159 |
+
else:
|
| 160 |
+
v_flat = v
|
| 161 |
+
x_flat = x
|
| 162 |
+
B, H, HD = None, None, v_flat.shape[-1]
|
| 163 |
+
R = self.rank
|
| 164 |
+
|
| 165 |
+
# Check if we can take the fast CUDA path
|
| 166 |
+
use_cuda_fused = (
|
| 167 |
+
CUDA_AVAILABLE and
|
| 168 |
+
low_rank_christoffel_fwd is not None and
|
| 169 |
+
v_flat.is_cuda and
|
| 170 |
+
v_flat.dtype == torch.float32 and
|
| 171 |
+
self.W.dim() == 3
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if use_cuda_fused:
|
| 175 |
+
# Reshape [B*H, HD] -> [B, H, HD] to match what kernel expects
|
| 176 |
+
actual_B = original_shape[0] if v.dim() == 3 else 1
|
| 177 |
+
actual_H = self.num_heads
|
| 178 |
+
|
| 179 |
+
v_re = v_flat.view(actual_B, actual_H, HD)
|
| 180 |
+
U_re = self.U.view(actual_H, HD, R) # self.U is [H, D, R]
|
| 181 |
+
W_re = self.W.view(actual_H, HD, R)
|
| 182 |
+
|
| 183 |
+
gamma_re = LowRankChristoffelFunction.apply(
|
| 184 |
+
v_re, U_re, W_re, self.clamp_val, self.enable_trace_normalization, False
|
| 185 |
+
)
|
| 186 |
+
gamma = gamma_re.view_as(v_flat)
|
| 187 |
+
else:
|
| 188 |
+
# Christoffel symbols via self-connection (Native Fallback)
|
| 189 |
+
if v_flat.dim() == 3 and self.U.dim() == 3:
|
| 190 |
+
v_r = torch.einsum('bhd, hdr -> bhr', v_flat, self.U)
|
| 191 |
+
sq = v_r * v_r
|
| 192 |
+
gamma = torch.einsum('bhr, hdr -> bhd', sq, self.W)
|
| 193 |
+
else:
|
| 194 |
+
v_r = v_flat @ self.U # [..., rank]
|
| 195 |
+
sq = v_r * v_r
|
| 196 |
+
W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
|
| 197 |
+
gamma = sq @ W_t # [..., HD]
|
| 198 |
+
|
| 199 |
+
# Friction coefficient (position-dependent, gated)
|
| 200 |
+
x_in = self._get_features(x_flat)
|
| 201 |
+
mu = self.friction_gate(x_in, force=force)
|
| 202 |
+
|
| 203 |
+
# Note: PhysicsEngine will add base friction and apply velocity scaling
|
| 204 |
+
|
| 205 |
+
# Friction and normalizing only if not already done in CUDA
|
| 206 |
+
if not use_cuda_fused:
|
| 207 |
+
gamma = self._normalize(gamma)
|
| 208 |
+
gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
|
| 209 |
+
|
| 210 |
+
# Restore original shape if we reshaped
|
| 211 |
+
if B is not None:
|
| 212 |
+
gamma = gamma.view(original_shape)
|
| 213 |
+
mu = mu.view(original_shape)
|
| 214 |
+
|
| 215 |
+
# CONTRACT: always return (gamma_pure, mu) so engine has single authority over friction
|
| 216 |
+
return gamma, mu
|
| 217 |
+
|
| 218 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
"""
|
| 220 |
+
Implicit Riemannian metric from the low-rank decomposition.
|
| 221 |
+
|
| 222 |
+
The Christoffel parametrization Γ^k_ij ≈ Σ_r W_rk (U_ir U_jr) implies
|
| 223 |
+
an underlying metric: g_ij ≈ Σ_r U_ir * U_jr = diag(U @ Uᵀ)
|
| 224 |
+
|
| 225 |
+
Returns per-coordinate metric scale [..., D] or ones if shape unknown.
|
| 226 |
+
T = (1/2) Σ_i g_ii v_i² (Riemannian kinetic energy)
|
| 227 |
+
"""
|
| 228 |
+
if self.U.dim() == 2:
|
| 229 |
+
# Single head: U is [D, rank] → g_diag is [D]
|
| 230 |
+
g_diag = (self.U ** 2).sum(dim=-1) # [D]
|
| 231 |
+
# Broadcast to x shape: handles [B, D], [B*H, D], any [..., D]
|
| 232 |
+
return g_diag.expand_as(x)
|
| 233 |
+
else:
|
| 234 |
+
# Multi-head: U is [H, D, rank] → g_diag is [H, D]
|
| 235 |
+
g_diag = (self.U ** 2).sum(dim=-1) # [H, D]
|
| 236 |
+
if x.dim() == 3 and x.shape[1] == self.num_heads:
|
| 237 |
+
# [B, H, D]: structured multi-head
|
| 238 |
+
return g_diag.unsqueeze(0).expand_as(x)
|
| 239 |
+
else:
|
| 240 |
+
# [B, H*D]: flat layout — expand g_diag [H,D] → [H*D], broadcast to [B, H*D]
|
| 241 |
+
g_flat = g_diag.reshape(-1) # [H*D]
|
| 242 |
+
return g_flat.expand(x.shape[0], -1) if x.dim() == 2 else g_flat.expand_as(x)
|
| 243 |
+
|
| 244 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 245 |
+
if self.topology == TOPOLOGY_TORUS:
|
| 246 |
+
diff = x1 - x2
|
| 247 |
+
diff = torch.atan2(torch.sin(diff), torch.cos(diff))
|
| 248 |
+
return torch.norm(diff, dim=-1)
|
| 249 |
+
return torch.norm(x1 - x2, dim=-1)
|
| 250 |
+
|
| 251 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
if self.topology == TOPOLOGY_TORUS:
|
| 253 |
+
return torch.atan2(torch.sin(x), torch.cos(x))
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@register_geometry('low_rank_paper')
|
| 258 |
+
class PaperLowRankRiemannianGeometry(LowRankRiemannianGeometry):
|
| 259 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 260 |
+
config: Optional[PhysicsConfig] = None):
|
| 261 |
+
super().__init__(dim, rank, num_heads=num_heads, config=config)
|
| 262 |
+
self.return_friction_separately = True
|
| 263 |
+
|
| 264 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 265 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 266 |
+
if v is None:
|
| 267 |
+
return torch.zeros_like(x)
|
| 268 |
+
|
| 269 |
+
original_shape = v.shape
|
| 270 |
+
if v.dim() == 3 and self.U.dim() == 2:
|
| 271 |
+
B, H, HD = v.shape
|
| 272 |
+
v_flat = v.reshape(B * H, HD)
|
| 273 |
+
x_flat = x.reshape(B * H, HD)
|
| 274 |
+
else:
|
| 275 |
+
v_flat = v
|
| 276 |
+
x_flat = x
|
| 277 |
+
B, H, HD = None, None, v_flat.shape[-1]
|
| 278 |
+
R = self.rank
|
| 279 |
+
|
| 280 |
+
use_cuda_fused = (
|
| 281 |
+
CUDA_AVAILABLE and
|
| 282 |
+
low_rank_christoffel_fwd is not None and
|
| 283 |
+
v_flat.is_cuda and
|
| 284 |
+
v_flat.dtype == torch.float32 and
|
| 285 |
+
self.W.dim() == 3
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if use_cuda_fused:
|
| 289 |
+
actual_B = original_shape[0] if v.dim() == 3 else 1
|
| 290 |
+
actual_H = self.num_heads
|
| 291 |
+
v_re = v_flat.view(actual_B, actual_H, HD)
|
| 292 |
+
U_re = self.U.view(actual_H, HD, R)
|
| 293 |
+
W_re = self.W.view(actual_H, HD, R)
|
| 294 |
+
|
| 295 |
+
gamma_re = LowRankChristoffelFunction.apply(
|
| 296 |
+
v_re, U_re, W_re, self.clamp_val, self.enable_trace_normalization, True
|
| 297 |
+
)
|
| 298 |
+
gamma = gamma_re.view_as(v_flat)
|
| 299 |
+
else:
|
| 300 |
+
if v_flat.dim() == 3 and self.U.dim() == 3:
|
| 301 |
+
v_r = torch.einsum('bhd, hdr -> bhr', v_flat, self.U)
|
| 302 |
+
denom = 1.0 + torch.norm(v_r, dim=-1, keepdim=True)
|
| 303 |
+
phi = (v_r * v_r) / denom
|
| 304 |
+
gamma = torch.einsum('bhr, hdr -> bhd', phi, self.W)
|
| 305 |
+
else:
|
| 306 |
+
v_r = v_flat @ self.U
|
| 307 |
+
denom = 1.0 + torch.norm(v_r, dim=-1, keepdim=True)
|
| 308 |
+
phi = (v_r * v_r) / denom
|
| 309 |
+
W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
|
| 310 |
+
gamma = phi @ W_t
|
| 311 |
+
|
| 312 |
+
x_in = self._get_features(x_flat)
|
| 313 |
+
mu = self.friction_gate(x_in, force=force)
|
| 314 |
+
|
| 315 |
+
if not use_cuda_fused:
|
| 316 |
+
gamma = self._normalize(gamma)
|
| 317 |
+
gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
|
| 318 |
+
|
| 319 |
+
if B is not None:
|
| 320 |
+
gamma = gamma.view(original_shape)
|
| 321 |
+
mu = mu.view(original_shape)
|
| 322 |
+
|
| 323 |
+
# CONTRACT: Always return (gamma_pure, mu) for unified PhysicsEngine handling.
|
| 324 |
+
return gamma, mu
|
gfn/realizations/gssm/geometry/reactive.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ReactiveRiemannianGeometry — GFN V5
|
| 3 |
+
Active-inference geometry: curvature reacts to system state.
|
| 4 |
+
Migrated from gfn/geo/physical/reactive_field_geometry.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Optional, Union, Tuple, Dict, Any
|
| 10 |
+
|
| 11 |
+
from ..constants import CURVATURE_CLAMP, EPS, DEFAULT_PLASTICITY, TOPOLOGY_TORUS
|
| 12 |
+
from ..config.schema import PhysicsConfig
|
| 13 |
+
from ..geometry.low_rank import LowRankRiemannianGeometry
|
| 14 |
+
from ..registry import register_geometry
|
| 15 |
+
|
| 16 |
+
# Default constants
|
| 17 |
+
SINGULARITY_THRESHOLD = 0.5
|
| 18 |
+
BLACK_HOLE_STRENGTH = 3.0
|
| 19 |
+
SINGULARITY_GATE_SLOPE = 10.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@register_geometry('reactive')
|
| 23 |
+
class ReactiveRiemannianGeometry(LowRankRiemannianGeometry):
|
| 24 |
+
"""
|
| 25 |
+
Geometry that reacts to the system's own state via active inference.
|
| 26 |
+
|
| 27 |
+
Enhancements over LowRank:
|
| 28 |
+
1. Plasticity: Christoffel symbols scaled by kinetic energy (curv. amplification ≈ attention).
|
| 29 |
+
2. Singularities: Soft curvature amplification near semantic attractors.
|
| 30 |
+
|
| 31 |
+
Note: These are regularization/attention mechanisms, NOT physical manifold properties.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
|
| 35 |
+
config: Optional[PhysicsConfig] = None):
|
| 36 |
+
super().__init__(dim, rank, num_heads=num_heads, config=config)
|
| 37 |
+
self.return_friction_separately = True
|
| 38 |
+
|
| 39 |
+
self.active_cfg = self.config.active_inference
|
| 40 |
+
self.plasticity = getattr(self.active_cfg, 'plasticity', DEFAULT_PLASTICITY)
|
| 41 |
+
|
| 42 |
+
sing_cfg = self.config.singularities
|
| 43 |
+
sing_enabled = getattr(sing_cfg, 'enabled', False)
|
| 44 |
+
|
| 45 |
+
if sing_enabled:
|
| 46 |
+
self.semantic_certainty_threshold = getattr(sing_cfg, 'threshold', SINGULARITY_THRESHOLD)
|
| 47 |
+
self.curvature_amplification_factor = getattr(sing_cfg, 'strength', BLACK_HOLE_STRENGTH)
|
| 48 |
+
gate_input_dim = dim * 2 if self.topology == TOPOLOGY_TORUS else dim
|
| 49 |
+
if num_heads > 1:
|
| 50 |
+
self.V_weight = nn.Parameter(torch.zeros(num_heads, gate_input_dim, 1))
|
| 51 |
+
else:
|
| 52 |
+
self.V = nn.Linear(gate_input_dim, 1)
|
| 53 |
+
nn.init.zeros_(self.V.weight)
|
| 54 |
+
nn.init.constant_(self.V.bias, -2.0) # Start gate closed
|
| 55 |
+
else:
|
| 56 |
+
self.semantic_certainty_threshold = SINGULARITY_THRESHOLD
|
| 57 |
+
self.curvature_amplification_factor = BLACK_HOLE_STRENGTH
|
| 58 |
+
self.V = None
|
| 59 |
+
|
| 60 |
+
def _get_potential(self, x_in: torch.Tensor) -> Optional[torch.Tensor]:
|
| 61 |
+
"""Compute singularity potential, returns None if disabled."""
|
| 62 |
+
if not getattr(self.config.singularities, 'enabled', False):
|
| 63 |
+
return None
|
| 64 |
+
if self.num_heads > 1:
|
| 65 |
+
return torch.sigmoid(torch.matmul(x_in.unsqueeze(-2), self.V_weight).squeeze(-2))
|
| 66 |
+
elif self.V is not None:
|
| 67 |
+
return torch.sigmoid(self.V(x_in))
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 71 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 72 |
+
if v is None:
|
| 73 |
+
return torch.zeros_like(x)
|
| 74 |
+
|
| 75 |
+
# 1. Base curvature from LowRank
|
| 76 |
+
res = super().forward(x, v, force=force, **kwargs)
|
| 77 |
+
if isinstance(res, tuple):
|
| 78 |
+
gamma, mu = res
|
| 79 |
+
else:
|
| 80 |
+
gamma, mu = res, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
|
| 81 |
+
|
| 82 |
+
if not self.active_cfg.enabled:
|
| 83 |
+
if self.return_friction_separately:
|
| 84 |
+
return gamma, mu
|
| 85 |
+
return gamma + mu * v if v is not None else gamma
|
| 86 |
+
|
| 87 |
+
# 2. Plasticity: scale curvature by kinetic energy
|
| 88 |
+
react_cfg = self.active_cfg.reactive_curvature
|
| 89 |
+
react_enabled = react_cfg.get('enabled', False) if isinstance(react_cfg, dict) else False
|
| 90 |
+
if react_enabled and self.plasticity > 0.0:
|
| 91 |
+
energy = torch.tanh(v.pow(2).mean(dim=-1, keepdim=True))
|
| 92 |
+
gamma = gamma * (1.0 + self.plasticity * energy)
|
| 93 |
+
|
| 94 |
+
# 3. Singularity amplification
|
| 95 |
+
if getattr(self.config.singularities, 'enabled', False) and x is not None:
|
| 96 |
+
x_in = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) if self.topology == TOPOLOGY_TORUS else x
|
| 97 |
+
potential = self._get_potential(x_in)
|
| 98 |
+
if potential is not None:
|
| 99 |
+
gate_slope = getattr(self.config.singularities, 'gate_slope', SINGULARITY_GATE_SLOPE)
|
| 100 |
+
is_amplified = torch.sigmoid(gate_slope * (potential - self.semantic_certainty_threshold))
|
| 101 |
+
amp = 1.0 + is_amplified * (self.curvature_amplification_factor - 1.0)
|
| 102 |
+
gamma = gamma * amp
|
| 103 |
+
limit = self.curvature_amplification_factor * CURVATURE_CLAMP
|
| 104 |
+
gamma = limit * torch.tanh(gamma / limit)
|
| 105 |
+
|
| 106 |
+
if self.return_friction_separately:
|
| 107 |
+
return gamma, mu
|
| 108 |
+
|
| 109 |
+
return gamma + mu * v if v is not None else gamma
|
gfn/realizations/gssm/geometry/spherical.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Optional, Union, Tuple, Any
|
| 3 |
+
from ..geometry.base import BaseGeometry
|
| 4 |
+
from ..registry import register_geometry
|
| 5 |
+
from ..constants import EPS, TOPOLOGY_SPHERE
|
| 6 |
+
|
| 7 |
+
@register_geometry(TOPOLOGY_SPHERE)
|
| 8 |
+
class SphericalGeometry(BaseGeometry):
|
| 9 |
+
"""
|
| 10 |
+
Spherical Geometry (Analytical).
|
| 11 |
+
Computes Christoffel symbols for a constant positive curvature space.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, dim: int, rank: int = 16, config: Optional[Any] = None, **kwargs):
|
| 14 |
+
super().__init__(config)
|
| 15 |
+
self.dim = dim
|
| 16 |
+
|
| 17 |
+
def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""Analytical Christoffel symbols for S^n are typically zero in standard embedding or complex in other charts."""
|
| 19 |
+
# For our GFN purposes, we often use the simplified 'spherical_christoffel_torch'
|
| 20 |
+
# which acts as a centering/restoring force towards the sphere surface.
|
| 21 |
+
return torch.zeros_like(x)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
|
| 24 |
+
force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 25 |
+
if v is None:
|
| 26 |
+
return torch.zeros_like(x)
|
| 27 |
+
|
| 28 |
+
# Simplified analytical spherical coupling (restoring force)
|
| 29 |
+
xv = torch.sum(x * v, dim=-1, keepdim=True)
|
| 30 |
+
vv = torch.sum(v * v, dim=-1, keepdim=True)
|
| 31 |
+
|
| 32 |
+
# Gamma = -(2.0 * xv * v - vv * x)
|
| 33 |
+
gamma = -(2.0 * xv * v - vv * x)
|
| 34 |
+
|
| 35 |
+
# Apply standard V5 clamping
|
| 36 |
+
clamp_val = getattr(self, 'clamp_val', 5.0)
|
| 37 |
+
return clamp_val * torch.tanh(gamma / clamp_val)
|
| 38 |
+
|
| 39 |
+
def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""Identity metric for simplified spherical chart."""
|
| 41 |
+
return torch.ones_like(x)
|
| 42 |
+
|
| 43 |
+
def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""Great-circle distance approximation."""
|
| 45 |
+
dot = torch.sum(x1 * x2, dim=-1)
|
| 46 |
+
# Assuming points are on unit sphere
|
| 47 |
+
return torch.acos(torch.clamp(dot, -1.0 + EPS, 1.0 - EPS))
|